| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import re |
| from abc import ABC, abstractmethod |
| from functools import partial |
| from pathlib import Path |
| from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple |
|
|
| import lightning.pytorch as pl |
| import torch |
| import torch.nn as nn |
| from lightning.fabric.utilities.types import _PATH |
| from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO |
| from lightning.pytorch.trainer.states import TrainerFn |
| from typing_extensions import override |
|
|
| from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME, HF_ADAPTER_CONFIG_FILENAME, HF_ADAPTER_PATH |
| from nemo.lightning.io.mixin import IOMixin |
| from nemo.lightning.io.pl import ckpt_to_dir, ckpt_to_weights_subdir |
| from nemo.lightning.megatron_parallel import MegatronParallel |
| from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform |
| from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule |
| from nemo.lightning.pytorch.utils import get_automodel_from_trainer, is_trainer_attached |
| from nemo.utils import logging |
| from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO |
|
|
| if TYPE_CHECKING: |
| from megatron.core.dist_checkpointing.mapping import ShardedStateDict |
|
|
|
|
| class PEFT(IOMixin, ABC, ModelTransform): |
| """Abstract base class for Parameter-Efficient Fine-Tuning (PEFT) methods. |
| |
| This class defines the interface for PEFT methods, which are used to fine-tune |
| large language models efficiently by modifying only a small subset of the model's |
| parameters. |
| |
| Example: |
| class MyPEFT(PEFT): |
| def transform(self, module, name=None, prefix=None): |
| # Implement the transform logic |
| pass |
| |
| |
| peft = MyPEFT() |
| peft_model = LargeLanguageModel(model_transform=peft) |
| """ |
|
|
| @abstractmethod |
| def transform(self, module, name=None, prefix=None): |
| """Transform a single module according to the PEFT method. |
| |
| This method is called for each module in the model during the PEFT application process. |
| It should be implemented by subclasses to define how individual modules are transformed |
| for the specific PEFT technique. |
| |
| Args: |
| module (nn.Module): The individual module to be transformed. |
| name (Optional[str]): The name of the module within the model structure. Defaults to None. |
| prefix (Optional[str]): A prefix to be added to the module name, typically used for |
| nested modules. Defaults to None. |
| |
| Returns: |
| nn.Module: The transformed module. This can be the original module with modifications, |
| a new module replacing the original, or the original module if no |
| transformation is needed for this specific module. |
| |
| Note: |
| This method is automatically called for each module in the model when the PEFT |
| instance is applied to the model using the __call__ method. |
| """ |
| raise NotImplementedError("The transform method should be implemented by subclasses.") |
|
|
| def __call__(self, model: nn.Module) -> nn.Module: |
| """Apply the PEFT method to the entire model. |
| |
| This method freezes the model parameters and walks through the model |
| structure, applying the transform method to each module. |
| |
| Args: |
| model (nn.Module): The model to be fine-tuned. |
| |
| Returns: |
| nn.Module: The transformed model with PEFT applied. |
| """ |
| self.freeze_model(model) |
|
|
| |
| if isinstance(model, MegatronParallel) and len(model) > 1: |
| for model_chunk in model: |
| model_chunk.walk(self.transform) |
| elif isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel): |
| model.module.walk(self.transform) |
| else: |
| model.walk(self.transform) |
|
|
| if is_trainer_attached(model) and model.trainer.state.fn != TrainerFn.FITTING: |
| self.freeze_model(model) |
| return model |
|
|
| def freeze_model(self, model: nn.Module) -> None: |
| """Apply a default freeze method to the model. |
| |
| This method freezes all the model parameters. This method can be overridden by subclasses to |
| implement custom freeze strategies (e.g. freeze only parts of the model) |
| |
| Args: |
| model (nn.Module): The model to be fine-tuned. |
| |
| Returns: |
| nn.Module: The transformed model with PEFT applied. |
| """ |
| if isinstance(model, MegatronParallel) and len(model) > 1: |
| for model_chunk in model: |
| model_chunk.freeze() |
| if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel): |
| model.module.freeze() |
| else: |
| model.freeze() |
| if is_trainer_attached(model) and model.trainer.state.fn == TrainerFn.FITTING: |
| model.train(mode=True) |
|
|
| def get_wrappped_io(self): |
| """ |
| This is a helper function to return a partial function that wraps the checkpoint I/O with the PEFT adapter. |
| Can be overridden in each PEFT method class. |
| """ |
| return partial(WrappedAdapterIO, peft=self) |
|
|
| def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: |
| """PTL callback setup function.""" |
| from nemo.lightning.pytorch.strategies.utils import create_checkpoint_io |
| from nemo.lightning.pytorch.utils import get_automodel_from_trainer |
|
|
| super().setup(trainer, pl_module, stage=stage) |
|
|
| self._add_via_setattr = 'HFAutoModel' in type(pl_module).__name__ |
| trainer.strategy.trainer = trainer |
| wrapped_io = self.get_wrappped_io() |
|
|
| |
| self.automodel_setup_optimizers = None |
| |
| self.transform_already_applied = False |
| if get_automodel_from_trainer(trainer) is not None: |
| ckpt_io_kwargs = {"model_library": "huggingface", "lora": True} |
| |
| |
| trainer._checkpoint_connector.restore_training_state = lambda: True |
| trainer._checkpoint_connector.restore_model = lambda: True |
| self.automodel_setup_optimizers = trainer.strategy.setup_optimizers |
| self.transform_already_applied = True |
| trainer.strategy.setup_optimizers = lambda x: True |
| else: |
| ckpt_io_kwarg_names = [ |
| "save_ckpt_format", |
| "async_save", |
| "torch_dist_multiproc", |
| "assume_constant_structure", |
| "parallel_save", |
| "parallel_save_within_dp", |
| "parallel_load", |
| "load_directly_on_device", |
| ] |
| ckpt_io_kwargs = { |
| arg: getattr(trainer.strategy, arg) |
| for arg in filter(lambda x: hasattr(trainer.strategy, x), ckpt_io_kwarg_names) |
| } |
| trainer.strategy._checkpoint_io = create_checkpoint_io(wrapping_ckpt_io=wrapped_io, **ckpt_io_kwargs) |
| self.wrapped_io = ( |
| trainer.strategy._checkpoint_io._checkpoint_io |
| if getattr(trainer.strategy, 'async_save', False) |
| else trainer.strategy._checkpoint_io |
| ) |
| trainer.strategy._init_model_parallel = False |
| |
| |
| if self.transform_already_applied: |
| trainer.strategy._init_model_parallel = True |
| trainer.strategy._setup_optimizers = False |
|
|
| def set_params_to_save(self, trainer: pl.Trainer) -> None: |
| """ |
| Set params to be saved for PEFT. This function is called in apply_transform. |
| Can be overridden in each PEFT method class. |
| """ |
| self.params_to_save = set( |
| name for name, param in trainer.lightning_module.named_parameters() if param.requires_grad |
| ) |
| for module_name, module in trainer.lightning_module.named_modules(): |
| if hasattr(module, "track_running_stats"): |
| for buffer_name, buffer in module.named_buffers(): |
| if buffer is not None: |
| self.params_to_save.add(module_name + "." + buffer_name) |
|
|
| def apply_transform(self, trainer): |
| """ |
| This function does the following: |
| 1. Apply PEFT model transform. |
| 2. Set up model parallel and optimizer, which were skipped in setup |
| 3. Load weights and optimizer state dict |
| 4. Set up `finalize_model_grads` from mcore. |
| """ |
| |
| if not getattr(self, 'transform_already_applied', False): |
| super().apply_transform(trainer) |
| self.set_params_to_save(trainer) |
|
|
| |
| if ( |
| self.wrapped_io.adapter_ckpt_path is not None |
| and Path(self.wrapped_io.adapter_ckpt_path).parts[-1] == HF_ADAPTER_PATH |
| ): |
| |
| return self.restore_automodel(trainer, self.wrapped_io.adapter_ckpt_path.parent) |
| elif getattr(self, 'transform_already_applied', False) == True or self.automodel_setup_optimizers is not None: |
| if self.automodel_setup_optimizers is not None: |
| logging.info("Setting up optimizers") |
| self.automodel_setup_optimizers(trainer) |
| self.automodel_setup_optimizers = None |
| return |
|
|
| adapter_sharded_state_dict = {} |
| if self.wrapped_io.adapter_ckpt_path is not None: |
| logging.info(f"Loading adapters from {self.wrapped_io.adapter_ckpt_path}") |
| |
| adapter_sharded_state_dict['state_dict'] = { |
| k: v for k, v in trainer.model.sharded_state_dict().items() if self.adapter_key_filter(k) |
| } |
|
|
| if hasattr(trainer.strategy, "init_model_parallel"): |
| logging.info("Initializing model parallel") |
| trainer.strategy.init_model_parallel() |
|
|
| if trainer.state.fn == TrainerFn.FITTING: |
| logging.info("Setting up optimizers") |
| trainer.strategy.setup_optimizers(trainer) |
| if self.wrapped_io.adapter_ckpt_path is not None and trainer.strategy.should_restore_optimizer_states(): |
| |
| adapter_sharded_state_dict['optimizer'] = [ |
| trainer.strategy.optimizer_sharded_state_dict(is_loading=True) |
| ] |
|
|
| if adapter_sharded_state_dict: |
| adapter_state = self.wrapped_io.load_checkpoint( |
| self.wrapped_io.adapter_ckpt_path, sharded_state_dict=adapter_sharded_state_dict |
| ) |
| trainer.strategy.load_model_state_dict(adapter_state, strict=False) |
| if trainer.state.fn == TrainerFn.FITTING: |
| |
| trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=False) |
| |
| if (lr_schedulers := adapter_state.get('lr_schedulers', None)) is not None: |
| for config, lrs_state in zip(trainer.lr_scheduler_configs, lr_schedulers): |
| config.scheduler.load_state_dict(lrs_state) |
|
|
| for cb in trainer.callbacks[::-1]: |
| if isinstance(cb, MegatronOptimizerModule): |
| cb.on_fit_start(trainer, trainer.lightning_module) |
| break |
| else: |
| |
| if get_automodel_from_trainer(trainer) is None: |
| logging.warning( |
| "MegatronOptimizerModule not found in trainer callbacks. finalize_model_grads is not " |
| "properly set up for PEFT." |
| ) |
|
|
| def restore_automodel(self, trainer, path): |
| """restores automodel's adapter and optimizer state dict""" |
|
|
| def pop_fqn_prefix(fqn, prefix='model'): |
| """helper function to remove first "model" from fqn""" |
| parts = fqn.split('.') |
| assert parts[0] == prefix |
| return '.'.join(parts[1:]) |
|
|
| adapter_state = self.wrapped_io.load_checkpoint(path) |
| |
| state_dict = trainer.lightning_module.state_dict() |
| for key in adapter_state['state_dict'].keys(): |
| assert key in state_dict, (key, state_dict.keys()) |
|
|
| |
| from nemo.lightning.pytorch.strategies.utils import to_cpu |
|
|
| trainer.strategy.load_model_state_dict( |
| {'state_dict': {pop_fqn_prefix(k): to_cpu(v) for k, v in adapter_state['state_dict'].items()}}, |
| strict=False, |
| ) |
|
|
| |
| for key, param in trainer.lightning_module.named_parameters(): |
| param.requires_grad_(key in adapter_state['state_dict']) |
|
|
| if trainer.state.fn == TrainerFn.FITTING: |
| |
| assert self.automodel_setup_optimizers is not None, "Expected automodel_setup_optimizers to be valid" |
| self.automodel_setup_optimizers(trainer) |
| |
| trainer.strategy.load_optimizer_state_dict(adapter_state) |
| |
| if (lr_schedulers := adapter_state.get('lr_schedulers', None)) is not None: |
| for config, lrs_state in zip(trainer.lr_scheduler_configs, lr_schedulers): |
| config.scheduler.load_state_dict(lrs_state) |
|
|
| def adapter_key_filter(self, key: str) -> bool: |
| """ |
| Given a key in the state dict, return whether the key is an adapter (or base model). |
| This function can be subclassed in each PEFT method class. |
| """ |
| if isinstance(key, tuple): |
| return key[1].requires_grad |
| return key in self.params_to_save or ".adapter." in key or key.endswith(".adapters") |
|
|
|
|
| class AdapterWrapper(nn.Module): |
| """Abstract base class for wrapping modules with adapters in Parameter-Efficient Fine-Tuning (PEFT). |
| |
| This class wraps a module and its associated adapter, providing methods for |
| managing the state dictionaries of both the main module and the adapter. It does not |
| implement the forward method, which must be implemented by concrete subclasses. |
| |
| Attributes: |
| to_wrap (nn.Module): The main module to be wrapped. |
| adapter (nn.Module): The adapter module to be applied. |
| |
| Note: |
| This class is abstract and cannot be instantiated directly. Subclasses must |
| implement the forward method. |
| |
| Example: |
| class LoRALinear(AdapterWrapper): |
| def __init__(self, to_wrap, adapter): |
| super().__init__(to_wrap, adapter) |
| |
| def forward(self, x): |
| return self.to_wrap(x) + self.adapter(x) |
| |
| main_module = nn.Linear(100, 100) |
| adapter = nn.Linear(100, 100) |
| parallel_adapter = LoRALinear(main_module, adapter) |
| """ |
|
|
| def __init__(self, to_wrap: nn.Module, adapter: nn.Module): |
| super(AdapterWrapper, self).__init__() |
| self.to_wrap = to_wrap |
| self.adapter = adapter |
|
|
| def base_linear_forward(self, x, *args, **kwargs): |
| """ |
| Run the forward method of the linear module `to_wrap`. |
| Return a tuple of three elements: linear_output, bias, layernorm_output |
| |
| x -> [layernorm/identity] -> layernorm_output -> [linear] -> linear_output, bias |
| |
| layernorm_output is different from input x only when linear layer is LayerNormColumnParallelLinear. |
| """ |
| linear_output = self.to_wrap(x, *args, **kwargs) |
| assert isinstance( |
| linear_output, tuple |
| ), f"{self.to_wrap} should return a tuple but instead returns {linear_output}" |
| """ Four cases for the wrapped module's return values |
| 1. nothing: (out, None) |
| 2. return_bias: (out, bias) |
| 2. return_layernorm_output: ((out, ln_out), None) |
| 3. both: (out, bias, ln_out) |
| """ |
| bias = None |
| layernorm_output = x |
| if len(linear_output) == 2: |
| linear_output, bias = linear_output |
| if isinstance(linear_output, tuple) and len(linear_output) == 2: |
| linear_output, layernorm_output = linear_output |
| elif len(linear_output) == 3: |
| linear_output, bias, layernorm_output = linear_output |
|
|
| return linear_output, bias, layernorm_output |
|
|
| def state_dict(self, destination=None, prefix='', keep_vars=False): |
| """Retrieve the state dictionary of the wrapped module and adapter. |
| |
| This method overrides the default state_dict behavior to include both |
| the main module's state and the adapter's state under a special 'adapters' key. |
| |
| Args: |
| destination (Optional[dict]): A dictionary to store the state. If None, a new |
| dictionary is created. Defaults to None. |
| prefix (str): A prefix added to parameter and buffer names. Defaults to ''. |
| keep_vars (bool): If True, returns variables instead of tensor values. |
| Defaults to False. |
| |
| Returns: |
| dict: The state dictionary containing both the main module and adapter states. |
| """ |
|
|
| if destination is None: |
| destination = {} |
|
|
| |
| self.to_wrap.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) |
|
|
| |
| self.adapter.state_dict(destination=destination, prefix=f'{prefix}adapter.', keep_vars=keep_vars) |
| return destination |
|
|
| def sharded_state_dict( |
| self, |
| prefix: str = '', |
| sharded_offsets: Tuple[Tuple[int, int, int]] = (), |
| metadata: Optional[dict] = None, |
| ) -> "ShardedStateDict": |
| """Retrieve the sharded state dictionary of the wrapped module and adapter. |
| |
| This method is used for distributed checkpointing, combining the sharded states |
| of both the main module and the adapter. |
| |
| Args: |
| prefix (str): A prefix added to parameter and buffer names. Defaults to ''. |
| sharded_offsets (Tuple[Tuple[int, int, int]]): Offsets for sharded parameters. |
| Defaults to an empty tuple. |
| metadata (Optional[dict]): Additional metadata for the sharded state. |
| Defaults to None. |
| |
| Returns: |
| ShardedStateDict: The combined sharded state dictionary. |
| """ |
| sharded_state_dict = {} |
| sharded_state_dict.update(self.to_wrap.sharded_state_dict(prefix, sharded_offsets, metadata)) |
| sharded_state_dict.update(self.adapter.sharded_state_dict(f"{prefix}adapter.", sharded_offsets, metadata)) |
| return sharded_state_dict |
|
|
|
|
| class WrappedAdapterIO(_WrappingCheckpointIO, AsyncCompatibleCheckpointIO): |
| """ |
| A wrapper class for checkpoint I/O operations, specifically designed for PEFT (Parameter-Efficient Fine-Tuning). |
| |
| This class handles the complexities of saving and loading checkpoints for both initial PEFT training and resuming |
| PEFT training. It ensures that only the necessary adapter weights are saved and loaded, while also preserving the |
| base model weights. |
| |
| **Usage:** |
| |
| 1. **Initial PEFT Training:** |
| - The class handles the saving of only adapter weights. |
| - Metadata about the base model checkpoint is stored for future reference. |
| |
| 2. **PEFT Resume:** |
| - The class loads both base model and adapter weights. |
| - The previously stored metadata is used to locate the correct base model checkpoint. |
| |
| **Attributes:** |
| |
| - `peft`: The PEFT instance associated with the wrapped checkpoint I/O. |
| - `model_ckpt_path`: The path to the base model checkpoint. |
| - `adapter_ckpt_path`: The path to the adapter checkpoint. |
| Note that the paths are set by save/load functions and users do not need to set them. |
| |
| **Methods:** |
| |
| - `save_checkpoint`: Saves the adapter weights and metadata to the specified path. |
| - `load_checkpoint`: Loads the base model and adapter weights based on the specified path and metadata. |
| """ |
|
|
| peft: Optional[PEFT] = None |
| model_ckpt_path: Optional[Path] = None |
| adapter_ckpt_path: Optional[Path] = None |
|
|
| def __init__( |
| self, checkpoint_io: Optional["CheckpointIO"] = None, peft: Optional[PEFT] = None |
| ) -> None: |
| self.peft = peft |
| super().__init__(checkpoint_io) |
|
|
| @override |
| def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: |
| assert self.checkpoint_io is not None |
| state_key = None |
| for k in ['sharded_state_dict', 'state_dict']: |
| if k in checkpoint: |
| state_key = k |
| break |
| assert state_key is not None, "Expected checkpoint to contain `sharded_state_dict` or `state_dict`" |
| assert state_key in checkpoint, "Expected state_key to be in checkpoint" |
|
|
| state_dict = checkpoint.pop(state_key) |
| checkpoint[state_key] = dict(filter(lambda item: self.peft.adapter_key_filter(item[0]), state_dict.items())) |
| ckpt_keys = list(checkpoint[state_key].keys()) |
| request = self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options) |
|
|
| from nemo.utils.get_rank import is_global_rank_zero |
|
|
| if is_global_rank_zero(): |
| base_dir = ckpt_to_weights_subdir(path, is_saving=True) |
|
|
| from nemo.lightning.io.hf import HFCheckpointIO |
|
|
| if isinstance(self.checkpoint_io, HFCheckpointIO): |
| metadata = self._create_lora_hf_config(ckpt_keys) |
| hf_adapter_base = base_dir.parent / HF_ADAPTER_PATH |
| hf_adapter_base.mkdir(parents=True, exist_ok=True) |
| adapter_meta_path = hf_adapter_base / HF_ADAPTER_CONFIG_FILENAME |
| else: |
| base_dir.mkdir(parents=True, exist_ok=True) |
| metadata = {"model_ckpt_path": str(self.model_ckpt_path)} |
| adapter_meta_path = base_dir / ADAPTER_META_FILENAME |
|
|
| with open(adapter_meta_path, "w") as f: |
| json.dump(metadata, f) |
| return request |
|
|
| def _create_lora_hf_config(self, ckpt_keys): |
| """Creates a HF lora config from a NeMo Lora config""" |
|
|
| def extract_matched_module_names(ckpt_keys, target_modules): |
| """ |
| Extracts module names from a list of checkpoint keys that match the target modules. |
| |
| This function processes a list of target module patterns, where each pattern may or may |
| not contain a wildcard (`'*'`). The function matches these patterns against the |
| checkpoint keys, with the following behavior: |
| - Patterns containing '*' will be expanded to match any sequence of characters |
| except a dot (`.`). |
| - Patterns without '*' are matched literally. |
| |
| Args: |
| ckpt_keys (list of str): A list of strings representing checkpoint keys to be |
| searched. |
| target_modules (list of str): A list of target module patterns. Some patterns may |
| contain wildcards (`'*'`), which match any characters except a dot. |
| |
| Returns: |
| list of str: A list of module names from `target_modules` that match any of the |
| `ckpt_keys`. The result is returned as a list of unique module names. |
| |
| Example: |
| ckpt_keys = [ |
| "model.model.layers.27.self_attn.k_proj", |
| "model.model.layers.27.self_attn.v_proj", |
| "model.model.layers.27.self_attn.mlp" |
| ] |
| target_modules = ["*proj"] |
| |
| extract_matched_module_names(ckpt_keys, target_modules) |
| # Output: ['k_proj', 'v_proj'] |
| |
| Notes: |
| - This function uses regular expressions to match the target patterns in the |
| checkpoint keys. |
| - Wildcards are expanded as `[^.]+` to ensure that the match doesn't cross dot |
| (`.`) boundaries. |
| """ |
| re_target_modules = list(filter(lambda x: '*' in x, target_modules)) |
| if len(re_target_modules) == 0: |
| return target_modules |
| non_re_target_modules = list(filter(lambda x: not '*' in x, target_modules)) |
| combined_pattern = '|'.join( |
| map(lambda x: x.replace('*', '[^.]+'), re_target_modules), |
| ) |
| ans = set(non_re_target_modules) |
| for key in ckpt_keys: |
| ans.update(re.findall(combined_pattern, key)) |
| return list(ans) |
|
|
| from peft import LoraConfig |
|
|
| from nemo.collections.llm.peft import DoRA |
|
|
| |
| materialized_module_names = extract_matched_module_names(ckpt_keys, self.peft.target_modules) |
| lora_config = LoraConfig( |
| r=self.peft.dim, |
| target_modules=materialized_module_names, |
| lora_alpha=self.peft.alpha, |
| lora_dropout=self.peft.dropout, |
| use_dora=isinstance(self.peft, DoRA), |
| ) |
| lora_config = lora_config.to_dict() |
| lora_config["peft_type"] = "LORA" |
| lora_config["megatron_core"] = None |
| lora_config["target_modules"] = materialized_module_names |
| return lora_config |
|
|
| @override |
| def load_checkpoint( |
| self, |
| path: _PATH, |
| sharded_state_dict=None, |
| map_location: Optional[Callable] = None, |
| strict: Optional['StrictHandling'] | bool = None, |
| ) -> Dict[str, Any]: |
| """ |
| ===================== |
| Initial PEFT Training |
| ===================== |
| Initial PEFT training requires loading the base model weights. In this case, this function is called by |
| trainer.strategy.setup() -> megatron_strategy.restore_model() -> megatron_strategy.load_checkpoint(). |
| `path = PosixPath(<base_path>)`, and sharded_state_dict contains only base model weights |
| |
| =========== |
| PEFT Resume |
| =========== |
| PEFT resume requires loading two set of model weights, 1) base model weights and 2) adapter weights |
| Base model weights could be imported from e.g. HF, and is frozen during PEFT training. |
| Adapter weights contains the training metadata that will need to be loaded. |
| As such, this function will be entered twice during PEFT training resume. |
| |
| For the FIRST TIME this function is called by trainer._checkpoint_connector._restore_modules_and_callbacks. |
| `path = AdapterPath(<adapter_path>, base_model_path=<base_path>)`, and sharded_state_dict contains only base |
| model weights |
| |
| For the SECOND TIME this function is called by PEFT.apply_transform (above, in the same file). |
| `path = PosixPath(<adapter_path>)`, and sharded_state_dict contains only adapter weights. |
| """ |
|
|
| assert self.checkpoint_io is not None |
|
|
| adapter_ckpt = None |
| base = ckpt_to_dir(path) |
| if getattr(path, "base_model_path", None): |
| |
| self.adapter_ckpt_path = Path(str(path)) |
| adapter_ckpt = self.checkpoint_io.load_checkpoint(path, sharded_state_dict={}) |
| |
| path = self.model_ckpt_path = path.base_model_path |
| elif (adapter_meta_path := base / ADAPTER_META_FILENAME).exists(): |
| |
| with open(adapter_meta_path, "r") as f: |
| metadata = json.load(f) |
| self.model_ckpt_path = Path(metadata['model_ckpt_path']) |
| self.adapter_ckpt_path = path |
| elif (base / HF_ADAPTER_PATH / HF_ADAPTER_CONFIG_FILENAME).exists(): |
| self.adapter_ckpt_path = path / HF_ADAPTER_PATH |
| else: |
| |
| self.model_ckpt_path = path |
|
|
| |
| model_ckpt = self.checkpoint_io.load_checkpoint(path, sharded_state_dict, map_location, strict) |
| if adapter_ckpt is not None: |
| |
| adapter_ckpt['state_dict'].update(model_ckpt['state_dict']) |
| return adapter_ckpt |
| return model_ckpt |
|
|