File size: 7,684 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import types
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import List, Optional

import lightning.pytorch as L
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from torch.optim import Optimizer

from nemo.lightning.io.mixin import IOMixin
from nemo.lightning.megatron_parallel import CallbackMethods


class LRSchedulerModule(L.Callback, CallbackMethods, IOMixin, ABC):
    """A module to standardize the learning rate scheduler setup and configuration.

    This class decouples the learning rate scheduler from the model, similar to how the LightningDataModule
    decouples data handling. It also acts as a Callback to hook into the training loop, which can be useful
    for adding custom all-reduces, logging, early stopping, etc. Next to that standard Lightning callback-event,
    this also supports hooking into the Megatron forward-backward function at a granular level.

    Example::

        class MyLRSchedulerModule(LRSchedulerModule):
            def setup(self, model, optimizer):
                # Custom setup logic
                ...

            def scheduler(self, model, optimizers):
                # Define and return the learning rate scheduler
                ...

    Methods:
        setup(model, optimizer): Sets up the learning rate scheduler.
        scheduler(model, optimizers): Abstract method to define the learning rate scheduler.
        __call__(model, optimizers): Calls the setup and scheduler methods.
    """

    def connect(self, model, optimizer) -> None:
        """Sets up the learning rate scheduler.

        Args:
            model: The model for which the scheduler is being set up.
            optimizer: The optimizer for which the scheduler is being set up.
        """
        ...

    @abstractmethod
    def scheduler(self, model, optimizers) -> OptimizerLRScheduler:
        """Abstract method to define the learning rate scheduler.

        Args:
            model: The model for which the scheduler is being defined.
            optimizers: The optimizers for which the scheduler is being defined.

        Returns:
            OptimizerLRScheduler: The learning rate scheduler.
        """
        raise NotImplementedError("The scheduler method should be implemented by subclasses.")

    def __call__(self, model, optimizers):
        """Calls the setup and scheduler methods.

        Args:
            model: The model for which the scheduler is being called.
            optimizers: The optimizers for which the scheduler is being called.

        Returns:
            OptimizerLRScheduler: The learning rate scheduler.
        """

        self.connect(model, optimizers)

        self._scheduler = self.scheduler(model, optimizers)

        if not isinstance(self._scheduler, (dict, tuple)):
            return optimizers, self._scheduler

        return self._scheduler


class OptimizerModule(L.Callback, CallbackMethods, IOMixin, ABC):
    """A module to standardize the optimizer setup and configuration.

    This class decouples the optimizer from the model, similar to how the LightningDataModule
    decouples data handling. It also acts as a Callback to hook into the training loop, which can be useful
    for adding custom all-reduces, logging, early stopping, etc. Next to that standard Lightning callback-event,
    this also supports hooking into the Megatron forward-backward function at a granular level.

    Attributes:
        lr_scheduler (Optional[LRSchedulerModule]): The learning rate scheduler module.

    Example::

        class MyOptimizerModule(OptimizerModule):
            def __init__(self, lr_scheduler=None):
                super().__init__(lr_scheduler)

            def setup(self, model):
                # Custom setup logic
                ...

            def optimizers(self, model):
                # Define and return the optimizers
                ...

    Methods:
        connect(model, trainer): Connects the optimizer module to the model and trainer.
        setup(model): Sets up the optimizer.
        optimizers(model): Abstract method to define the optimizers.
        __call__(model, megatron_parallel): Calls the setup and optimizers methods.
    """

    def __init__(self, lr_scheduler: Optional[LRSchedulerModule]):
        """Initializes the OptimizerModule.

        Args:
            lr_scheduler (Optional[LRSchedulerModule]): The learning rate scheduler module.
        """
        self.lr_scheduler = lr_scheduler

    def connect(self, model: L.LightningModule) -> None:
        """Connects the optimizer module to the model and trainer.

        Args:
            model (L.LightningModule): The model to which the optimizer module is being connected.
        """

        def custom_configure_optimizers(lightning_module_self, megatron_parallel=None):
            opt = self(lightning_module_self, megatron_parallel=megatron_parallel)
            return opt

        model.configure_optimizers = types.MethodType(custom_configure_optimizers, model)
        model.optim = self

        if hasattr(self, "__io__") and hasattr(model, "__io__"):
            if hasattr(model.__io__, "optim"):
                model.__io__.optim = deepcopy(self.__io__)

    @abstractmethod
    def optimizers(self, model) -> List[Optimizer]:
        """Abstract method to define the optimizers.

        Args:
            model: The model for which the optimizers are being defined.

        Returns:
            List[Optimizer]: The list of optimizers.
        """
        raise NotImplementedError("The optimizers method should be implemented by subclasses.")

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx) -> None:
        # pylint: disable=C0116
        if self._optimizers is not None:
            if len(self._optimizers[0].param_groups) > 0:
                lr = self._optimizers[0].param_groups[0]['lr']
            else:
                lr = 0.0
            pl_module.log('lr', lr, batch_size=1, prog_bar=True)

    def __call__(self, model: L.LightningModule, megatron_parallel=None) -> OptimizerLRScheduler:
        """Calls the setup and optimizers methods.

        Args:
            model (L.LightningModule): The model for which the optimizers are being called.
            megatron_parallel: Optional parallel model.

        Returns:
            OptimizerLRScheduler: The optimizers and optionally the learning rate scheduler.
        """
        _model = model if megatron_parallel is None else megatron_parallel
        callbacks = _model.trainer.callbacks
        if self not in callbacks:
            callbacks.append(self)
        if self.lr_scheduler is not None and self.lr_scheduler not in callbacks:
            callbacks.append(self.lr_scheduler)
        self._optimizers = self.optimizers(_model)

        _opt = self._optimizers[0] if len(self._optimizers) == 1 else self._optimizers

        if self.lr_scheduler is not None:
            with_scheduler = self.lr_scheduler(_model, _opt)

            return with_scheduler

        return self._optimizers