| """Optimizer.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import TypedDict |
|
|
| from torch import nn |
| from torch.nn import GroupNorm, LayerNorm |
| from torch.nn.modules.batchnorm import _BatchNorm |
| from torch.nn.modules.instancenorm import _InstanceNorm |
| from torch.optim.optimizer import Optimizer |
| from typing_extensions import NotRequired |
|
|
| from vis4d.common.logging import rank_zero_info |
| from vis4d.config import instantiate_classes |
| from vis4d.config.typing import OptimizerConfig, ParamGroupCfg |
|
|
| from .scheduler import LRSchedulerWrapper |
|
|
|
|
| class ParamGroup(TypedDict): |
| """Parameter dictionary. |
| |
| Attributes: |
| params (list[nn.Parameter]): List of parameters. |
| lr (NotRequired[float]): Learning rate. |
| weight_decay (NotRequired[float]): Weight decay. |
| """ |
|
|
| params: list[nn.Parameter] |
| lr: NotRequired[float] |
| weight_decay: NotRequired[float] |
|
|
|
|
| |
| |
| def set_up_optimizers( |
| optimizers_cfg: list[OptimizerConfig], |
| models: list[nn.Module], |
| steps_per_epoch: int = -1, |
| ) -> tuple[list[Optimizer], list[LRSchedulerWrapper]]: |
| """Set up optimizers.""" |
| optimizers = [] |
| lr_schedulers = [] |
| for optim_cfg, model in zip(optimizers_cfg, models): |
| optimizer = configure_optimizer(optim_cfg, model) |
| optimizers.append(optimizer) |
|
|
| if optim_cfg.lr_schedulers is not None: |
| lr_schedulers.append( |
| LRSchedulerWrapper( |
| optim_cfg.lr_schedulers, optimizer, steps_per_epoch |
| ) |
| ) |
|
|
| return optimizers, lr_schedulers |
|
|
|
|
| def configure_optimizer( |
| optim_cfg: OptimizerConfig, model: nn.Module |
| ) -> Optimizer: |
| """Configure optimizer with parameter groups.""" |
| param_groups_cfg = optim_cfg.get("param_groups", None) |
|
|
| if param_groups_cfg is None: |
| return instantiate_classes( |
| optim_cfg.optimizer, params=model.parameters() |
| ) |
|
|
| params = [] |
| base_lr = optim_cfg.optimizer["init_args"].lr |
| weight_decay = optim_cfg.optimizer["init_args"].get("weight_decay", None) |
| for group in param_groups_cfg: |
| lr_mult = group.get("lr_mult", 1.0) |
| decay_mult = group.get("decay_mult", 1.0) |
| norm_decay_mult = group.get("norm_decay_mult", None) |
| bias_decay_mult = group.get("bias_decay_mult", None) |
|
|
| param_group: ParamGroup = {"params": [], "lr": base_lr * lr_mult} |
|
|
| if weight_decay is not None: |
| if norm_decay_mult is not None: |
| param_group["weight_decay"] = weight_decay * norm_decay_mult |
| elif bias_decay_mult is not None: |
| param_group["weight_decay"] = weight_decay * bias_decay_mult |
| else: |
| param_group["weight_decay"] = weight_decay * decay_mult |
|
|
| params.append(param_group) |
|
|
| |
| param_group = {"params": [], "lr": base_lr} |
| if weight_decay is not None: |
| param_group["weight_decay"] = weight_decay |
| params.append(param_group) |
|
|
| |
| add_params(params, model, param_groups_cfg) |
|
|
| return instantiate_classes(optim_cfg.optimizer, params=params) |
|
|
|
|
| def add_params( |
| params: list[ParamGroup], |
| module: nn.Module, |
| param_groups_cfg: list[ParamGroupCfg], |
| prefix: str = "", |
| ) -> None: |
| """Add all parameters of module to the params list. |
| |
| The parameters of the given module will be added to the list of param |
| groups, with specific rules defined by paramwise_cfg. |
| |
| Args: |
| params (list[DictStrAny]): A list of param groups, it will be modified |
| in place. |
| module (nn.Module): The module to be added. |
| param_groups_cfg (dict[str, list[str] | float]): The configuration |
| of the param groups. |
| prefix (str): The prefix of the module. Default: ''. |
| """ |
| for name, param in module.named_parameters(recurse=False): |
| if not param.requires_grad: |
| params[-1]["params"].append(param) |
| continue |
|
|
| is_norm = isinstance( |
| module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm) |
| ) |
|
|
| |
| is_custom = False |
| msg = f"{prefix}.{name}" |
| for i, group in enumerate(param_groups_cfg): |
| for key in group["custom_keys"]: |
| if key not in f"{prefix}.{name}": |
| continue |
| norm_decay_mult = group.get("norm_decay_mult", None) |
| bias_decay_mult = group.get("bias_decay_mult", None) |
| if group.get("lr_mult", None) is not None: |
| msg += f" with lr_mult: {group['lr_mult']}" |
| if norm_decay_mult is not None: |
| if not is_norm: |
| continue |
| msg += f" with norm_decay_mult: {norm_decay_mult}" |
| if bias_decay_mult is not None: |
| if name != "bias": |
| continue |
| msg += f" with bias_decay_mult: {bias_decay_mult}" |
| if group.get("decay_mult", None) is not None: |
| msg += f" with decay_mult: {group['decay_mult']}" |
| params[i]["params"].append(param) |
| is_custom = True |
| break |
| if is_custom: |
| break |
|
|
| if is_custom: |
| rank_zero_info(msg) |
| else: |
| |
| params[-1]["params"].append(param) |
|
|
| for child_name, child_mod in module.named_children(): |
| child_prefix = f"{prefix}.{child_name}" if prefix else child_name |
| add_params(params, child_mod, param_groups_cfg, prefix=child_prefix) |
|
|