| |
| |
|
|
| """Optimizer.""" |
|
|
| import torch |
|
|
| import slowfast.utils.lr_policy as lr_policy |
|
|
|
|
| def construct_optimizer(model, cfg): |
| """ |
| Construct a stochastic gradient descent or ADAM optimizer with momentum. |
| Details can be found in: |
| Herbert Robbins, and Sutton Monro. "A stochastic approximation method." |
| and |
| Diederik P.Kingma, and Jimmy Ba. |
| "Adam: A Method for Stochastic Optimization." |
| |
| Args: |
| model (model): model to perform stochastic gradient descent |
| optimization or ADAM optimization. |
| cfg (config): configs of hyper-parameters of SGD or ADAM, includes base |
| learning rate, momentum, weight_decay, dampening, and etc. |
| """ |
| |
| bn_params = [] |
| |
| non_bn_parameters = [] |
| for name, p in model.named_parameters(): |
| if "bn" in name: |
| bn_params.append(p) |
| else: |
| non_bn_parameters.append(p) |
| |
| |
| |
| |
| optim_params = [ |
| {"params": bn_params, "weight_decay": cfg.BN.WEIGHT_DECAY}, |
| {"params": non_bn_parameters, "weight_decay": cfg.SOLVER.WEIGHT_DECAY}, |
| ] |
| |
| assert len(list(model.parameters())) == len(non_bn_parameters) + len( |
| bn_params |
| ), "parameter size does not match: {} + {} != {}".format( |
| len(non_bn_parameters), len(bn_params), len(list(model.parameters())) |
| ) |
|
|
| if cfg.SOLVER.OPTIMIZING_METHOD == "sgd": |
| return torch.optim.SGD( |
| optim_params, |
| lr=cfg.SOLVER.BASE_LR, |
| momentum=cfg.SOLVER.MOMENTUM, |
| weight_decay=cfg.SOLVER.WEIGHT_DECAY, |
| dampening=cfg.SOLVER.DAMPENING, |
| nesterov=cfg.SOLVER.NESTEROV, |
| ) |
| elif cfg.SOLVER.OPTIMIZING_METHOD == "adam": |
| return torch.optim.Adam( |
| optim_params, |
| lr=cfg.SOLVER.BASE_LR, |
| betas=(0.9, 0.999), |
| weight_decay=cfg.SOLVER.WEIGHT_DECAY, |
| ) |
| else: |
| raise NotImplementedError( |
| "Does not support {} optimizer".format(cfg.SOLVER.OPTIMIZING_METHOD) |
| ) |
|
|
|
|
| def get_epoch_lr(cur_epoch, cfg): |
| """ |
| Retrieves the lr for the given epoch (as specified by the lr policy). |
| Args: |
| cfg (config): configs of hyper-parameters of ADAM, includes base |
| learning rate, betas, and weight decays. |
| cur_epoch (float): the number of epoch of the current training stage. |
| """ |
| return lr_policy.get_lr_at_epoch(cfg, cur_epoch) |
|
|
| def get_iter_lr(cur_iter, cfg): |
| """ |
| Retrieves the lr for the given iter (as specified by the lr policy). |
| Args: |
| cfg (config): configs of hyper-parameters of ADAM, includes base |
| learning rate, betas, and weight decays. |
| cur_epoch (float): the number of epoch of the current training stage. |
| """ |
| lr=lr_policy.get_lr_at_iter(cfg, cur_iter) |
| |
| return lr |
|
|
|
|
| def set_lr(optimizer, new_lr): |
| """ |
| Sets the optimizer lr to the specified value. |
| Args: |
| optimizer (optim): the optimizer using to optimize the current network. |
| new_lr (float): the new learning rate to set. |
| """ |
| for param_group in optimizer.param_groups: |
| param_group["lr"] = new_lr |
|
|