File size: 5,001 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, List, Optional

import lightning.pytorch as pl
import lightning.pytorch as L
from torch.optim import Optimizer
from torch.optim.optimizer import ParamsT

from nemo.lightning.megatron_parallel import MegatronParallel
from nemo.lightning.pytorch.optim.base import LRSchedulerModule, OptimizerModule


def _param_does_not_have_wd(param_name, param):
    return 'bias' in param_name


def _extract_model_params_for_optim(model, weight_decay=0, no_weight_decay_cond=None):
    params_with_wd, params_without_wd = [], []
    if no_weight_decay_cond is not None:
        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue
            if no_weight_decay_cond(name, param):
                params_without_wd.append(param)
            else:
                params_with_wd.append(param)
    else:
        params_with_wd = list(filter(lambda x: x.requires_grad, model.parameters()))

    assert max(map(len, (params_with_wd, params_without_wd))) > 0, "Expected at least one optimizer with params"

    return [
        {'params': params, 'weight_decay': wd}
        for params, wd in zip((params_with_wd, params_without_wd), (weight_decay, 0))
    ]


class PytorchOptimizerModule(OptimizerModule):
    """A OptimizerModule for pytorch optimizers.

    Attributes:
        optimizer_fn (Callable[[ParamsT], Optimizer]): Configuration for the optimizer.
        no_weight_decay_cond (Optional[Callable]): Condition for no weight decay.
        scale_lr_cond (Optional[Callable]): Condition for scaling learning rate.
        lr_mult (float): Learning rate multiplier.

    Example::

        optimizer_fn = run.Partial(
            SGD,
            lr=lr,
            weight_decay=wd,
        )
        lr_scheduler = MyLRSchedulerModule(...)
        optimizer_module = PytorchOptimizerModule(optimizer_fn, lr_scheduler)

    Methods:
        setup(model): Sets up the optimizer.
        optimizers(model): Defines the optimizers.
    """

    def __init__(
        self,
        optimizer_fn: Callable[[ParamsT], Optimizer],
        lr_scheduler: Optional[LRSchedulerModule] = None,
        no_weight_decay_cond: Optional[Callable] = _param_does_not_have_wd,
        scale_lr_cond: Optional[Callable] = None,
        lr_mult: float = 1.0,
    ):
        """Initializes the PytorchOptimizerModule.

        Args:
            optimizer_fn (Callable[[ParamsT], Optimizer]): Configuration for the optimizer.
            lr_scheduler (Optional[LRSchedulerModule]): The learning rate scheduler module.
            no_weight_decay_cond (Optional[Callable]): Condition for no weight decay.
            scale_lr_cond (Optional[Callable]): Condition for scaling learning rate.
            lr_mult (float): Learning rate multiplier.
        """

        super().__init__(lr_scheduler=lr_scheduler)
        self.optimizer_fn = optimizer_fn
        self.no_weight_decay_cond = no_weight_decay_cond
        self.scale_lr_cond = scale_lr_cond
        self.lr_mult = lr_mult

    def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        """nooop"""
        # Noop
        pass

    def optimizers(self, model) -> List[Optimizer]:
        """Defines the optimizers.

        Args:
            model (nn.Module): The model for which the optimizers are being defined.

        Returns:
            List[Optimizer]: The list of optimizers.

        Raises:
            ValueError: If the model is an instance of MegatronParallel.
        """

        if isinstance(model, MegatronParallel):
            raise ValueError("Model cannot be an instance of MegatronParallel")

        wd = self.optimizer_fn.keywords.get('weight_decay', 0)
        optim = self.optimizer_fn(_extract_model_params_for_optim(model, wd, self.no_weight_decay_cond))
        self._optimizers = optim
        if not isinstance(optim, list):
            optim = [optim]
        if self.lr_scheduler is None:
            return optim
        else:
            return [self.lr_scheduler.scheduler(model, opt) for opt in optim]

    def connect(self, model: L.LightningModule) -> None:
        """Connects the optimizer module to the model.

        Args:
            model (L.LightningModule): The model to which the optimizer module is being connected.
        """
        model.configure_optimizers = lambda: self.optimizers(model)