| """Optimizer helpers: split learning rates for AuralFuser train_* vs VGG backbone.""" |
| import torch |
| import copy |
| from typing import List, Dict, Set, Any |
|
|
|
|
| def manipulate_params(cfg, model): |
| weight_decay_norm = 0 |
| weight_decay_embed = 0 |
| defaults = {} |
| defaults["lr"] = cfg.lr |
| defaults["weight_decay"] = cfg.weight_decay |
|
|
| norm_module_types = ( |
| torch.nn.BatchNorm1d, |
| torch.nn.BatchNorm2d, |
| torch.nn.BatchNorm3d, |
| torch.nn.SyncBatchNorm, |
| torch.nn.GroupNorm, |
| torch.nn.InstanceNorm1d, |
| torch.nn.InstanceNorm2d, |
| torch.nn.InstanceNorm3d, |
| torch.nn.LayerNorm, |
| torch.nn.LocalResponseNorm, |
| ) |
|
|
| params_training: List[Dict[str, Any]] = [] |
| params_finetuning: List[Dict[str, Any]] = [] |
| memo: Set[torch.nn.parameter.Parameter] = set() |
|
|
| train_prefixes = ( |
| "patch_embeds", |
| "f_blocks", |
| "a_blocks", |
| "fusion_modules", |
| "smooth_convs", |
| "train_proj_v1", |
| "train_proj_a1", |
| ) |
|
|
| for module_name, module in model.named_modules(): |
| for module_param_name, value in module.named_parameters(recurse=False): |
| if not value.requires_grad: |
| continue |
| |
| if value in memo: |
| continue |
| memo.add(value) |
| hyperparams = copy.copy(defaults) |
| if 'vgg' in module_name or 'vgg' in module_param_name: |
| hyperparams['lr'] *= 0.1 |
| params_finetuning.append({"params": [value], "name": [module_name], **hyperparams}) |
| elif ( |
| 'train' in module_name |
| or 'train' in module_param_name |
| or module_name.startswith(train_prefixes) |
| ): |
| if ( |
| "relative_position_bias_table" in module_param_name |
| or "pos_embed" in module_param_name |
| ): |
| hyperparams["weight_decay"] = 0.0 |
| if isinstance(module, norm_module_types): |
| hyperparams["weight_decay"] = 0.0 |
| if isinstance(module, torch.nn.Embedding): |
| hyperparams["weight_decay"] = 0.0 |
| params_training.append({"params": [value], "name": [module_name], **hyperparams}) |
| else: |
| print('undefined layer type.') |
| raise NotImplementedError |
| final_list = params_training + params_finetuning |
| assert len([p for p in model.parameters() if p.requires_grad]) == len(final_list), 'checksum confirmed not pass.' |
| return final_list |
|
|
|
|
| def group_weight(weight_group, module, weight_decay_value, lr): |
| group_decay = [] |
| group_no_decay = [] |
| norm_module_types = ( |
| torch.nn.BatchNorm1d, |
| torch.nn.BatchNorm2d, |
| torch.nn.BatchNorm3d, |
| torch.nn.SyncBatchNorm, |
| torch.nn.GroupNorm, |
| torch.nn.InstanceNorm1d, |
| torch.nn.InstanceNorm2d, |
| torch.nn.InstanceNorm3d, |
| torch.nn.LayerNorm, |
| torch.nn.LocalResponseNorm, |
| ) |
|
|
| for m in module.modules(): |
| if isinstance(m, torch.nn.Linear): |
| group_decay.append(m.weight) |
| if m.bias is not None: |
| group_no_decay.append(m.bias) |
| elif isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose2d, |
| torch.nn.ConvTranspose3d)): |
| group_decay.append(m.weight) |
| if m.bias is not None: |
| group_no_decay.append(m.bias) |
| elif isinstance(m, norm_module_types): |
| if m.weight is not None: |
| group_no_decay.append(m.weight) |
| if m.bias is not None: |
| group_no_decay.append(m.bias) |
| elif isinstance(m, torch.nn.Parameter): |
| group_no_decay.append(m) |
| elif isinstance(m, torch.nn.Embedding): |
| group_no_decay.append(m) |
| else: |
| print('undefined layer type find.') |
| raise NotImplementedError |
|
|
| assert len(list(module.parameters())) == len(group_decay) + len( |
| group_no_decay) |
| weight_group.append(dict(params=group_decay, weight_deacy=weight_decay_value, lr=lr)) |
| weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) |
| return weight_group |