| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from unittest.mock import MagicMock, call, patch |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from lightning.pytorch.trainer.states import TrainerFn |
| |
|
| | from nemo.collections.llm import fn |
| | from nemo.lightning.pytorch.callbacks.peft import PEFT, WrappedAdapterIO |
| | from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO |
| |
|
| |
|
| | class TestPEFT: |
| | class DummyPEFT(PEFT): |
| | def transform(self, module, name=None, prefix=None): |
| | return module |
| |
|
| | def freeze_model(self, module): |
| | super().freeze_model(module) |
| | self.is_called = True |
| | return module |
| |
|
| | class DummyModel(nn.Module, fn.FNMixin): |
| | def __init__(self): |
| | super().__init__() |
| | self.linear = nn.Linear(10, 10) |
| | self.conv = nn.Conv2d(3, 3, 3) |
| |
|
| | def test_peft_call(self): |
| | model = self.DummyModel() |
| | peft = self.DummyPEFT() |
| |
|
| | transformed_model = peft(model) |
| |
|
| | assert ( |
| | hasattr(peft, "is_called") and peft.is_called == True |
| | ), "peft methods may subclass `freeze_model()`, so it must be called" |
| | assert transformed_model.linear.weight.requires_grad == False |
| | assert transformed_model.conv.weight.requires_grad == False |
| |
|
| | def test_linear_adapter(self): |
| | from nemo.collections.llm.peft.lora import LinearAdapter |
| |
|
| | for has_bias in [True, False]: |
| | linear = nn.Linear(10, 10, bias=has_bias) |
| | linear_adapter = LinearAdapter(linear) |
| | bias_in_state_dict = 'bias' in linear.state_dict() |
| | if has_bias: |
| | assert bias_in_state_dict |
| | else: |
| | assert not bias_in_state_dict |
| |
|
| | |
| | for key, val in linear.state_dict().items(): |
| | assert key in linear_adapter.state_dict(), f"Key {key} not found in LinearAdapter" |
| | assert torch.equal(val, linear_adapter.state_dict()[key]), f"Key {key} diff. val in LinearAdapter" |
| | |
| | for key, val in linear_adapter.state_dict().items(): |
| | if key in linear.state_dict(): |
| | continue |
| | assert key in ['lora_a.weight', 'lora_b.weight'] |
| |
|
| | def test_linear_adapter_monkey_patch(self): |
| | from copy import deepcopy |
| |
|
| | from nemo.collections.llm.peft.lora import patch_linear_module |
| |
|
| | linear = nn.Linear(10, 10) |
| | state_init = deepcopy(linear.state_dict()) |
| | linear_adapter = patch_linear_module(linear) |
| | |
| | for key, val in state_init.items(): |
| | assert key in linear_adapter.state_dict(), f"Key {key} not found in LinearAdapter" |
| | assert torch.equal(val, linear_adapter.state_dict()[key]), f"Key {key} diff. val in LinearAdapter" |
| | |
| | for key, val in linear_adapter.state_dict().items(): |
| | if key in state_init: |
| | continue |
| | assert key in ['lora_a.weight', 'lora_b.weight'] |
| |
|
| | state_dict = linear_adapter.state_dict() |
| | for key in ['lora_a', 'lora_b']: |
| | assert hasattr(linear_adapter, key), f"Expected {key} to be in module" |
| | assert f'{key}.weight' in state_dict, f"Expected {key} to be in state dict" |
| | assert getattr(linear_adapter, key).weight.requires_grad == True, "Expected {key} to require_grad" |
| |
|
| | def test_peft_setup(self): |
| | peft = self.DummyPEFT() |
| | trainer = MagicMock() |
| | pl_module = MagicMock() |
| |
|
| | pl_module.model_transform = peft |
| | peft.setup(trainer, pl_module, "fit") |
| |
|
| | assert isinstance(trainer.strategy._checkpoint_io, AsyncFinalizableCheckpointIO) |
| | assert isinstance(trainer.strategy._checkpoint_io._checkpoint_io, WrappedAdapterIO) |
| | assert peft.model_transform is not None |
| | assert peft._needs_to_call is True |
| |
|
| | @patch('nemo.lightning.pytorch.callbacks.peft.logging') |
| | def test_peft_on_train_epoch_start_with_adapter(self, mock_logging): |
| | peft = self.DummyPEFT() |
| | trainer = MagicMock() |
| | pl_module = MagicMock() |
| | pl_module.model_transform = peft |
| | trainer.state.fn = TrainerFn.FITTING |
| |
|
| | peft.setup(trainer, pl_module, "fit") |
| |
|
| | assert peft.model_transform is not None |
| | assert peft._needs_to_call is True |
| |
|
| | peft.wrapped_io = MagicMock() |
| | peft.wrapped_io.adapter_ckpt_path = "dummy_path" |
| | peft.wrapped_io.load_checkpoint.return_value = {"dummy_state": "dummy_value"} |
| | peft.on_train_epoch_start(trainer, pl_module) |
| |
|
| | |
| | mock_logging.info.assert_has_calls( |
| | [ |
| | call("Loading adapters from dummy_path"), |
| | call("Initializing model parallel"), |
| | call("Setting up optimizers"), |
| | ], |
| | any_order=True, |
| | ) |
| |
|
| | |
| | assert mock_logging.info.call_count == 3 |
| |
|
| | trainer.strategy.load_model_state_dict.assert_called_once_with({"dummy_state": "dummy_value"}, strict=False) |
| | trainer.strategy.init_model_parallel.assert_called_once() |
| | trainer.strategy.setup_optimizers.assert_called_once_with(trainer) |
| |
|
| | def test_params_to_save(self): |
| | |
| | model = self.DummyModel() |
| | peft = self.DummyPEFT() |
| | trainer = MagicMock() |
| | trainer.lightning_module = model |
| |
|
| | |
| | model.conv.requires_grad_(False) |
| | model.linear.requires_grad_(True) |
| |
|
| | |
| | peft.set_params_to_save(trainer) |
| |
|
| | |
| | expected_trainable = {'linear.weight', 'linear.bias'} |
| |
|
| | |
| | assert hasattr(peft, 'params_to_save'), "params_to_save not set" |
| | assert ( |
| | peft.params_to_save == expected_trainable |
| | ), f"Expected trainable params {expected_trainable}, but got {peft.params_to_save}" |
| |
|
| | |
| | for name, param in model.named_parameters(): |
| | if name in peft.params_to_save: |
| | assert param.requires_grad, f"Parameter {name} should require gradients" |
| | else: |
| | assert not param.requires_grad, f"Parameter {name} should not require gradients" |
| |
|
| | def test_params_to_save_batchnorm(self): |
| | |
| | model = self.DummyModel() |
| | model.bn = nn.BatchNorm2d(8) |
| | peft = self.DummyPEFT() |
| | trainer = MagicMock() |
| | trainer.lightning_module = model |
| |
|
| | |
| | model.freeze() |
| |
|
| | |
| | peft.set_params_to_save(trainer) |
| |
|
| | |
| | expected_trainable = { |
| | 'bn.running_mean', |
| | 'bn.running_var', |
| | 'bn.num_batches_tracked', |
| | } |
| |
|
| | |
| | assert hasattr(peft, 'params_to_save'), "params_to_save not set" |
| | assert ( |
| | peft.params_to_save == expected_trainable |
| | ), f"Expected trainable params {expected_trainable}, but got {peft.params_to_save}" |
| |
|
| | |
| | for name, param in model.named_parameters(): |
| | if name in peft.params_to_save: |
| | assert param.requires_grad, f"Parameter {name} should require gradients" |
| | else: |
| | assert not param.requires_grad, f"Parameter {name} should not require gradients" |
| |
|