Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
import lightning.pytorch as pl
import torch
import torch.nn as nn
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
from lightning.pytorch.trainer.states import TrainerFn
from typing_extensions import override
from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME, HF_ADAPTER_CONFIG_FILENAME, HF_ADAPTER_PATH
from nemo.lightning.io.mixin import IOMixin
from nemo.lightning.io.pl import ckpt_to_dir, ckpt_to_weights_subdir
from nemo.lightning.megatron_parallel import MegatronParallel
from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.lightning.pytorch.utils import get_automodel_from_trainer, is_trainer_attached
from nemo.utils import logging
from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO
if TYPE_CHECKING:
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
class PEFT(IOMixin, ABC, ModelTransform):
"""Abstract base class for Parameter-Efficient Fine-Tuning (PEFT) methods.
This class defines the interface for PEFT methods, which are used to fine-tune
large language models efficiently by modifying only a small subset of the model's
parameters.
Example:
class MyPEFT(PEFT):
def transform(self, module, name=None, prefix=None):
# Implement the transform logic
pass
peft = MyPEFT()
peft_model = LargeLanguageModel(model_transform=peft)
"""
@abstractmethod
def transform(self, module, name=None, prefix=None):
"""Transform a single module according to the PEFT method.
This method is called for each module in the model during the PEFT application process.
It should be implemented by subclasses to define how individual modules are transformed
for the specific PEFT technique.
Args:
module (nn.Module): The individual module to be transformed.
name (Optional[str]): The name of the module within the model structure. Defaults to None.
prefix (Optional[str]): A prefix to be added to the module name, typically used for
nested modules. Defaults to None.
Returns:
nn.Module: The transformed module. This can be the original module with modifications,
a new module replacing the original, or the original module if no
transformation is needed for this specific module.
Note:
This method is automatically called for each module in the model when the PEFT
instance is applied to the model using the __call__ method.
"""
raise NotImplementedError("The transform method should be implemented by subclasses.")
def __call__(self, model: nn.Module) -> nn.Module:
"""Apply the PEFT method to the entire model.
This method freezes the model parameters and walks through the model
structure, applying the transform method to each module.
Args:
model (nn.Module): The model to be fine-tuned.
Returns:
nn.Module: The transformed model with PEFT applied.
"""
self.freeze_model(model)
# walk model chunks
if isinstance(model, MegatronParallel) and len(model) > 1:
for model_chunk in model:
model_chunk.walk(self.transform)
elif isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel):
model.module.walk(self.transform)
else:
model.walk(self.transform)
if is_trainer_attached(model) and model.trainer.state.fn != TrainerFn.FITTING:
self.freeze_model(model)
return model
def freeze_model(self, model: nn.Module) -> None:
"""Apply a default freeze method to the model.
This method freezes all the model parameters. This method can be overridden by subclasses to
implement custom freeze strategies (e.g. freeze only parts of the model)
Args:
model (nn.Module): The model to be fine-tuned.
Returns:
nn.Module: The transformed model with PEFT applied.
"""
if isinstance(model, MegatronParallel) and len(model) > 1:
for model_chunk in model:
model_chunk.freeze()
if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel):
model.module.freeze()
else:
model.freeze()
if is_trainer_attached(model) and model.trainer.state.fn == TrainerFn.FITTING:
model.train(mode=True)
def get_wrappped_io(self):
"""
This is a helper function to return a partial function that wraps the checkpoint I/O with the PEFT adapter.
Can be overridden in each PEFT method class.
"""
return partial(WrappedAdapterIO, peft=self)
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
"""PTL callback setup function."""
from nemo.lightning.pytorch.strategies.utils import create_checkpoint_io
from nemo.lightning.pytorch.utils import get_automodel_from_trainer
super().setup(trainer, pl_module, stage=stage)
self._add_via_setattr = 'HFAutoModel' in type(pl_module).__name__
trainer.strategy.trainer = trainer
wrapped_io = self.get_wrappped_io()
# automodel_setup_optimizers is either None or holds a reference to trainer.strategy.setup_optimizers
self.automodel_setup_optimizers = None
# automodel adds adapters in configure_model
self.transform_already_applied = False
if get_automodel_from_trainer(trainer) is not None:
ckpt_io_kwargs = {"model_library": "huggingface", "lora": True}
# Due to the workaround used in peft restoration, it makes restoration non-PTL conforming,
# therefore need to short-circuit these two functions.
trainer._checkpoint_connector.restore_training_state = lambda: True
trainer._checkpoint_connector.restore_model = lambda: True
self.automodel_setup_optimizers = trainer.strategy.setup_optimizers
self.transform_already_applied = True
trainer.strategy.setup_optimizers = lambda x: True
else:
ckpt_io_kwarg_names = [
"save_ckpt_format",
"async_save",
"torch_dist_multiproc",
"assume_constant_structure",
"parallel_save",
"parallel_save_within_dp",
"parallel_load",
"load_directly_on_device",
]
ckpt_io_kwargs = {
arg: getattr(trainer.strategy, arg)
for arg in filter(lambda x: hasattr(trainer.strategy, x), ckpt_io_kwarg_names)
}
trainer.strategy._checkpoint_io = create_checkpoint_io(wrapping_ckpt_io=wrapped_io, **ckpt_io_kwargs)
self.wrapped_io = (
trainer.strategy._checkpoint_io._checkpoint_io
if getattr(trainer.strategy, 'async_save', False)
else trainer.strategy._checkpoint_io
)
trainer.strategy._init_model_parallel = False
# it will enter the following if statement if the model is on the automodel workflow
# where the PEFT is applied in the configure_model
if self.transform_already_applied:
trainer.strategy._init_model_parallel = True
trainer.strategy._setup_optimizers = False
def set_params_to_save(self, trainer: pl.Trainer) -> None:
"""
Set params to be saved for PEFT. This function is called in apply_transform.
Can be overridden in each PEFT method class.
"""
self.params_to_save = set(
name for name, param in trainer.lightning_module.named_parameters() if param.requires_grad
)
for module_name, module in trainer.lightning_module.named_modules():
if hasattr(module, "track_running_stats"):
for buffer_name, buffer in module.named_buffers():
if buffer is not None:
self.params_to_save.add(module_name + "." + buffer_name)
def apply_transform(self, trainer):
"""
This function does the following:
1. Apply PEFT model transform.
2. Set up model parallel and optimizer, which were skipped in setup
3. Load weights and optimizer state dict
4. Set up `finalize_model_grads` from mcore.
"""
# automodel adds adapters in configure_model
if not getattr(self, 'transform_already_applied', False):
super().apply_transform(trainer)
self.set_params_to_save(trainer)
# Handle automodel and return early.
if (
self.wrapped_io.adapter_ckpt_path is not None
and Path(self.wrapped_io.adapter_ckpt_path).parts[-1] == HF_ADAPTER_PATH
):
# Automodel adapter restoration is handled in restore_automodel.
return self.restore_automodel(trainer, self.wrapped_io.adapter_ckpt_path.parent)
elif getattr(self, 'transform_already_applied', False) == True or self.automodel_setup_optimizers is not None:
if self.automodel_setup_optimizers is not None:
logging.info("Setting up optimizers")
self.automodel_setup_optimizers(trainer)
self.automodel_setup_optimizers = None
return
adapter_sharded_state_dict = {}
if self.wrapped_io.adapter_ckpt_path is not None:
logging.info(f"Loading adapters from {self.wrapped_io.adapter_ckpt_path}")
# create sharded state dict for adapter weights only to enable PEFT resume
adapter_sharded_state_dict['state_dict'] = {
k: v for k, v in trainer.model.sharded_state_dict().items() if self.adapter_key_filter(k)
}
if hasattr(trainer.strategy, "init_model_parallel"):
logging.info("Initializing model parallel")
trainer.strategy.init_model_parallel()
if trainer.state.fn == TrainerFn.FITTING:
logging.info("Setting up optimizers")
trainer.strategy.setup_optimizers(trainer)
if self.wrapped_io.adapter_ckpt_path is not None and trainer.strategy.should_restore_optimizer_states():
# PEFT resume, load optimizer state
adapter_sharded_state_dict['optimizer'] = [
trainer.strategy.optimizer_sharded_state_dict(is_loading=True)
]
if adapter_sharded_state_dict:
adapter_state = self.wrapped_io.load_checkpoint(
self.wrapped_io.adapter_ckpt_path, sharded_state_dict=adapter_sharded_state_dict
)
trainer.strategy.load_model_state_dict(adapter_state, strict=False)
if trainer.state.fn == TrainerFn.FITTING:
# Load optimizer
trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=False)
# Load lr scheduler
if (lr_schedulers := adapter_state.get('lr_schedulers', None)) is not None:
for config, lrs_state in zip(trainer.lr_scheduler_configs, lr_schedulers):
config.scheduler.load_state_dict(lrs_state)
for cb in trainer.callbacks[::-1]:
if isinstance(cb, MegatronOptimizerModule):
cb.on_fit_start(trainer, trainer.lightning_module)
break
else:
# i.e., this is an mcore model; elif not supported here.
if get_automodel_from_trainer(trainer) is None:
logging.warning(
"MegatronOptimizerModule not found in trainer callbacks. finalize_model_grads is not "
"properly set up for PEFT."
)
def restore_automodel(self, trainer, path):
"""restores automodel's adapter and optimizer state dict"""
def pop_fqn_prefix(fqn, prefix='model'):
"""helper function to remove first "model" from fqn"""
parts = fqn.split('.')
assert parts[0] == prefix
return '.'.join(parts[1:])
adapter_state = self.wrapped_io.load_checkpoint(path)
# Ensure all keys from adapter_state are contained in model
state_dict = trainer.lightning_module.state_dict()
for key in adapter_state['state_dict'].keys():
assert key in state_dict, (key, state_dict.keys())
# Move to cpu and load state
from nemo.lightning.pytorch.strategies.utils import to_cpu
trainer.strategy.load_model_state_dict(
{'state_dict': {pop_fqn_prefix(k): to_cpu(v) for k, v in adapter_state['state_dict'].items()}},
strict=False,
)
# Ensure adapters have grad enabled
for key, param in trainer.lightning_module.named_parameters():
param.requires_grad_(key in adapter_state['state_dict'])
if trainer.state.fn == TrainerFn.FITTING:
# Restore optim and LR Scheduler
assert self.automodel_setup_optimizers is not None, "Expected automodel_setup_optimizers to be valid"
self.automodel_setup_optimizers(trainer)
# Load optimizer
trainer.strategy.load_optimizer_state_dict(adapter_state)
# Load lr scheduler
if (lr_schedulers := adapter_state.get('lr_schedulers', None)) is not None:
for config, lrs_state in zip(trainer.lr_scheduler_configs, lr_schedulers):
config.scheduler.load_state_dict(lrs_state)
def adapter_key_filter(self, key: str) -> bool:
"""
Given a key in the state dict, return whether the key is an adapter (or base model).
This function can be subclassed in each PEFT method class.
"""
if isinstance(key, tuple):
return key[1].requires_grad
return key in self.params_to_save or ".adapter." in key or key.endswith(".adapters")
class AdapterWrapper(nn.Module):
"""Abstract base class for wrapping modules with adapters in Parameter-Efficient Fine-Tuning (PEFT).
This class wraps a module and its associated adapter, providing methods for
managing the state dictionaries of both the main module and the adapter. It does not
implement the forward method, which must be implemented by concrete subclasses.
Attributes:
to_wrap (nn.Module): The main module to be wrapped.
adapter (nn.Module): The adapter module to be applied.
Note:
This class is abstract and cannot be instantiated directly. Subclasses must
implement the forward method.
Example:
class LoRALinear(AdapterWrapper):
def __init__(self, to_wrap, adapter):
super().__init__(to_wrap, adapter)
def forward(self, x):
return self.to_wrap(x) + self.adapter(x)
main_module = nn.Linear(100, 100)
adapter = nn.Linear(100, 100)
parallel_adapter = LoRALinear(main_module, adapter)
"""
def __init__(self, to_wrap: nn.Module, adapter: nn.Module):
super(AdapterWrapper, self).__init__()
self.to_wrap = to_wrap
self.adapter = adapter
def base_linear_forward(self, x, *args, **kwargs):
"""
Run the forward method of the linear module `to_wrap`.
Return a tuple of three elements: linear_output, bias, layernorm_output
x -> [layernorm/identity] -> layernorm_output -> [linear] -> linear_output, bias
layernorm_output is different from input x only when linear layer is LayerNormColumnParallelLinear.
"""
linear_output = self.to_wrap(x, *args, **kwargs)
assert isinstance(
linear_output, tuple
), f"{self.to_wrap} should return a tuple but instead returns {linear_output}"
""" Four cases for the wrapped module's return values
1. nothing: (out, None)
2. return_bias: (out, bias)
2. return_layernorm_output: ((out, ln_out), None)
3. both: (out, bias, ln_out)
"""
bias = None
layernorm_output = x
if len(linear_output) == 2:
linear_output, bias = linear_output
if isinstance(linear_output, tuple) and len(linear_output) == 2:
linear_output, layernorm_output = linear_output
elif len(linear_output) == 3:
linear_output, bias, layernorm_output = linear_output
return linear_output, bias, layernorm_output
def state_dict(self, destination=None, prefix='', keep_vars=False):
"""Retrieve the state dictionary of the wrapped module and adapter.
This method overrides the default state_dict behavior to include both
the main module's state and the adapter's state under a special 'adapters' key.
Args:
destination (Optional[dict]): A dictionary to store the state. If None, a new
dictionary is created. Defaults to None.
prefix (str): A prefix added to parameter and buffer names. Defaults to ''.
keep_vars (bool): If True, returns variables instead of tensor values.
Defaults to False.
Returns:
dict: The state dictionary containing both the main module and adapter states.
"""
if destination is None:
destination = {}
# Get state dict of the main module
self.to_wrap.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
# Store adapter state dict under the "adapter" prefix in the destination dict
self.adapter.state_dict(destination=destination, prefix=f'{prefix}adapter.', keep_vars=keep_vars)
return destination
def sharded_state_dict(
self,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) -> "ShardedStateDict":
"""Retrieve the sharded state dictionary of the wrapped module and adapter.
This method is used for distributed checkpointing, combining the sharded states
of both the main module and the adapter.
Args:
prefix (str): A prefix added to parameter and buffer names. Defaults to ''.
sharded_offsets (Tuple[Tuple[int, int, int]]): Offsets for sharded parameters.
Defaults to an empty tuple.
metadata (Optional[dict]): Additional metadata for the sharded state.
Defaults to None.
Returns:
ShardedStateDict: The combined sharded state dictionary.
"""
sharded_state_dict = {}
sharded_state_dict.update(self.to_wrap.sharded_state_dict(prefix, sharded_offsets, metadata))
sharded_state_dict.update(self.adapter.sharded_state_dict(f"{prefix}adapter.", sharded_offsets, metadata))
return sharded_state_dict
class WrappedAdapterIO(_WrappingCheckpointIO, AsyncCompatibleCheckpointIO): # noqa: F821
"""
A wrapper class for checkpoint I/O operations, specifically designed for PEFT (Parameter-Efficient Fine-Tuning).
This class handles the complexities of saving and loading checkpoints for both initial PEFT training and resuming
PEFT training. It ensures that only the necessary adapter weights are saved and loaded, while also preserving the
base model weights.
**Usage:**
1. **Initial PEFT Training:**
- The class handles the saving of only adapter weights.
- Metadata about the base model checkpoint is stored for future reference.
2. **PEFT Resume:**
- The class loads both base model and adapter weights.
- The previously stored metadata is used to locate the correct base model checkpoint.
**Attributes:**
- `peft`: The PEFT instance associated with the wrapped checkpoint I/O.
- `model_ckpt_path`: The path to the base model checkpoint.
- `adapter_ckpt_path`: The path to the adapter checkpoint.
Note that the paths are set by save/load functions and users do not need to set them.
**Methods:**
- `save_checkpoint`: Saves the adapter weights and metadata to the specified path.
- `load_checkpoint`: Loads the base model and adapter weights based on the specified path and metadata.
"""
peft: Optional[PEFT] = None
model_ckpt_path: Optional[Path] = None
adapter_ckpt_path: Optional[Path] = None
def __init__(
self, checkpoint_io: Optional["CheckpointIO"] = None, peft: Optional[PEFT] = None # noqa: F821
) -> None:
self.peft = peft
super().__init__(checkpoint_io)
@override
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
assert self.checkpoint_io is not None
state_key = None
for k in ['sharded_state_dict', 'state_dict']:
if k in checkpoint:
state_key = k
break
assert state_key is not None, "Expected checkpoint to contain `sharded_state_dict` or `state_dict`"
assert state_key in checkpoint, "Expected state_key to be in checkpoint"
state_dict = checkpoint.pop(state_key)
checkpoint[state_key] = dict(filter(lambda item: self.peft.adapter_key_filter(item[0]), state_dict.items()))
ckpt_keys = list(checkpoint[state_key].keys())
request = self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options)
from nemo.utils.get_rank import is_global_rank_zero
if is_global_rank_zero():
base_dir = ckpt_to_weights_subdir(path, is_saving=True)
from nemo.lightning.io.hf import HFCheckpointIO
if isinstance(self.checkpoint_io, HFCheckpointIO):
metadata = self._create_lora_hf_config(ckpt_keys)
hf_adapter_base = base_dir.parent / HF_ADAPTER_PATH
hf_adapter_base.mkdir(parents=True, exist_ok=True)
adapter_meta_path = hf_adapter_base / HF_ADAPTER_CONFIG_FILENAME
else:
base_dir.mkdir(parents=True, exist_ok=True)
metadata = {"model_ckpt_path": str(self.model_ckpt_path)}
adapter_meta_path = base_dir / ADAPTER_META_FILENAME
with open(adapter_meta_path, "w") as f:
json.dump(metadata, f)
return request
def _create_lora_hf_config(self, ckpt_keys):
"""Creates a HF lora config from a NeMo Lora config"""
def extract_matched_module_names(ckpt_keys, target_modules):
"""
Extracts module names from a list of checkpoint keys that match the target modules.
This function processes a list of target module patterns, where each pattern may or may
not contain a wildcard (`'*'`). The function matches these patterns against the
checkpoint keys, with the following behavior:
- Patterns containing '*' will be expanded to match any sequence of characters
except a dot (`.`).
- Patterns without '*' are matched literally.
Args:
ckpt_keys (list of str): A list of strings representing checkpoint keys to be
searched.
target_modules (list of str): A list of target module patterns. Some patterns may
contain wildcards (`'*'`), which match any characters except a dot.
Returns:
list of str: A list of module names from `target_modules` that match any of the
`ckpt_keys`. The result is returned as a list of unique module names.
Example:
ckpt_keys = [
"model.model.layers.27.self_attn.k_proj",
"model.model.layers.27.self_attn.v_proj",
"model.model.layers.27.self_attn.mlp"
]
target_modules = ["*proj"]
extract_matched_module_names(ckpt_keys, target_modules)
# Output: ['k_proj', 'v_proj']
Notes:
- This function uses regular expressions to match the target patterns in the
checkpoint keys.
- Wildcards are expanded as `[^.]+` to ensure that the match doesn't cross dot
(`.`) boundaries.
"""
re_target_modules = list(filter(lambda x: '*' in x, target_modules))
if len(re_target_modules) == 0:
return target_modules
non_re_target_modules = list(filter(lambda x: not '*' in x, target_modules))
combined_pattern = '|'.join(
map(lambda x: x.replace('*', '[^.]+'), re_target_modules),
)
ans = set(non_re_target_modules)
for key in ckpt_keys:
ans.update(re.findall(combined_pattern, key))
return list(ans)
from peft import LoraConfig
from nemo.collections.llm.peft import DoRA
# Contains all target module names, without any regular expression
materialized_module_names = extract_matched_module_names(ckpt_keys, self.peft.target_modules)
lora_config = LoraConfig(
r=self.peft.dim,
target_modules=materialized_module_names,
lora_alpha=self.peft.alpha,
lora_dropout=self.peft.dropout,
use_dora=isinstance(self.peft, DoRA),
)
lora_config = lora_config.to_dict()
lora_config["peft_type"] = "LORA"
lora_config["megatron_core"] = None
lora_config["target_modules"] = materialized_module_names
return lora_config
@override
def load_checkpoint(
self,
path: _PATH,
sharded_state_dict=None,
map_location: Optional[Callable] = None,
strict: Optional['StrictHandling'] | bool = None, # noqa: F821
) -> Dict[str, Any]:
"""
=====================
Initial PEFT Training
=====================
Initial PEFT training requires loading the base model weights. In this case, this function is called by
trainer.strategy.setup() -> megatron_strategy.restore_model() -> megatron_strategy.load_checkpoint().
`path = PosixPath(<base_path>)`, and sharded_state_dict contains only base model weights
===========
PEFT Resume
===========
PEFT resume requires loading two set of model weights, 1) base model weights and 2) adapter weights
Base model weights could be imported from e.g. HF, and is frozen during PEFT training.
Adapter weights contains the training metadata that will need to be loaded.
As such, this function will be entered twice during PEFT training resume.
For the FIRST TIME this function is called by trainer._checkpoint_connector._restore_modules_and_callbacks.
`path = AdapterPath(<adapter_path>, base_model_path=<base_path>)`, and sharded_state_dict contains only base
model weights
For the SECOND TIME this function is called by PEFT.apply_transform (above, in the same file).
`path = PosixPath(<adapter_path>)`, and sharded_state_dict contains only adapter weights.
"""
assert self.checkpoint_io is not None
adapter_ckpt = None
base = ckpt_to_dir(path)
if getattr(path, "base_model_path", None):
# PEFT Resume, FIRST TIME
self.adapter_ckpt_path = Path(str(path))
adapter_ckpt = self.checkpoint_io.load_checkpoint(path, sharded_state_dict={}) # Loads only metadata
# path is adapter path to restore the training metadata, but switch to loading base model here.
path = self.model_ckpt_path = path.base_model_path
elif (adapter_meta_path := base / ADAPTER_META_FILENAME).exists():
# PEFT Resume, SECOND TIME
with open(adapter_meta_path, "r") as f:
metadata = json.load(f)
self.model_ckpt_path = Path(metadata['model_ckpt_path'])
self.adapter_ckpt_path = path
elif (base / HF_ADAPTER_PATH / HF_ADAPTER_CONFIG_FILENAME).exists():
self.adapter_ckpt_path = path / HF_ADAPTER_PATH
else:
# Initial PEFT Training
self.model_ckpt_path = path
# Note: this will include the Trainer-state of the model-checkpoint
model_ckpt = self.checkpoint_io.load_checkpoint(path, sharded_state_dict, map_location, strict)
if adapter_ckpt is not None:
# PEFT Resume, FIRST TIME
adapter_ckpt['state_dict'].update(model_ckpt['state_dict'])
return adapter_ckpt
return model_ckpt