ecodepth / EcoDepth /optimizer.py
aradhye's picture
initial commit
e628ca8
from torch.optim import AdamW
from typing import List
def get_optimizer(
model,
args,
no_decay_names: List[str] = ['relative_position_bias_table', 'rpe_mlp', 'logit_scale']
):
param_groups = []
param_group_lookup = dict()
for n, p in model.named_parameters():
if p.requires_grad:
# Determine weight decay
if p.ndim == 1 or n.endswith(".bias") or "absolute_pos_embed" in n:
weight_decay = 0.0
else:
found = False
for no_decay_name in no_decay_names:
if no_decay_name in n:
found = True
break
if found:
weight_decay = 0.0
else:
weight_decay = args.weight_decay
# Determine LR scale
if "unet" in n:
scale = 0.01
elif "embeddings" in n:
scale = 10.0
else:
scale = 1.0
lr = scale * args.lr
if (lr, weight_decay) not in param_group_lookup:
new_param_group = {
"params" : [],
"weight_decay" : weight_decay,
"lr" : lr
}
param_group_lookup[(lr, weight_decay)] = new_param_group
param_groups.append(new_param_group)
param_group = param_group_lookup[(lr, weight_decay)]
param_group["params"].append(p)
# LR and weight_decay are already set above
# So setting these below has no effect
optimizer = AdamW(params=param_groups)
return optimizer