|
|
""" Optimizer Factory w/ Custom Weight Decay |
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
|
""" |
|
|
import re |
|
|
import torch |
|
|
from torch import optim as optim |
|
|
from utils.distributed import is_main_process |
|
|
import logging |
|
|
logger = logging.getLogger(__name__) |
|
|
try: |
|
|
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD |
|
|
has_apex = True |
|
|
except ImportError: |
|
|
has_apex = False |
|
|
|
|
|
|
|
|
def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True): |
|
|
named_param_tuples = [] |
|
|
for name, param in model.named_parameters(): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")): |
|
|
named_param_tuples.append([name, param, 0]) |
|
|
elif name in no_decay_list: |
|
|
named_param_tuples.append([name, param, 0]) |
|
|
else: |
|
|
named_param_tuples.append([name, param, weight_decay]) |
|
|
return named_param_tuples |
|
|
|
|
|
|
|
|
def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr): |
|
|
"""use lr=diff_lr for modules named found in diff_lr_names, |
|
|
otherwise use lr=default_lr |
|
|
|
|
|
Args: |
|
|
named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module |
|
|
diff_lr_names: List(str) |
|
|
diff_lr: float |
|
|
default_lr: float |
|
|
Returns: |
|
|
named_param_tuples_with_lr: List([name, param, weight_decay, lr]) |
|
|
""" |
|
|
named_param_tuples_with_lr = [] |
|
|
logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}") |
|
|
for name, p, wd in named_param_tuples_or_model: |
|
|
use_diff_lr = False |
|
|
for diff_name in diff_lr_names: |
|
|
|
|
|
if re.search(diff_name, name) is not None: |
|
|
logger.info(f"param {name} use different_lr: {diff_lr}") |
|
|
use_diff_lr = True |
|
|
break |
|
|
|
|
|
named_param_tuples_with_lr.append( |
|
|
[name, p, wd, diff_lr if use_diff_lr else default_lr] |
|
|
) |
|
|
|
|
|
if is_main_process(): |
|
|
for name, _, wd, diff_lr in named_param_tuples_with_lr: |
|
|
logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}") |
|
|
|
|
|
return named_param_tuples_with_lr |
|
|
|
|
|
|
|
|
def create_optimizer_params_group(named_param_tuples_with_lr): |
|
|
"""named_param_tuples_with_lr: List([name, param, weight_decay, lr])""" |
|
|
group = {} |
|
|
for name, p, wd, lr in named_param_tuples_with_lr: |
|
|
if wd not in group: |
|
|
group[wd] = {} |
|
|
if lr not in group[wd]: |
|
|
group[wd][lr] = [] |
|
|
group[wd][lr].append(p) |
|
|
|
|
|
optimizer_params_group = [] |
|
|
for wd, lr_groups in group.items(): |
|
|
for lr, p in lr_groups.items(): |
|
|
optimizer_params_group.append(dict( |
|
|
params=p, |
|
|
weight_decay=wd, |
|
|
lr=lr |
|
|
)) |
|
|
logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}") |
|
|
return optimizer_params_group |
|
|
|
|
|
|
|
|
def create_optimizer(args, model, filter_bias_and_bn=True): |
|
|
opt_lower = args.opt.lower() |
|
|
weight_decay = args.weight_decay |
|
|
|
|
|
if hasattr(args, "different_lr") and args.different_lr.enable: |
|
|
diff_lr_module_names = args.different_lr.module_names |
|
|
diff_lr = args.different_lr.lr |
|
|
else: |
|
|
diff_lr_module_names = [] |
|
|
diff_lr = None |
|
|
|
|
|
no_decay = {} |
|
|
if hasattr(model, 'no_weight_decay'): |
|
|
no_decay = model.no_weight_decay() |
|
|
named_param_tuples = add_weight_decay( |
|
|
model, weight_decay, no_decay, filter_bias_and_bn) |
|
|
named_param_tuples = add_different_lr( |
|
|
named_param_tuples, diff_lr_module_names, diff_lr, args.lr) |
|
|
parameters = create_optimizer_params_group(named_param_tuples) |
|
|
|
|
|
if 'fused' in opt_lower: |
|
|
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' |
|
|
|
|
|
opt_args = dict(lr=args.lr, weight_decay=weight_decay) |
|
|
if hasattr(args, 'opt_eps') and args.opt_eps is not None: |
|
|
opt_args['eps'] = args.opt_eps |
|
|
if hasattr(args, 'opt_betas') and args.opt_betas is not None: |
|
|
opt_args['betas'] = args.opt_betas |
|
|
if hasattr(args, 'opt_args') and args.opt_args is not None: |
|
|
opt_args.update(args.opt_args) |
|
|
|
|
|
opt_split = opt_lower.split('_') |
|
|
opt_lower = opt_split[-1] |
|
|
if opt_lower == 'sgd' or opt_lower == 'nesterov': |
|
|
opt_args.pop('eps', None) |
|
|
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) |
|
|
elif opt_lower == 'momentum': |
|
|
opt_args.pop('eps', None) |
|
|
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) |
|
|
elif opt_lower == 'adam': |
|
|
optimizer = optim.Adam(parameters, **opt_args) |
|
|
elif opt_lower == 'adamw': |
|
|
optimizer = optim.AdamW(parameters, **opt_args) |
|
|
else: |
|
|
assert False and "Invalid optimizer" |
|
|
raise ValueError |
|
|
return optimizer |
|
|
|