File size: 4,339 Bytes
c6dfc69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | """Optimizer helpers: split learning rates for AuralFuser train_* vs VGG backbone."""
import torch
import copy
from typing import List, Dict, Set, Any
def manipulate_params(cfg, model):
weight_decay_norm = 0
weight_decay_embed = 0
defaults = {}
defaults["lr"] = cfg.lr
defaults["weight_decay"] = cfg.weight_decay
norm_module_types = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
)
params_training: List[Dict[str, Any]] = []
params_finetuning: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
train_prefixes = (
"patch_embeds",
"f_blocks",
"a_blocks",
"fusion_modules",
"smooth_convs",
"train_proj_v1",
"train_proj_a1",
)
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
hyperparams = copy.copy(defaults)
if 'vgg' in module_name or 'vgg' in module_param_name:
hyperparams['lr'] *= 0.1
params_finetuning.append({"params": [value], "name": [module_name], **hyperparams})
elif (
'train' in module_name
or 'train' in module_param_name
or module_name.startswith(train_prefixes)
):
if (
"relative_position_bias_table" in module_param_name
or "pos_embed" in module_param_name
):
hyperparams["weight_decay"] = 0.0
if isinstance(module, norm_module_types):
hyperparams["weight_decay"] = 0.0
if isinstance(module, torch.nn.Embedding):
hyperparams["weight_decay"] = 0.0
params_training.append({"params": [value], "name": [module_name], **hyperparams})
else:
print('undefined layer type.')
raise NotImplementedError
final_list = params_training + params_finetuning
assert len([p for p in model.parameters() if p.requires_grad]) == len(final_list), 'checksum confirmed not pass.'
return final_list
def group_weight(weight_group, module, weight_decay_value, lr):
group_decay = []
group_no_decay = []
norm_module_types = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
)
for m in module.modules():
if isinstance(m, torch.nn.Linear):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d)):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, norm_module_types):
if m.weight is not None:
group_no_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, torch.nn.Parameter):
group_no_decay.append(m)
elif isinstance(m, torch.nn.Embedding):
group_no_decay.append(m)
else:
print('undefined layer type find.')
raise NotImplementedError
assert len(list(module.parameters())) == len(group_decay) + len(
group_no_decay)
weight_group.append(dict(params=group_decay, weight_deacy=weight_decay_value, lr=lr))
weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr))
return weight_group |