File size: 5,863 Bytes
f71ac1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""Optimizer."""

from __future__ import annotations

from typing import TypedDict

from torch import nn
from torch.nn import GroupNorm, LayerNorm
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
from torch.optim.optimizer import Optimizer
from typing_extensions import NotRequired

from vis4d.common.logging import rank_zero_info
from vis4d.config import instantiate_classes
from vis4d.config.typing import OptimizerConfig, ParamGroupCfg

from .scheduler import LRSchedulerWrapper


class ParamGroup(TypedDict):
    """Parameter dictionary.

    Attributes:
        params (list[nn.Parameter]): List of parameters.
        lr (NotRequired[float]): Learning rate.
        weight_decay (NotRequired[float]): Weight decay.
    """

    params: list[nn.Parameter]
    lr: NotRequired[float]
    weight_decay: NotRequired[float]


# TODO: Add true support for multiple optimizers. This will need to
# modify config to specify which optimizer to use for which module.
def set_up_optimizers(
    optimizers_cfg: list[OptimizerConfig],
    models: list[nn.Module],
    steps_per_epoch: int = -1,
) -> tuple[list[Optimizer], list[LRSchedulerWrapper]]:
    """Set up optimizers."""
    optimizers = []
    lr_schedulers = []
    for optim_cfg, model in zip(optimizers_cfg, models):
        optimizer = configure_optimizer(optim_cfg, model)
        optimizers.append(optimizer)

        if optim_cfg.lr_schedulers is not None:
            lr_schedulers.append(
                LRSchedulerWrapper(
                    optim_cfg.lr_schedulers, optimizer, steps_per_epoch
                )
            )

    return optimizers, lr_schedulers


def configure_optimizer(
    optim_cfg: OptimizerConfig, model: nn.Module
) -> Optimizer:
    """Configure optimizer with parameter groups."""
    param_groups_cfg = optim_cfg.get("param_groups", None)

    if param_groups_cfg is None:
        return instantiate_classes(
            optim_cfg.optimizer, params=model.parameters()
        )

    params = []
    base_lr = optim_cfg.optimizer["init_args"].lr
    weight_decay = optim_cfg.optimizer["init_args"].get("weight_decay", None)
    for group in param_groups_cfg:
        lr_mult = group.get("lr_mult", 1.0)
        decay_mult = group.get("decay_mult", 1.0)
        norm_decay_mult = group.get("norm_decay_mult", None)
        bias_decay_mult = group.get("bias_decay_mult", None)

        param_group: ParamGroup = {"params": [], "lr": base_lr * lr_mult}

        if weight_decay is not None:
            if norm_decay_mult is not None:
                param_group["weight_decay"] = weight_decay * norm_decay_mult
            elif bias_decay_mult is not None:
                param_group["weight_decay"] = weight_decay * bias_decay_mult
            else:
                param_group["weight_decay"] = weight_decay * decay_mult

        params.append(param_group)

    # Create a param group for the rest of the parameters
    param_group = {"params": [], "lr": base_lr}
    if weight_decay is not None:
        param_group["weight_decay"] = weight_decay
    params.append(param_group)

    # Add the parameters to the param groups
    add_params(params, model, param_groups_cfg)

    return instantiate_classes(optim_cfg.optimizer, params=params)


def add_params(
    params: list[ParamGroup],
    module: nn.Module,
    param_groups_cfg: list[ParamGroupCfg],
    prefix: str = "",
) -> None:
    """Add all parameters of module to the params list.

    The parameters of the given module will be added to the list of param
    groups, with specific rules defined by paramwise_cfg.

    Args:
        params (list[DictStrAny]): A list of param groups, it will be modified
            in place.
        module (nn.Module): The module to be added.
        param_groups_cfg (dict[str, list[str] | float]): The configuration
            of the param groups.
        prefix (str): The prefix of the module. Default: ''.
    """
    for name, param in module.named_parameters(recurse=False):
        if not param.requires_grad:
            params[-1]["params"].append(param)
            continue

        is_norm = isinstance(
            module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
        )

        # if the parameter match one of the custom keys, ignore other rules
        is_custom = False
        msg = f"{prefix}.{name}"
        for i, group in enumerate(param_groups_cfg):
            for key in group["custom_keys"]:
                if key not in f"{prefix}.{name}":
                    continue
                norm_decay_mult = group.get("norm_decay_mult", None)
                bias_decay_mult = group.get("bias_decay_mult", None)
                if group.get("lr_mult", None) is not None:
                    msg += f" with lr_mult: {group['lr_mult']}"
                if norm_decay_mult is not None:
                    if not is_norm:
                        continue
                    msg += f" with norm_decay_mult: {norm_decay_mult}"
                if bias_decay_mult is not None:
                    if name != "bias":
                        continue
                    msg += f" with bias_decay_mult: {bias_decay_mult}"
                if group.get("decay_mult", None) is not None:
                    msg += f" with decay_mult: {group['decay_mult']}"
                params[i]["params"].append(param)
                is_custom = True
                break
            if is_custom:
                break

        if is_custom:
            rank_zero_info(msg)
        else:
            # add parameter to the last param group
            params[-1]["params"].append(param)

    for child_name, child_mod in module.named_children():
        child_prefix = f"{prefix}.{child_name}" if prefix else child_name
        add_params(params, child_mod, param_groups_cfg, prefix=child_prefix)