| | from torch.optim import AdamW |
| | from typing import List |
| |
|
| | def get_optimizer( |
| | model, |
| | args, |
| | no_decay_names: List[str] = ['relative_position_bias_table', 'rpe_mlp', 'logit_scale'] |
| | ): |
| | param_groups = [] |
| | param_group_lookup = dict() |
| | for n, p in model.named_parameters(): |
| | if p.requires_grad: |
| | |
| | if p.ndim == 1 or n.endswith(".bias") or "absolute_pos_embed" in n: |
| | weight_decay = 0.0 |
| | else: |
| | found = False |
| | for no_decay_name in no_decay_names: |
| | if no_decay_name in n: |
| | found = True |
| | break |
| | if found: |
| | weight_decay = 0.0 |
| | else: |
| | weight_decay = args.weight_decay |
| |
|
| | |
| | if "unet" in n: |
| | scale = 0.01 |
| | elif "embeddings" in n: |
| | scale = 10.0 |
| | else: |
| | scale = 1.0 |
| |
|
| | lr = scale * args.lr |
| | |
| | if (lr, weight_decay) not in param_group_lookup: |
| | new_param_group = { |
| | "params" : [], |
| | "weight_decay" : weight_decay, |
| | "lr" : lr |
| | } |
| | param_group_lookup[(lr, weight_decay)] = new_param_group |
| | param_groups.append(new_param_group) |
| | |
| | param_group = param_group_lookup[(lr, weight_decay)] |
| | param_group["params"].append(p) |
| | |
| | |
| | |
| | optimizer = AdamW(params=param_groups) |
| | return optimizer |
| |
|