Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Optional
import lightning.pytorch as pl
import lightning.pytorch as L
from torch.optim import Optimizer
from torch.optim.optimizer import ParamsT
from nemo.lightning.megatron_parallel import MegatronParallel
from nemo.lightning.pytorch.optim.base import LRSchedulerModule, OptimizerModule
def _param_does_not_have_wd(param_name, param):
return 'bias' in param_name
def _extract_model_params_for_optim(model, weight_decay=0, no_weight_decay_cond=None):
params_with_wd, params_without_wd = [], []
if no_weight_decay_cond is not None:
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if no_weight_decay_cond(name, param):
params_without_wd.append(param)
else:
params_with_wd.append(param)
else:
params_with_wd = list(filter(lambda x: x.requires_grad, model.parameters()))
assert max(map(len, (params_with_wd, params_without_wd))) > 0, "Expected at least one optimizer with params"
return [
{'params': params, 'weight_decay': wd}
for params, wd in zip((params_with_wd, params_without_wd), (weight_decay, 0))
]
class PytorchOptimizerModule(OptimizerModule):
"""A OptimizerModule for pytorch optimizers.
Attributes:
optimizer_fn (Callable[[ParamsT], Optimizer]): Configuration for the optimizer.
no_weight_decay_cond (Optional[Callable]): Condition for no weight decay.
scale_lr_cond (Optional[Callable]): Condition for scaling learning rate.
lr_mult (float): Learning rate multiplier.
Example::
optimizer_fn = run.Partial(
SGD,
lr=lr,
weight_decay=wd,
)
lr_scheduler = MyLRSchedulerModule(...)
optimizer_module = PytorchOptimizerModule(optimizer_fn, lr_scheduler)
Methods:
setup(model): Sets up the optimizer.
optimizers(model): Defines the optimizers.
"""
def __init__(
self,
optimizer_fn: Callable[[ParamsT], Optimizer],
lr_scheduler: Optional[LRSchedulerModule] = None,
no_weight_decay_cond: Optional[Callable] = _param_does_not_have_wd,
scale_lr_cond: Optional[Callable] = None,
lr_mult: float = 1.0,
):
"""Initializes the PytorchOptimizerModule.
Args:
optimizer_fn (Callable[[ParamsT], Optimizer]): Configuration for the optimizer.
lr_scheduler (Optional[LRSchedulerModule]): The learning rate scheduler module.
no_weight_decay_cond (Optional[Callable]): Condition for no weight decay.
scale_lr_cond (Optional[Callable]): Condition for scaling learning rate.
lr_mult (float): Learning rate multiplier.
"""
super().__init__(lr_scheduler=lr_scheduler)
self.optimizer_fn = optimizer_fn
self.no_weight_decay_cond = no_weight_decay_cond
self.scale_lr_cond = scale_lr_cond
self.lr_mult = lr_mult
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
"""nooop"""
# Noop
pass
def optimizers(self, model) -> List[Optimizer]:
"""Defines the optimizers.
Args:
model (nn.Module): The model for which the optimizers are being defined.
Returns:
List[Optimizer]: The list of optimizers.
Raises:
ValueError: If the model is an instance of MegatronParallel.
"""
if isinstance(model, MegatronParallel):
raise ValueError("Model cannot be an instance of MegatronParallel")
wd = self.optimizer_fn.keywords.get('weight_decay', 0)
optim = self.optimizer_fn(_extract_model_params_for_optim(model, wd, self.no_weight_decay_cond))
self._optimizers = optim
if not isinstance(optim, list):
optim = [optim]
if self.lr_scheduler is None:
return optim
else:
return [self.lr_scheduler.scheduler(model, opt) for opt in optim]
def connect(self, model: L.LightningModule) -> None:
"""Connects the optimizer module to the model.
Args:
model (L.LightningModule): The model to which the optimizer module is being connected.
"""
model.configure_optimizers = lambda: self.optimizers(model)