| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from __future__ import annotations |
|
|
| import copy |
| import functools |
| import inspect |
| import os |
| import re |
| import warnings |
| from collections.abc import Sequence |
| from contextlib import nullcontext |
| from operator import attrgetter |
| from typing import Any, Optional, Union |
|
|
| import accelerate |
| import torch |
| import transformers |
| from accelerate import FullyShardedDataParallelPlugin |
| from accelerate.hooks import add_hook_to_module, remove_hook_from_module |
| from accelerate.utils import is_npu_available, is_xpu_available |
| from huggingface_hub import file_exists |
| from huggingface_hub.errors import EntryNotFoundError, HFValidationError |
| from packaging import version |
| from safetensors.torch import storage_ptr, storage_size |
| from transformers import PreTrainedModel |
|
|
| from ..import_utils import is_auto_gptq_available, is_gptqmodel_available, is_torch_tpu_available |
| from .constants import ( |
| CONFIG_NAME, |
| EMBEDDING_LAYER_NAMES, |
| INCLUDE_LINEAR_LAYERS_SHORTHAND, |
| SAFETENSORS_WEIGHTS_NAME, |
| TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_MISS_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_OFT_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_POLY_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, |
| TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_ROAD_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, |
| TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING, |
| WEIGHTS_NAME, |
| bloom_model_postprocess_past_key_value, |
| starcoder_model_postprocess_past_key_value, |
| ) |
|
|
|
|
| mlu_available = False |
| if version.parse(accelerate.__version__) >= version.parse("0.29.0"): |
| from accelerate.utils import is_mlu_available |
|
|
| mlu_available = is_mlu_available() |
|
|
|
|
| __all__ = [ |
| "CONFIG_NAME", |
| "EMBEDDING_LAYER_NAMES", |
| "INCLUDE_LINEAR_LAYERS_SHORTHAND", |
| "SAFETENSORS_WEIGHTS_NAME", |
| "TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_MISS_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_OFT_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_POLY_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", |
| "TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_ROAD_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING", |
| "TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING", |
| "WEIGHTS_NAME", |
| "bloom_model_postprocess_past_key_value", |
| "starcoder_model_postprocess_past_key_value", |
| ] |
|
|
|
|
| |
| def infer_device() -> str: |
| if torch.cuda.is_available(): |
| return "cuda" |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| return "mps" |
| elif mlu_available: |
| return "mlu" |
| elif is_xpu_available(): |
| return "xpu" |
| elif is_npu_available(): |
| return "npu" |
| return "cpu" |
|
|
|
|
| def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None): |
| r""" |
| Note this method only works for `transformers` models. |
| |
| This method wraps the entire protocol for preparing a model before running a training. This includes: |
| 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm |
| head to fp32 4- Freezing the base model layers to ensure they are not updated during training |
| |
| |
| Args: |
| model (`transformers.PreTrainedModel`): |
| The loaded model from `transformers` |
| use_gradient_checkpointing (`bool`, *optional*, defaults to `True`): |
| If True, use gradient checkpointing to save memory at the expense of slower backward pass. |
| gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`): |
| Keyword arguments to pass to the gradient checkpointing function, please refer to the documentation of |
| `torch.utils.checkpoint.checkpoint` for more details about the arguments that you can pass to that method. |
| Note this is only available in the latest transformers versions (> 4.34.1). |
| """ |
| loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) |
| is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq" |
| is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm" |
| is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq" |
| is_torchao_quantized = getattr(model, "quantization_method", None) == "torchao" |
| is_hqq_quantized = getattr(model, "quantization_method", None) == "hqq" or getattr(model, "hqq_quantized", False) |
|
|
| if gradient_checkpointing_kwargs is None: |
| gradient_checkpointing_kwargs = {} |
|
|
| for name, param in model.named_parameters(): |
| |
| param.requires_grad = False |
|
|
| if ( |
| not is_gptq_quantized |
| and not is_aqlm_quantized |
| and not is_eetq_quantized |
| and not is_hqq_quantized |
| and not is_torchao_quantized |
| ): |
| |
| for param in model.parameters(): |
| if ( |
| (param.dtype == torch.float16) or (param.dtype == torch.bfloat16) |
| ) and param.__class__.__name__ != "Params4bit": |
| param.data = param.data.to(torch.float32) |
|
|
| if ( |
| loaded_in_kbit |
| or is_gptq_quantized |
| or is_aqlm_quantized |
| or is_eetq_quantized |
| or is_hqq_quantized |
| or is_torchao_quantized |
| ) and use_gradient_checkpointing: |
| |
| if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]: |
| |
| if hasattr(model, "enable_input_require_grads"): |
| model.enable_input_require_grads() |
| else: |
|
|
| def make_inputs_require_grad(module, input, output): |
| output.requires_grad_(True) |
|
|
| model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
| |
| _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( |
| inspect.signature(model.gradient_checkpointing_enable).parameters |
| ) |
|
|
| if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0: |
| warnings.warn( |
| "gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored." |
| " if you want to use that feature, please upgrade to the latest version of transformers.", |
| FutureWarning, |
| ) |
|
|
| gc_enable_kwargs = ( |
| {} if not _supports_gc_kwargs else {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} |
| ) |
|
|
| |
| model.gradient_checkpointing_enable(**gc_enable_kwargs) |
| return model |
|
|
|
|
| |
| def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): |
| """ |
| Shift input ids one token to the right. |
| |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids |
| pad_token_id (`int`): The id of the `padding` token. |
| decoder_start_token_id (`int`): The id of the `start` token. |
| """ |
| shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
| shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() |
| shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
| if pad_token_id is None: |
| raise ValueError("self.model.config.pad_token_id has to be defined.") |
| |
| shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) |
|
|
| return shifted_input_ids |
|
|
|
|
| class AuxiliaryTrainingWrapper(torch.nn.Module): |
| """Wrap a specific module so that it can be trained and saved in a way that is tangential to how |
| PEFT normally works, e.g. fully training a classification layer instead of using an adapter. |
| |
| """ |
|
|
| |
| adapter_layer_names: tuple[str, ...] = () |
| |
| other_param_names: tuple[str, ...] = () |
| |
| merged_adapters: list[str] = [] |
|
|
| def __init__(self, module_to_save, adapter_name, **kwargs): |
| """Extra kwargs will be passed to `self.init_modules` and `self.update`.""" |
| super().__init__() |
| self.original_module = module_to_save |
| self._active_adapter = [adapter_name] |
| self._disable_adapters = False |
| self._adapters = set() |
|
|
| self.init_modules(adapter_name, **kwargs) |
|
|
| self.update(adapter_name, **kwargs) |
| self.check_module() |
|
|
| def init_modules(self, adapter_name, **kwargs): |
| """A place to initialize PyTorch modules in `__init__` before the call to `self.update()`.""" |
| raise NotImplementedError |
|
|
| def _get_available_adapters(self) -> set[str]: |
| """Return all adapter names that can be found on this module.""" |
| raise NotImplementedError |
|
|
| def _error_message_name(self): |
| """Returns a user friendly identifier for error messages, e.g. for type compatibility error messages from |
| `check_module()` so that the user can backtrack where the error comes from. A generic "training wrapper" is |
| less helpful than "modules_to_save", for example. |
| """ |
| return "training wrapper" |
|
|
| def check_module(self): |
| """Perform some sanity checks on the module to ensure that it works""" |
| |
| |
| |
| forbidden_classes = (torch.nn.ModuleDict, torch.nn.ModuleList, torch.nn.ParameterDict, torch.nn.ParameterList) |
| if isinstance(self.original_module, forbidden_classes): |
| cls_name = self.original_module.__class__ |
| raise TypeError(f"{self._error_message_name()} cannot be applied to modules of type {cls_name}") |
|
|
| |
| from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
| if isinstance(self.original_module, BaseTunerLayer): |
| |
| cls_name = self.original_module.__class__ |
| raise TypeError(f"{self._error_message_name()} cannot be applied to modules of type {cls_name}") |
|
|
| @property |
| def disable_adapters(self) -> bool: |
| |
| return self._disable_adapters |
|
|
| @property |
| def active_adapter(self) -> Union[list[str], str]: |
| |
| return self._active_adapter |
|
|
| @property |
| def active_adapters(self) -> list[str]: |
| if isinstance(self._active_adapter, str): |
| return [self._active_adapter] |
| return self._active_adapter |
|
|
| def _hasattr_wrapped(self, name, modules): |
| """Infrastructure to enable the implementing class to delegate attributes to other modules. |
| Returns True if the implementing class knows how to handle attribute `name`. |
| |
| Gets passed `modules` which is PyTorch's internal list of assigned modules from `nn.Module`. |
| """ |
| return False |
|
|
| def _getattr_wrapped(self, name, modules): |
| """If `_hasattr_wrapped` returns True for `name`, then this function should return the corresponding |
| value associated with `name`. |
| """ |
| return None |
|
|
| def __getattr__(self, name: str): |
| |
| |
| try: |
| return super().__getattr__(name) |
| except AttributeError: |
| pass |
|
|
| if "_modules" not in self.__dict__: |
| raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") |
|
|
| |
| |
| modules = self.__dict__["_modules"] |
| if self.disable_adapters: |
| return getattr(self.original_module, name) |
| elif self._hasattr_wrapped(name, modules): |
| return self._getattr_wrapped(name, modules) |
|
|
| |
| |
| raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") |
|
|
| def update(self, adapter_name, **kwargs): |
| """Called when this instance should be part of an adapter's training. |
| Adds the given adapter to the list of adapters that this instance is training along with. |
| |
| Additional kwargs are expected to be the same kwargs that are also passed for initializing this class. |
| """ |
| if adapter_name not in self._adapters: |
| self._adapters.add(adapter_name) |
|
|
| def _create_new_hook(self, old_hook): |
| r""" |
| Creates a new hook based on the old hook. Use it only if you know what you are doing ! |
| """ |
| old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) |
| old_hook_attr = old_hook.__dict__ |
| filtered_old_hook_attr = {} |
| old_hook_init_signature = inspect.signature(old_hook_cls.__init__) |
| for k in old_hook_attr.keys(): |
| if k in old_hook_init_signature.parameters: |
| filtered_old_hook_attr[k] = old_hook_attr[k] |
| new_hook = old_hook_cls(**filtered_old_hook_attr) |
| return new_hook |
|
|
| def _check_forward_args(self, x, *args, **kwargs): |
| """Check if the arguments are compatible with the configs and state of the model""" |
| adapter_names = kwargs.get("adapter_names", None) |
| if adapter_names is None: |
| return |
|
|
| if len(x) != len(adapter_names): |
| msg = ( |
| "Length of `adapter_names` should be the same as the number of inputs, but got " |
| f"{len(adapter_names)} and {len(x)} respectively." |
| ) |
| raise ValueError(msg) |
|
|
| def _forward_wrapped(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
| raise NotImplementedError |
|
|
| def _forward_wrapped_mixed_batch( |
| self, x: torch.Tensor, active_adapter: str, *args: Any, **kwargs: Any |
| ) -> torch.Tensor: |
| raise NotImplementedError |
|
|
| def _forward_wrapped_passthrough(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
| """The forward call when no adapter is involved in the forward computation, only the base model""" |
| raise NotImplementedError |
|
|
| def _mixed_batch_forward( |
| self, input: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any |
| ) -> torch.Tensor: |
| |
| |
|
|
| SUPPORTED_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d) |
|
|
| module_names = ", ".join([module.__name__ for module in SUPPORTED_MODULES]) |
|
|
| if not isinstance(self.original_module, SUPPORTED_MODULES): |
| raise TypeError(f"Mixed batching is only supported for the following modules: {module_names}.") |
|
|
| unique_adapters = set(adapter_names) |
| sub_batch_indices_list = [] |
|
|
| for adapter in unique_adapters: |
| sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) |
|
|
| results = [0 for _ in range(len(input))] |
|
|
| for i, active_adapter in enumerate(unique_adapters): |
| sub_batch = input[sub_batch_indices_list[i]] |
|
|
| if active_adapter == "__base__": |
| output = self.original_module(sub_batch, *args, **kwargs) |
| else: |
| output = self._forward_wrapped_mixed_batch(sub_batch, active_adapter, *args, **kwargs) |
|
|
| for index, j in enumerate(sub_batch_indices_list[i]): |
| results[j] = output[index] |
|
|
| return torch.stack(results) |
|
|
| def forward(self, x: torch.Tensor, *args, **kwargs): |
| self._check_forward_args(x, *args, **kwargs) |
| adapter_names = kwargs.pop("adapter_names", None) |
|
|
| if self.disable_adapters or any(adapter not in self._adapters for adapter in self.active_adapters): |
| return self._forward_wrapped_passthrough(x, *args, **kwargs) |
|
|
| if adapter_names is None: |
| return self._forward_wrapped(x, *args, **kwargs) |
| return self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) |
|
|
| def enable_adapters(self, enabled: bool): |
| """Toggle the enabling and disabling of adapters |
| |
| Args: |
| enabled (bool): True to enable adapters, False to disable adapters |
| """ |
| if enabled: |
| self._disable_adapters = False |
| else: |
| self._disable_adapters = True |
|
|
| def check_set_adapter(self, adapter_name: str | list[str]) -> str | None: |
| """Helper function to check if the given adapter(s) can be set. |
| |
| Return the name of the adapter to be set or None if no adapter should be set. |
| """ |
| raise NotImplementedError |
|
|
| def set_adapter(self, adapter_names: Union[str, list[str]], inference_mode: bool = False) -> None: |
| """Set the active adapter |
| |
| Args: |
| adapter_names (str or list[str]): |
| The name(s) of the adapter(s) to set as active |
| inference_mode (bool, optional): |
| Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. |
| """ |
| if isinstance(adapter_names, str): |
| self._active_adapter = adapter_names |
| else: |
| self._active_adapter = [] |
| for adapter_name in adapter_names: |
| if adapter_name not in self._adapters: |
| raise ValueError(f"Adapter {adapter_name} not found in {self._adapters}") |
|
|
| self._active_adapter.append(adapter_name) |
|
|
| def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None: |
| """Delete an adapter from the layer, set a new active adapter if necessary""" |
| raise NotImplementedError |
|
|
| def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None: |
| """ |
| Enable or disable gradients on the given adapter(s). |
| |
| Args: |
| adapter_name (`str` or `Sequence[str]`): |
| The name of the adapter(s) whose gradients should be enabled/disabled. |
| requires_grad (`bool`, *optional*) |
| Whether to enable (`True`, default) or disable (`False`). |
| """ |
| if isinstance(adapter_names, str): |
| adapter_names_set = {adapter_names} |
| else: |
| adapter_names_set = set(adapter_names) |
|
|
| for layer_name in self.adapter_layer_names: |
| |
| module_dict = attrgetter(layer_name)(self) |
| for key, layer in module_dict.items(): |
| if key in adapter_names_set: |
| layer.requires_grad_(requires_grad) |
|
|
| def adapter_state_dict(self, adapter_name): |
| """Return the state dict of this module for a given adapter.""" |
| raise NotImplementedError |
|
|
| def adapter_state_dict_load_map(self, adapter_name): |
| """Return a mapping from the key present in disk-loaded state dict |
| and how it should be represented in the loaded model's state dict. |
| |
| The default should be a 1:1 mapping but it is important to define a mapping as it also serves as the |
| ground-truth for which keys are supposed to be loaded from a saved state dict. |
| """ |
| raise NotImplementedError |
|
|
| def unload_and_optionally_merge_module( |
| self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]] |
| ) -> torch.nn.Module: |
| """Handles unloading when called from PEFT models. Returns the wrapped module |
| and handles merging onto the wrapped module if requested. |
| """ |
| raise NotImplementedError |
|
|
|
|
| class ModulesToSaveWrapper(AuxiliaryTrainingWrapper): |
| """Wraps a module that is supposed to be trained (i.e. `requires_grad_(True)`) and saved after training.""" |
|
|
| |
| adapter_layer_names: tuple[str, ...] = ("modules_to_save",) |
|
|
| def __init__(self, module_to_save, adapter_name, tied_module=None): |
| super().__init__(module_to_save, adapter_name, tied_module=tied_module) |
|
|
| def init_modules(self, adapter_name, **kwargs): |
| |
| self.modules_to_save = torch.nn.ModuleDict({}) |
|
|
| def _error_message_name(self): |
| return "modules_to_save" |
|
|
| def _forward_wrapped(self, x, *args, **kwargs): |
| if not self.active_adapters: |
| return self._forward_wrapped_passthrough(x, *args, **kwargs) |
| return self.modules_to_save[self.active_adapters[0]](x, *args, **kwargs) |
|
|
| def _forward_wrapped_mixed_batch(self, x, active_adapter, *args, **kwargs): |
| return self.modules_to_save[active_adapter](x, *args, **kwargs) |
|
|
| def _forward_wrapped_passthrough(self, x, *args, **kwargs): |
| return self.original_module(x, *args, **kwargs) |
|
|
| def _hasattr_wrapped(self, name, modules): |
| return self.active_adapters[0] in modules["modules_to_save"] |
|
|
| def _getattr_wrapped(self, name, modules): |
| return getattr(modules["modules_to_save"][self.active_adapters[0]], name) |
|
|
| def update(self, adapter_name, tied_module=None, **kwargs): |
| super().update(adapter_name) |
|
|
| context_manager = nullcontext() |
| for _, param in self.original_module.named_parameters(): |
| num_params = param.numel() |
| |
| if num_params == 0 and hasattr(param, "ds_numel"): |
| import deepspeed |
|
|
| context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0) |
| break |
|
|
| if adapter_name not in self.modules_to_save: |
| with context_manager: |
| if tied_module: |
| new_linear = torch.nn.Linear(*tied_module.weight.shape, bias=False) |
| new_linear.weight = tied_module.weight |
|
|
| self.modules_to_save[adapter_name] = new_linear |
| else: |
| self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module) |
|
|
| if hasattr(self.modules_to_save[adapter_name], "_hf_hook"): |
| old_hook = self.modules_to_save[adapter_name]._hf_hook |
| new_hook = self._create_new_hook(old_hook) |
| remove_hook_from_module(self.modules_to_save[adapter_name]) |
| add_hook_to_module(self.modules_to_save[adapter_name], new_hook) |
|
|
| self.original_module.requires_grad_(False) |
|
|
| |
| |
| |
| if adapter_name == self.active_adapter: |
| self.modules_to_save[adapter_name].requires_grad_(True) |
|
|
| def enable_adapters(self, enabled: bool): |
| """Takes care of setting the required_grad flag on the wrapped module. |
| If adapters are enabled, gradients for the module are required as well. |
| """ |
| super().enable_adapters(enabled) |
|
|
| if enabled: |
| self.original_module.requires_grad_(False) |
| for adapter_name in self.active_adapters: |
| self.modules_to_save[adapter_name].requires_grad_(True) |
| else: |
| self.original_module.requires_grad_(True) |
| self.modules_to_save.requires_grad_(False) |
|
|
| def check_set_adapter(self, adapter_name: str | list[str]) -> str | None: |
| """Helper function to check if the given adapter(s) can be set. |
| |
| Return the name of the adapter to be set or None if no adapter should be set. |
| """ |
| if isinstance(adapter_name, str): |
| return adapter_name |
|
|
| |
| if len(adapter_name) == 0: |
| raise ValueError("Please specify at least one adapter to set") |
|
|
| adapter_names_in_module = [n for n in adapter_name if n in self.modules_to_save] |
|
|
| if len(adapter_names_in_module) > 1: |
| raise ValueError(f"Only one adapter can be set at a time for {self}, got {len(adapter_names_in_module)}") |
|
|
| adapter_name_to_set: str | None |
| if not adapter_names_in_module: |
| adapter_name_to_set = None |
| else: |
| adapter_name_to_set = adapter_names_in_module[0] |
|
|
| return adapter_name_to_set |
|
|
| def set_adapter(self, adapter_names: Union[str, list[str]], inference_mode: bool = False) -> None: |
| """Set the active adapter |
| |
| Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True) unless |
| inference_mode is True. |
| |
| Args: |
| adapter_names (list[str], str): |
| The name(s) of the adapter(s) to set as active. |
| inference_mode (bool, optional): |
| Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. |
| """ |
| if isinstance(adapter_names, str): |
| adapter_names = [adapter_names] |
|
|
| if len(adapter_names) > 1: |
| raise ValueError(f"Attempted to set multiple ({adapter_names}) adapters at once for modules_to_save.") |
|
|
| if len(adapter_names) == 0: |
| |
| self._active_adapter = [] |
| return |
|
|
| adapter_name = adapter_names[0] |
|
|
| if adapter_name not in self._adapters: |
| raise ValueError(f"Adapter {adapter_name} not found in {self._adapters}") |
|
|
| for currently_active_adapter_name in self.active_adapters: |
| self.modules_to_save[currently_active_adapter_name].requires_grad_(False) |
| self.modules_to_save[adapter_name].requires_grad_(not inference_mode) |
| self._active_adapter = adapter_name |
|
|
| def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None: |
| """ |
| Delete the adapter if present. |
| |
| This method will also set a new active adapter if the deleted adapter was the active adapter. It is important |
| that the new adapter is chosen by the caller in a deterministic way, so that the same adapter is chosen on all |
| layers. |
| """ |
| if adapter_name not in self.modules_to_save: |
| return |
|
|
| |
| |
| if isinstance(new_active_adapters, (list, tuple)) and len(new_active_adapters) > 1: |
| name = self.__class__.__name__ |
| raise ValueError( |
| f"Attempted to set multiple ({new_active_adapters}) adapters at once for {name}, which is not allowed." |
| ) |
|
|
| if adapter_name in self._adapters: |
| self._adapters.remove(adapter_name) |
|
|
| if not new_active_adapters: |
| |
| del self.modules_to_save[adapter_name] |
| self._active_adapter = [] |
| return |
|
|
| new_active_adapter = new_active_adapters[0] |
| if new_active_adapter not in self.modules_to_save: |
| |
| del self.modules_to_save[adapter_name] |
| self._active_adapter = [] |
| return |
|
|
| if new_active_adapter != self.active_adapters[0]: |
| self.set_adapter(new_active_adapter) |
| del self.modules_to_save[adapter_name] |
|
|
| def adapter_state_dict_load_map(self, adapter_name): |
| |
| |
| if adapter_name not in self._adapters: |
| |
| |
| return {} |
| return {k: f"modules_to_save.{adapter_name}.{k}" for k in self.modules_to_save[adapter_name].state_dict()} |
|
|
| def adapter_state_dict(self, adapter_name, state_dict): |
| if adapter_name not in self._adapters: |
| |
| |
| return {} |
|
|
| return { |
| k: state_dict[f"modules_to_save.{adapter_name}.{k}"] |
| for k in self.modules_to_save[adapter_name].state_dict() |
| } |
|
|
| def unload_and_optionally_merge_module( |
| self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]] |
| ) -> torch.nn.Module: |
| """Unloading in case of `ModulesToSave` means to simply return the wrapped module. |
| |
| However, if the wrapped module is itself a tuner, we'll call merge on it before. |
| """ |
| new_module = self.modules_to_save[self.active_adapter] |
|
|
| |
| |
| if hasattr(new_module, "base_layer"): |
| |
| if merge: |
| new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) |
| new_module = new_module.get_base_layer() |
|
|
| return new_module |
|
|
| def _get_available_adapters(self) -> set[str]: |
| """Return all adapter names that can be found on this module.""" |
| return set(self.modules_to_save.keys()) |
|
|
|
|
| class TrainableTokensWrapper(AuxiliaryTrainingWrapper): |
| """Wraps a module (typically an embedding layer) that is supposed to be re-trained selectively (i.e. |
| solely updating a few columns) using the `TrainableTokensLayer` PEFT method. |
| |
| Supports weight-tying to another adapter when passed a `tied_adapter` which is expected to be a |
| `TrainableTokensLayer`. |
| """ |
|
|
| |
| adapter_layer_names: tuple[str, ...] = ("token_adapter.trainable_tokens_delta",) |
| other_param_names: tuple[str, ...] = ("token_adapter.token_indices", "token_adapter.trainable_tokens_original") |
|
|
| def __init__( |
| self, |
| module_to_save: torch.nn.Module, |
| adapter_name: str, |
| token_indices: list[int], |
| tied_adapter=None, |
| ) -> None: |
| super().__init__(module_to_save, adapter_name, token_indices=token_indices, tied_adapter=tied_adapter) |
|
|
| |
| self.original_module = None |
|
|
| @property |
| def original_module(self): |
| |
| |
| return self.token_adapter.base_layer |
|
|
| def init_modules(self, adapter_name, token_indices, tied_adapter): |
| |
| from peft.tuners.trainable_tokens import TrainableTokensLayer |
|
|
| |
| |
| self.token_adapter = TrainableTokensLayer(self.original_module, adapter_name, token_indices, tied_adapter) |
|
|
| def _error_message_name(self): |
| return "trainable_token_indices" |
|
|
| def _hasattr_wrapped(self, name, modules): |
| return name == "weight" |
|
|
| def _getattr_wrapped(self, name, modules): |
| |
| |
| |
| if name == "weight": |
| return modules["token_adapter"].get_merged_weights(self.token_adapter.active_adapters) |
|
|
| raise RuntimeError( |
| f"This code should've never been reached, probably a bad check in `_hasattr_wrapped` for {name}. " |
| "Please file an issue under https://github.com/huggingface/peft/issues." |
| ) |
|
|
| def _forward_wrapped(self, x, *args, **kwargs): |
| if not self.active_adapters: |
| return self._forward_wrapped_passthrough(x, *args, **kwargs) |
| return self.token_adapter(x) |
|
|
| def _forward_wrapped_mixed_batch(self, x, active_adapter, *args, **kwargs): |
| return self.token_adapter.forward_adapters(x, [active_adapter]) |
|
|
| def _forward_wrapped_passthrough(self, x, *args, **kwargs): |
| |
| |
| return self.token_adapter(x, *args, **kwargs) |
|
|
| def update(self, active_adapter, **kwargs): |
| |
| |
| if active_adapter not in self._adapters: |
| self.token_adapter.update_layer(active_adapter, **kwargs) |
|
|
| super().update(active_adapter) |
|
|
| def adapter_state_dict_load_map(self, adapter_name): |
| if self.token_adapter.tied_adapter: |
| return {} |
| return {"token_adapter.trainable_tokens_delta": f"token_adapter.trainable_tokens_delta.{adapter_name}"} |
|
|
| def adapter_state_dict(self, adapter_name, state_dict): |
| if self.token_adapter.tied_adapter: |
| |
| |
| |
| return {} |
|
|
| return { |
| f"token_adapter.{k}": state_dict[f"token_adapter.{k}.{adapter_name}"] for k in ["trainable_tokens_delta"] |
| } |
|
|
| def enable_adapters(self, enabled: bool): |
| """Enables/disables the underlying `TrainableTokens` adapter. |
| Also handles the internal adapter disable flag. |
| """ |
| super().enable_adapters(enabled) |
|
|
| self.token_adapter.enable_adapters(enabled) |
|
|
| def check_set_adapter(self, adapter_name: str | list[str]) -> str | None: |
| """Helper function to check if the given adapter(s) can be set. |
| |
| Return the name of the adapter to be set or None if no adapter should be set. |
| """ |
| if isinstance(adapter_name, str): |
| return adapter_name |
|
|
| |
| if len(adapter_name) == 0: |
| raise ValueError("Please specify at least one adapter to set") |
|
|
| |
| adapter_names_in_module = [n for n in adapter_name if n in self.token_adapter.trainable_tokens_delta] |
|
|
| if len(adapter_names_in_module) > 1: |
| raise ValueError(f"Only one adapter can be set at a time for {self}, got {len(adapter_names_in_module)}") |
|
|
| adapter_name_to_set: str | None |
| if not adapter_names_in_module: |
| adapter_name_to_set = None |
| else: |
| adapter_name_to_set = adapter_names_in_module[0] |
|
|
| return adapter_name_to_set |
|
|
| def set_adapter(self, adapter_names: Union[str, list[str]], inference_mode: bool = False) -> None: |
| super().set_adapter(adapter_names, inference_mode=inference_mode) |
| self.token_adapter.set_adapter(adapter_names, inference_mode=inference_mode) |
|
|
| def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None: |
| """ |
| Delete the adapter if present. |
| |
| This method will also set a new active adapter if the deleted adapter was the active adapter. It is important |
| that the new adapter is chosen by the caller in a deterministic way, so that the same adapter is chosen on all |
| layers. |
| """ |
| self.token_adapter.delete_adapter(adapter_name) |
|
|
| |
| |
| if isinstance(new_active_adapters, (list, tuple)) and len(new_active_adapters) > 1: |
| name = self.__class__.__name__ |
| raise ValueError( |
| f"Attempted to set multiple ({new_active_adapters}) adapters at once for {name}, which is not allowed." |
| ) |
|
|
| if adapter_name in self._adapters: |
| self._adapters.remove(adapter_name) |
|
|
| if not new_active_adapters: |
| self._active_adapter = [] |
| return |
|
|
| if new_active_adapters[0] not in self.token_adapter.trainable_tokens_delta: |
| |
| self._active_adapter = [] |
| return |
|
|
| new_active_adapter = new_active_adapters[0] |
| self.set_adapter(new_active_adapter) |
|
|
| def unload_and_optionally_merge_module( |
| self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]] |
| ) -> torch.nn.Module: |
| """Unloading for `TrainableTokensWrapper` means to return the wrapped module, e.g. the embedding layer and, |
| if requested, merging the `TrainableTokens` adapter onto the wrapped module. |
| """ |
| if merge: |
| self.token_adapter.merge(safe_merge=safe_merge, adapter_names=adapter_names) |
| return self.token_adapter.get_base_layer() |
|
|
| def _get_available_adapters(self) -> set[str]: |
| """Return all adapter names that can be found on this module.""" |
| return set(self.token_adapter.trainable_tokens_delta.keys()) |
|
|
|
|
| def _get_input_embeddings_name(model, default=None): |
| if not hasattr(model, "get_input_embeddings"): |
| return default |
|
|
| input_embeddings = model.get_input_embeddings() |
| for name, module in model.named_modules(): |
| if module is input_embeddings: |
| return name |
|
|
| return default |
|
|
|
|
| def _get_submodules(model, key): |
| parent = model.get_submodule(".".join(key.split(".")[:-1])) |
| target_name = key.split(".")[-1] |
| target = model.get_submodule(key) |
| return parent, target, target_name |
|
|
|
|
| def _get_submodules_with_grandparent(model, key): |
| parent = model.get_submodule(".".join(key.split(".")[:-1])) |
| try: |
| grandparent = model.get_submodule(".".join(key.split(".")[:-2])) |
| except AttributeError: |
| |
| grandparent = None |
| target_name = key.split(".")[-1] |
| target = model.get_submodule(key) |
| return parent, grandparent, target, target_name |
|
|
|
|
| def _freeze_adapter(model, adapter_name): |
| for n, p in model.named_parameters(): |
| if adapter_name in n: |
| p.requires_grad = False |
|
|
|
|
| def _set_trainable( |
| model, |
| adapter_name, |
| module_names, |
| inference_mode: bool, |
| strict_module_check: bool = False, |
| wrapper_cls: Optional[AuxiliaryTrainingWrapper] = None, |
| activate_adapter: bool = True, |
| **wrapper_kwargs, |
| ): |
| """Wraps modules that are supposed to be re-trained either normally, i.e. marking them to require gradients and |
| saving them alongside other modules, or with certain methods that go alongside PEFT methods, such as retraining |
| specific token indices using selective read/write. |
| |
| Note that you need to validate beforehand if there are layers targeted by multiple wrappers, e.g. if the |
| 'embedding' layer is configured for both `ModulesToSaveWrapper` and `TrainableTokensWrapper` there would be |
| conflicts down the line. |
| |
| The default is to wrap the module in a `ModulesToSaveWrapper` wrapper. |
| |
| If `strict_module_check` is set, this method raises an ValueError, similar to BaseTuner.inject_adapter when none of |
| the requested modules in `module_names` is not found in the model. |
| |
| The `active_adapter` flag indicates if this new adapter should be activated. |
| """ |
| from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
| if wrapper_cls is None: |
| wrapper_cls = ModulesToSaveWrapper |
|
|
| if not module_names: |
| |
| |
| return |
|
|
| trainable_modules = [] |
| found_modules = set() |
| |
| key_list = [key for key, _ in model.named_modules(remove_duplicate=False)] |
|
|
| for key in key_list: |
| target_module_found = any(key.endswith(target_key) for target_key in module_names) |
| if target_module_found: |
| parent, grandparent, target, target_name = _get_submodules_with_grandparent(model, key) |
| if isinstance(grandparent, BaseTunerLayer): |
| |
| |
| |
| |
| |
| |
| |
| raise ValueError( |
| f"You are trying to target a module with {wrapper_cls} that is a child of {type(grandparent)}. " |
| "This is almost certainly not the intended behavior. Please ensure that the adapter name, " |
| f"'{adapter_name}', does not conflict with any of the targeted modules." |
| ) |
|
|
| if isinstance(target, wrapper_cls): |
| target.update(adapter_name, **wrapper_kwargs) |
| target.set_adapter(target.active_adapter, inference_mode=inference_mode) |
| else: |
| new_module = wrapper_cls(target, adapter_name, **wrapper_kwargs) |
| if activate_adapter: |
| new_module.set_adapter(adapter_name, inference_mode=inference_mode) |
| else: |
| new_module.set_adapter([], inference_mode=inference_mode) |
| setattr(parent, target_name, new_module) |
| trainable_modules.append(new_module) |
| found_modules.add(target_name) |
|
|
| not_found = set(module_names).difference(found_modules) |
| if strict_module_check and not found_modules: |
| raise ValueError( |
| f"Target modules {not_found} not found in the base model. Please check the target modules and try again." |
| ) |
|
|
| return trainable_modules |
|
|
|
|
| def _set_adapter(model, adapter_name: str | list[str], inference_mode: bool = False): |
| for module in model.modules(): |
| if isinstance(module, AuxiliaryTrainingWrapper): |
| |
| adapter_name_to_set = module.check_set_adapter(adapter_name) |
|
|
| |
| |
| if adapter_name_to_set in module._adapters: |
| module.enable_adapters(True) |
| module.set_adapter(adapter_name_to_set, inference_mode=inference_mode) |
| else: |
| module.enable_adapters(False) |
| module.set_adapter([], inference_mode=inference_mode) |
|
|
|
|
| def _prepare_prompt_learning_config(peft_config, model_config): |
| |
| if "text_config" in model_config: |
| model_config = model_config["text_config"] |
|
|
| if peft_config.num_layers is None: |
| if "num_hidden_layers" in model_config: |
| num_layers = model_config["num_hidden_layers"] |
| elif "num_layers" in model_config: |
| num_layers = model_config["num_layers"] |
| elif "n_layer" in model_config: |
| num_layers = model_config["n_layer"] |
| else: |
| raise ValueError("Please specify `num_layers` in `peft_config`") |
| peft_config.num_layers = num_layers |
|
|
| if peft_config.token_dim is None: |
| if "hidden_size" in model_config: |
| token_dim = model_config["hidden_size"] |
| elif "n_embd" in model_config: |
| token_dim = model_config["n_embd"] |
| elif "d_model" in model_config: |
| token_dim = model_config["d_model"] |
| else: |
| raise ValueError("Please specify `token_dim` in `peft_config`") |
| peft_config.token_dim = token_dim |
|
|
| if peft_config.num_attention_heads is None: |
| if "num_attention_heads" in model_config: |
| num_attention_heads = model_config["num_attention_heads"] |
| elif "n_head" in model_config: |
| num_attention_heads = model_config["n_head"] |
| elif "num_heads" in model_config: |
| num_attention_heads = model_config["num_heads"] |
| elif "encoder_attention_heads" in model_config: |
| num_attention_heads = model_config["encoder_attention_heads"] |
| else: |
| raise ValueError("Please specify `num_attention_heads` in `peft_config`") |
| peft_config.num_attention_heads = num_attention_heads |
|
|
| |
| if (peft_config.peft_type == "PREFIX_TUNING") and ("num_key_value_heads" in model_config): |
| num_key_value_heads = model_config["num_key_value_heads"] |
| if model_config.get("head_dim", None) is not None: |
| head_dim = model_config["head_dim"] |
| else: |
| head_dim = peft_config.token_dim // peft_config.num_attention_heads |
| peft_config.token_dim = head_dim * num_key_value_heads |
| peft_config.num_attention_heads = num_key_value_heads |
|
|
| if getattr(peft_config, "encoder_hidden_size", None) is None: |
| setattr(peft_config, "encoder_hidden_size", peft_config.token_dim) |
|
|
| return peft_config |
|
|
|
|
| def _get_no_split_modules(model) -> set[str]: |
| """ |
| Get the modules of the model that should not be split when using device_map. We iterate through the modules to get |
| the underlying `_no_split_modules`. |
| |
| Returns: |
| `List[str]`: List of modules that should not be split |
| """ |
| |
| |
| _no_split_modules: set[str] = set() |
| if not hasattr(model, "_no_split_modules"): |
| return _no_split_modules |
|
|
| modules_to_check = [model] |
| while len(modules_to_check) > 0: |
| module = modules_to_check.pop(-1) |
| |
| if module.__class__.__name__ not in _no_split_modules: |
| if isinstance(module, PreTrainedModel): |
| if module._no_split_modules is not None: |
| _no_split_modules = _no_split_modules | set(module._no_split_modules) |
| modules_to_check += list(module.children()) |
| return _no_split_modules |
|
|
|
|
| def fsdp_auto_wrap_policy(model): |
| if hasattr(FullyShardedDataParallelPlugin, "get_module_class_from_name"): |
| get_module_class_from_name = FullyShardedDataParallelPlugin.get_module_class_from_name |
| else: |
| from accelerate.utils.dataclasses import get_module_class_from_name |
| from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy |
|
|
| from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder |
|
|
| default_transformer_cls_names_to_wrap = ",".join(_get_no_split_modules(model)) |
| transformer_cls_names_to_wrap = os.environ.get( |
| "FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap |
| ).split(",") |
| transformer_cls_to_wrap = {PrefixEncoder, PromptEncoder, PromptEmbedding} |
| for layer_class in transformer_cls_names_to_wrap: |
| if len(layer_class) == 0: |
| continue |
| transformer_cls = get_module_class_from_name(model, layer_class) |
| if transformer_cls is None: |
| raise Exception("Could not find the transformer layer class to wrap in the model.") |
| else: |
| transformer_cls_to_wrap.add(transformer_cls) |
|
|
| def lambda_policy_fn(module): |
| if ( |
| len(list(module.named_children())) == 0 |
| and getattr(module, "weight", None) is not None |
| and module.weight.requires_grad |
| ): |
| return True |
| return False |
|
|
| lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) |
| transformer_wrap_policy = functools.partial( |
| transformer_auto_wrap_policy, |
| transformer_layer_cls=transformer_cls_to_wrap, |
| ) |
|
|
| auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) |
| return auto_wrap_policy |
|
|
|
|
| def transpose(weight, fan_in_fan_out): |
| if not fan_in_fan_out: |
| return weight |
|
|
| if isinstance(weight, torch.nn.Parameter): |
| return torch.nn.Parameter(weight.T) |
| return weight.T |
|
|
|
|
| def _is_valid_match(key: str, target_key: str): |
| """ |
| Helper function to match module names target_key and key. Makes sure that either the key is exactly the target_key |
| or the target_key is a submodule of key |
| """ |
| if key.endswith(target_key): |
| if len(key) > len(target_key): |
| return key.endswith("." + target_key) |
| return True |
| return False |
|
|
|
|
| def _get_batch_size(input_ids: Optional[torch.Tensor], inputs_embeds: Optional[torch.Tensor]) -> int: |
| """Get the batch size based on either input_ids or input_embeds |
| |
| Raises an ValueError if both are None. |
| |
| """ |
| if (input_ids is None) and (inputs_embeds is None): |
| raise ValueError("You have to provide either input_ids or inputs_embeds") |
|
|
| if input_ids is not None: |
| batch_size = input_ids.shape[0] |
| else: |
| batch_size = inputs_embeds.shape[0] |
| return batch_size |
|
|
|
|
| def get_quantization_config(model: torch.nn.Module, method: str): |
| """ |
| Get the quantization config of the related quantization method |
| """ |
| if ( |
| hasattr(model, "config") |
| and hasattr(model.config, "quantization_config") |
| and (getattr(model, "quantization_method", None) == method) |
| ): |
| return model.config.quantization_config |
| return None |
|
|
|
|
| def get_auto_gptq_quant_linear(gptq_quantization_config): |
| """ |
| Get the right AutoGPTQQuantLinear class based on the quantization config file |
| """ |
| if gptq_quantization_config is None: |
| return None |
|
|
| if is_auto_gptq_available(): |
| from auto_gptq.utils.import_utils import dynamically_import_QuantLinear |
| else: |
| return None |
|
|
| desc_act = gptq_quantization_config.desc_act |
| group_size = gptq_quantization_config.group_size |
| bits = gptq_quantization_config.bits |
| if hasattr(gptq_quantization_config, "use_exllama"): |
| use_exllama = gptq_quantization_config.use_exllama |
| else: |
| use_exllama = not gptq_quantization_config.disable_exllama |
| if hasattr(gptq_quantization_config, "exllama_config"): |
| exllama_version = gptq_quantization_config.exllama_config["version"] |
| else: |
| exllama_version = 1 |
|
|
| QuantLinear = dynamically_import_QuantLinear( |
| use_triton=False, |
| desc_act=desc_act, |
| group_size=group_size, |
| bits=bits, |
| disable_exllama=not (use_exllama and exllama_version == 1), |
| disable_exllamav2=not (use_exllama and exllama_version == 2), |
| ) |
|
|
| return QuantLinear |
|
|
|
|
| def get_gptqmodel_quant_linear(gptq_quantization_config, device_map=None): |
| """ |
| Get the right GPTQQuantLinear class based on the quantization config file |
| """ |
| if gptq_quantization_config is None: |
| return None |
|
|
| if not is_gptqmodel_available(): |
| return None |
|
|
| from gptqmodel.utils.importer import hf_select_quant_linear |
|
|
| desc_act = gptq_quantization_config.desc_act |
| group_size = gptq_quantization_config.group_size |
| bits = gptq_quantization_config.bits |
| checkpoint_format = ( |
| gptq_quantization_config.checkpoint_format |
| if hasattr(gptq_quantization_config, "checkpoint_format") |
| else "gptq" |
| ) |
| sym = gptq_quantization_config.sym |
| meta = gptq_quantization_config.meta if hasattr(gptq_quantization_config, "meta") else None |
|
|
| QuantLinear = hf_select_quant_linear( |
| bits=bits, |
| group_size=group_size, |
| desc_act=desc_act, |
| sym=sym, |
| device_map=device_map, |
| checkpoint_format=checkpoint_format, |
| meta=meta, |
| backend="auto_trainable", |
| ) |
|
|
| return QuantLinear |
|
|
|
|
| def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]: |
| """ |
| Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For |
| example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is |
| guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with |
| non-overlapping lifetimes may have the same id. |
| |
| This method is the exact same copy of |
| https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py#L282C1-L300C58 but we added |
| it here manually to avoid import issue with old versions of transformers. |
| """ |
| if tensor.device.type == "xla" and is_torch_tpu_available(): |
| |
| |
| |
| |
| import torch_xla |
|
|
| unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor) |
| else: |
| unique_id = storage_ptr(tensor) |
|
|
| return tensor.device, unique_id, storage_size(tensor) |
|
|
|
|
| def cast_mixed_precision_params(model, dtype): |
| """ |
| Cast all non-trainable parameters of the model to the given `dtype`. The `dtype` can be `torch.float16` or |
| `torch.bfloat16` as per the mixed-precision training you are performing. The trainable parameters are cast to full |
| precision. This is meant to reduce the GPU memory usage when using PEFT methods by using half-precision dtype for |
| non-trainable parameters. Having the trainable parameters in full-precision preserves training stability when using |
| automatic mixed-precision training. |
| |
| Args: |
| model (`torch.nn.Module`): |
| The model to cast the non-trainable parameters of. |
| dtype (`torch.dtype`): |
| The dtype to cast the non-trainable parameters to. The `dtype` can be `torch.float16` or |
| `torch.bfloat16` as per the mixed-precision training you are performing. |
| """ |
| for p in model.parameters(): |
| if not p.requires_grad: |
| p.data = p.to(dtype) |
| else: |
| p.data = p.to(torch.float32) |
|
|
|
|
| def str_to_bool(value: str) -> int: |
| """ |
| Converts a string representation of truth to `True` (1) or `False` (0). |
| |
| True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; |
| """ |
| |
| value = value.lower() |
| if value in ("y", "yes", "t", "true", "on", "1"): |
| return 1 |
| elif value in ("n", "no", "f", "false", "off", "0"): |
| return 0 |
| else: |
| raise ValueError(f"invalid truth value {value}") |
|
|
|
|
| def check_file_exists_on_hf_hub(repo_id: str, filename: str, **kwargs) -> Optional[bool]: |
| """Check if a file exists on HF Hub, if check was not successful returns None instead of erroring. |
| |
| Respect offline mode if set. |
| |
| """ |
| exists: Optional[bool] = None |
| if str_to_bool(os.environ.get("HF_HUB_OFFLINE", "0")): |
| |
| return exists |
|
|
| try: |
| exists = file_exists(repo_id, filename, **kwargs) |
| except (HFValidationError, EntryNotFoundError): |
| |
| pass |
| except Exception as e: |
| warnings.warn( |
| f"Unable to fetch remote file due to the following error {e} - silently ignoring the lookup" |
| f" for the file {filename} in {repo_id}." |
| ) |
|
|
| return exists |
|
|
|
|
| def match_target_against_key(target_pattern: str, key: str): |
| """Backing function for `target_modules` config parameter. |
| |
| Having this as its own function ensures that target key matching can be implemented in the same way everywhere. |
| """ |
| return re.fullmatch(target_pattern, key) |
|
|
|
|
| def get_pattern_key(pattern_keys: Sequence[str], key_to_match: str) -> str: |
| """Match a substring of key_to_match in pattern keys""" |
| for key in pattern_keys: |
| match = re.match(rf"(.*\.)?({key})$", key_to_match) |
| if not match: |
| continue |
| return key |
|
|
| return key_to_match |
|
|
|
|
| def set_additional_trainable_modules(model, peft_config, model_config, adapter_name, activate_adapter: bool = True): |
| """Handle the resolution of additional trainable modules (also called AuxiliaryTrainingWrapper) |
| by checking the config if such modules are requested and adding them to the model. |
| |
| Currently trainable tokens and modules to save are considered additional trainable modules. |
| |
| If `activate_adapter` is set to `False`, the adapter won't be activated. This is typically the case when |
| `model.add_adapter` or `model.load_adapter` are being called. |
| """ |
| if getattr(peft_config, "modules_to_save", None) is not None: |
| |
| _set_trainable( |
| model, |
| adapter_name, |
| inference_mode=peft_config.inference_mode, |
| module_names=getattr(peft_config, "modules_to_save", None), |
| activate_adapter=activate_adapter, |
| ) |
|
|
| if getattr(peft_config, "modules_to_tie", None) is not None: |
| |
| |
| |
| tied_module = getattr(model.get_input_embeddings().modules_to_save, adapter_name) |
| _set_trainable( |
| model, |
| adapter_name, |
| inference_mode=peft_config.inference_mode, |
| module_names=getattr(peft_config, "modules_to_tie", None), |
| activate_adapter=activate_adapter, |
| tied_module=tied_module, |
| ) |
|
|
| if getattr(peft_config, "trainable_token_indices", None) is not None: |
| if isinstance(peft_config.trainable_token_indices, dict): |
| target_layers = peft_config.trainable_token_indices |
| else: |
| layer_name = _get_input_embeddings_name(model, "embed_tokens") |
| target_layers = {layer_name: peft_config.trainable_token_indices} |
|
|
| modules_to_save = getattr(peft_config, "modules_to_save", None) |
| if modules_to_save is not None: |
| for target_layer in target_layers: |
| if target_layer in modules_to_save: |
| raise ValueError( |
| "The embedding layer is already marked to be trained fully, either specify " |
| f'`modules_to_save=[..., "{target_layer}", ...]` or ' |
| f"`trainable_tokens={{'{target_layer}': x}}` but not both." |
| ) |
|
|
| for target_layer, token_indices in target_layers.items(): |
| _set_trainable( |
| model, |
| adapter_name, |
| inference_mode=peft_config.inference_mode, |
| module_names=[target_layer], |
| strict_module_check=True, |
| wrapper_cls=TrainableTokensWrapper, |
| token_indices=token_indices, |
| activate_adapter=activate_adapter, |
| ) |
|
|
| tied_weights_module_names = _get_module_names_tied_with_embedding(model) |
|
|
| |
| |
| |
| if ( |
| tied_weights_module_names |
| and model_config.get("tie_word_embeddings", False) |
| and isinstance(model.get_input_embeddings(), TrainableTokensWrapper) |
| ): |
| token_adapter = model.get_input_embeddings().token_adapter |
| _set_trainable( |
| model, |
| adapter_name, |
| inference_mode=peft_config.inference_mode, |
| module_names=tied_weights_module_names, |
| strict_module_check=True, |
| wrapper_cls=TrainableTokensWrapper, |
| token_indices=token_adapter.token_indices[adapter_name], |
| tied_adapter=model.get_input_embeddings().token_adapter, |
| ) |
|
|
|
|
| def create_attention_mask( |
| model, *, model_input, attention_mask, past_key_values, cache_position, batch_size, sequence_length, position_ids |
| ): |
| |
| |
| |
| |
| |
| transformers_ge_4_53_1 = version.parse(transformers.__version__) >= version.parse("4.53.1") |
| if transformers_ge_4_53_1: |
| |
| from transformers.masking_utils import create_masks_for_generate |
| else: |
| raise ImportError("Your transformers version is too old, please upgrade it to >= 4.53.1") |
|
|
| |
| |
| base_model = getattr(model, model.base_model_prefix, model) |
| decoder = base_model.get_decoder() if hasattr(base_model, "get_decoder") else None |
| causal_mask_creation_function = getattr(base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None) |
| if causal_mask_creation_function is None and decoder is not None: |
| causal_mask_creation_function = getattr(decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None) |
|
|
| |
| if causal_mask_creation_function is None: |
| token_type_ids = getattr(model_input, "token_type_ids", None) |
| |
| causal_mask_creation_function = getattr(model, "create_masks_for_generate", create_masks_for_generate) |
| attention_mask = causal_mask_creation_function( |
| config=model.config, |
| |
| input_embeds=torch.empty((batch_size, sequence_length), dtype=model.dtype), |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| ) |
| else: |
| attention_mask = causal_mask_creation_function( |
| attention_mask, |
| sequence_length=sequence_length, |
| target_length=past_key_values.get_max_cache_shape(), |
| dtype=model.dtype, |
| cache_position=cache_position, |
| batch_size=batch_size, |
| config=model.config, |
| past_key_values=past_key_values, |
| position_ids=position_ids, |
| ) |
| return attention_mask |
|
|
|
|
| def _get_module_names_tied_with_embedding(model) -> list[str]: |
| """ |
| Get the list of the fully qualified names of the modules that are tied to the input embeddings. In case of a |
| source-target-mapping `_tied_weights_keys`, it will attempt to identify the input embedding weights from the |
| mapping and return the list of tied modules accordingly. This gives a unified interface to both transformers v4 |
| tied weights and v5 mapped tied weights. |
| |
| For example: For models which have `embed_tokens` and `lm_head` as the tied keys, this function will return |
| [`lm_head`]. The PEFT model is assumed to be transparent: returned names will be relative to the base model, so |
| even though `model.base_model.lm_head` is tied, the returned name is `lm_head` since such attributes are forwarded |
| to the base model anyway. Non-transformer models have to provide a `_tied_weights_keys` attribute for this function |
| to work. |
| |
| Note that this function will not check if weight tying is disabled by the model's config. There can be the case |
| that the weight tying definition is present but the tying is disabled via `model_config.tie_word_embeddings=False`. |
| You have to check that yourself. |
| """ |
| tied_weights = [] |
|
|
| if hasattr(model, "get_base_model"): |
| |
| model = model.get_base_model() |
|
|
| if hasattr(model, "tuner_layer_cls"): |
| |
| model = model.model |
|
|
| if not hasattr(model, "_tied_weights_keys"): |
| return [] |
|
|
| base_layer_pattern = re.compile(r"[^.]+\.base_layer\.") |
|
|
| if isinstance(model._tied_weights_keys, dict): |
| if not hasattr(model, "get_input_embeddings"): |
| raise ValueError( |
| "The supplied model implements `_tied_weights_keys` as a dict but doesn't implement " |
| "'get_input_embeddings' so we can't determine which weights are tied to embeddings." |
| ) |
|
|
| |
| |
| |
| |
| input_embedding_params = set(model.get_input_embeddings().parameters()) |
| candidates = [n for n, p in model.named_parameters(remove_duplicate=False) if p in input_embedding_params] |
|
|
| |
| |
| |
| peft_reverse_mapping = {base_layer_pattern.sub("", name): name for name in candidates} |
|
|
| |
| |
| peft_reverse_mapping.update(**{name.replace("base_layer.", ""): name for name in candidates}) |
|
|
| tied_weights.extend( |
| peft_reverse_mapping.get(k, k) |
| for k, v in model._tied_weights_keys.items() |
| if peft_reverse_mapping.get(v, v) in candidates |
| ) |
|
|
| elif model._tied_weights_keys is not None: |
| |
| tied_weights.extend(model._tied_weights_keys) |
|
|
| return sorted({name.rpartition(".")[0] for name in tied_weights}) |
|
|