| import copy |
| import torch |
|
|
| def apply_overrides(params, overrides): |
| params = copy.deepcopy(params) |
| for param_name in overrides: |
| if param_name not in params: |
| print(f'override failed: no parameter named {param_name}') |
| raise ValueError |
| params[param_name] = overrides[param_name] |
| return params |
|
|
| def get_default_params_train(overrides={}): |
|
|
| params = {} |
| |
| ''' |
| misc |
| ''' |
| params['device'] = 'cuda' |
| params['save_base'] = './experiments/' |
| params['experiment_name'] = 'demo' |
| params['timestamp'] = False |
| |
| ''' |
| data |
| ''' |
| params['species_set'] = 'all' |
| params['hard_cap_seed'] = 9472 |
| params['hard_cap_num_per_class'] = -1 |
| params['aux_species_seed'] = 8099 |
| params['num_aux_species'] = 0 |
|
|
| ''' |
| model |
| ''' |
| params['model'] = 'ResidualFCNet' |
| params['num_filts'] = 256 |
| params['input_enc'] = 'sin_cos' |
| params['depth'] = 4 |
| |
| ''' |
| loss |
| ''' |
| params['loss'] = 'an_full' |
| params['pos_weight'] = 2048 |
| |
| ''' |
| optimization |
| ''' |
| params['batch_size'] = 2048 |
| params['lr'] = 0.0005 |
| params['lr_decay'] = 0.98 |
| params['num_epochs'] = 10 |
| |
| ''' |
| saving |
| ''' |
| params['log_frequency'] = 512 |
| |
| params = apply_overrides(params, overrides) |
| |
| return params |
|
|
| def get_default_params_eval(overrides={}): |
|
|
| params = {} |
|
|
| ''' |
| misc |
| ''' |
| params['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| params['seed'] = 2022 |
| params['exp_base'] = './experiments' |
| params['ckp_name'] = 'model.pt' |
| params['eval_type'] = 'snt' |
| params['experiment_name'] = 'demo' |
|
|
| ''' |
| geo prior |
| ''' |
| params['batch_size'] = 2048 |
|
|
| ''' |
| geo feature |
| ''' |
| params['cell_size'] = 25 |
|
|
| params = apply_overrides(params, overrides) |
|
|
| return params |
|
|