| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from __future__ import annotations |
|
|
| import math |
| import warnings |
| from operator import attrgetter |
| from typing import Literal, Optional |
|
|
| import torch |
|
|
| from peft.config import PeftConfig |
| from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING, PEFT_TYPE_TO_PREFIX_MAPPING |
| from peft.tuners.lora import Conv2d, Linear, LoraConfig, LoraLayer |
|
|
| from .other import get_pattern_key, infer_device |
| from .peft_types import PeftType |
| from .save_and_load import _insert_adapter_name_into_state_dict, load_peft_weights |
|
|
|
|
| |
| CONFIG_KEYS_TO_CHECK = {PeftType.LORA: ["use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]} |
|
|
|
|
| def _update_scaling(lora_module, adapter_name, scaling=None): |
| """ |
| Update the value of the scalings of the LoRA module. |
| |
| Takes into consideration that scalings can be tensors from prepare_model_for_compiled_hotswap. |
| """ |
| if lora_module.scaling[adapter_name] == scaling: |
| return |
|
|
| if isinstance(lora_module.scaling[adapter_name], torch.Tensor): |
| lora_module.scaling[adapter_name].fill_(scaling) |
| elif isinstance(lora_module.scaling[adapter_name], (float, int)): |
| lora_module.scaling[adapter_name] = scaling |
| else: |
| raise ValueError( |
| "Something went wrong when trying to set the new scale value, expected to find the old value to be of type " |
| f"float or torch.Tensor, got {type(lora_module.scaling[adapter_name])} instead." |
| ) |
|
|
|
|
| def _convert_scalings_to_tensor(model) -> bool: |
| """ |
| Convert the LoRA scaling values into torch.tensors to prevent recompilation if they change. |
| |
| Returns: |
| bool: |
| Returns `True` if an appropriate adapter was found, else `False`. |
| """ |
| found_adapter = False |
| for module in model.modules(): |
| if not isinstance(module, LoraLayer): |
| continue |
|
|
| found_adapter = True |
| scaling = module.scaling |
| for key, val in scaling.items(): |
| if isinstance(val, float): |
| |
| scaling[key] = torch.tensor(val, device=module.weight.device) |
| elif not isinstance(val, torch.Tensor): |
| raise ValueError( |
| "Something went wrong while trying to convert the scalings, expected to find values of type float " |
| f"but found {type(val)} instead." |
| ) |
| return found_adapter |
|
|
|
|
| def _get_padded_linear(lora_module: torch.nn.Module, target_rank: int, is_lora_A: bool) -> torch.nn.Linear: |
| """ |
| Get a new Linear layer for LoRA with padded weights according to the target rank. |
| |
| Args: |
| lora_module (nn.Module): |
| The LoRA sub-module (e.g. module.lora_A[adapter_name]). |
| target_rank (int): |
| The desired rank to pad to. |
| is_lora_A (bool): |
| True if this is the LoRA A matrix, False if LoRA B. |
| |
| Returns: |
| nn.Linear: |
| A newly created and padded Linear layer. If the rank already fit, the original layer is returned. |
| """ |
| weight = lora_module.weight |
| |
| |
| original_rank = weight.size(0) if is_lora_A else weight.size(1) |
|
|
| |
| if original_rank == target_rank: |
| return lora_module |
|
|
| if original_rank > target_rank: |
| raise ValueError( |
| f"Trying to pad the adapter to the target rank {target_rank}, but the original rank is larger " |
| f"({original_rank}). This is not possible." |
| ) |
|
|
| out_features, in_features = weight.shape |
|
|
| |
| if is_lora_A: |
| |
| padded = torch.zeros(target_rank, in_features, device=weight.device, dtype=weight.dtype) |
| padded[:original_rank, :] = weight |
| new_layer = torch.nn.Linear(in_features, target_rank, bias=lora_module.bias is not None) |
| else: |
| |
| padded = torch.zeros(out_features, target_rank, device=weight.device, dtype=weight.dtype) |
| padded[:, :original_rank] = weight |
| new_layer = torch.nn.Linear(target_rank, out_features, bias=lora_module.bias is not None) |
|
|
| |
| if new_layer.weight.shape != padded.shape: |
| raise ValueError( |
| "Something went wrong when trying to pad the LoRA Linear weights, the new shape should be " |
| f"{padded.shape} but {new_layer.weight.shape} was found. Please open an issue on PEFT " |
| "(https://github.com/huggingface/peft/issues) and report this error." |
| ) |
| if (lora_module.bias is not None) and (new_layer.bias.shape != lora_module.bias.shape): |
| raise ValueError( |
| "Something went wrong when trying to pad the LoRA Linear bias, the new shape should be " |
| f"{lora_module.bias.shape} but {new_layer.bias.shape} was found. Please open an issue on PEFT " |
| "(https://github.com/huggingface/peft/issues) and report this error." |
| ) |
|
|
| new_layer.weight.data = padded |
| |
| if lora_module.bias is not None: |
| new_layer.bias.data = lora_module.bias.data |
|
|
| return new_layer |
|
|
|
|
| def _get_padded_conv2d(lora_module: torch.nn.Module, target_rank: int, is_lora_A: bool) -> torch.nn.Conv2d: |
| """ |
| Get a new Conv2d layer for LoRA with padded weights according to the target rank. |
| |
| Args: |
| lora_module (nn.Module): |
| The LoRA sub-module (e.g. module.lora_A[adapter_name]). |
| target_rank (int): |
| The desired rank to pad to. |
| is_lora_A (bool): |
| True if this is the LoRA A matrix, False if LoRA B. |
| |
| Returns: |
| nn.Conv2d: |
| A newly created and padded Conv2d layer. If the rank already fit, the original layer is returned. |
| """ |
| weight = lora_module.weight |
| |
| out_channels, in_channels, kh, kw = weight.shape |
| original_rank = out_channels if is_lora_A else in_channels |
|
|
| if original_rank == target_rank: |
| return lora_module |
|
|
| if original_rank > target_rank: |
| raise ValueError( |
| f"Trying to pad the adapter to the target rank {target_rank}, but the original rank is larger " |
| f"({original_rank}). This is not possible." |
| ) |
|
|
| |
| if is_lora_A: |
| |
| padded = torch.zeros(target_rank, in_channels, kh, kw, device=weight.device, dtype=weight.dtype) |
| padded[:out_channels, :, :, :] = weight |
| new_layer = torch.nn.Conv2d( |
| in_channels, |
| target_rank, |
| kernel_size=lora_module.kernel_size, |
| stride=lora_module.stride, |
| padding=lora_module.padding, |
| bias=lora_module.bias is not None, |
| groups=lora_module.groups, |
| ) |
| else: |
| |
| padded = torch.zeros(out_channels, target_rank, kh, kw, device=weight.device, dtype=weight.dtype) |
| padded[:, :in_channels, :, :] = weight |
| new_layer = torch.nn.Conv2d( |
| target_rank, |
| out_channels, |
| kernel_size=lora_module.kernel_size, |
| stride=lora_module.stride, |
| padding=lora_module.padding, |
| bias=lora_module.bias is not None, |
| groups=lora_module.groups, |
| ) |
|
|
| |
| if new_layer.weight.shape != padded.shape: |
| raise ValueError( |
| "Something went wrong when trying to pad the LoRA weights, the new shape should be " |
| f"{padded.shape} but {new_layer.weight.shape} was found. Please open an issue on PEFT " |
| "(https://github.com/huggingface/peft/issues) and report this error." |
| ) |
| if (lora_module.bias is not None) and (new_layer.bias.shape != lora_module.bias.shape): |
| raise ValueError( |
| "Something went wrong when trying to pad the LoRA Conv2d bias, the new shape should be " |
| f"{lora_module.bias.shape} but {new_layer.bias.shape} was found. Please open an issue on PEFT " |
| "(https://github.com/huggingface/peft/issues) and report this error." |
| ) |
|
|
| new_layer.weight.data = padded |
| |
| if lora_module.bias is not None: |
| new_layer.bias.data = lora_module.bias.data |
|
|
| return new_layer |
|
|
|
|
| def _pad_lora_weights(model: torch.nn.Module, target_rank: int) -> bool: |
| """ |
| Pad LoRA weights in a model to a target rank while preserving the original behavior. |
| |
| Args: |
| model (nn.Module): The model containing LoRA modules (with lora_A and lora_B). |
| target_rank (int): The target rank to pad to. |
| |
| Returns: |
| bool: |
| Returns `True` if an appropriate adapter was found, else `False`. |
| """ |
| found_adapter = False |
|
|
| for module in model.modules(): |
| |
| if isinstance(module, Linear): |
| pad_fn = _get_padded_linear |
| elif isinstance(module, Conv2d): |
| pad_fn = _get_padded_conv2d |
| else: |
| |
| continue |
|
|
| |
| for adapter_name, lora_A_module in module.lora_A.items(): |
| new_layer = pad_fn(lora_A_module, target_rank=target_rank, is_lora_A=True) |
| module.lora_A[adapter_name] = new_layer |
|
|
| |
| for adapter_name, lora_B_module in module.lora_B.items(): |
| new_layer = pad_fn(lora_B_module, target_rank=target_rank, is_lora_A=False) |
| module.lora_B[adapter_name] = new_layer |
|
|
| found_adapter = True |
| return found_adapter |
|
|
|
|
| def prepare_model_for_compiled_hotswap( |
| model: torch.nn.Module, |
| *, |
| target_rank: Optional[int] = None, |
| config: Optional[LoraConfig | dict[str, LoraConfig]] = None, |
| check_compiled: Literal["error", "warn", "ignore"] = "error", |
| ) -> None: |
| """ |
| Helper function that prepares the model so that it can later be compiled and then used with hot-swapping. |
| |
| It is necessary to call this function on the model for hot-swapping to work if both of these are true: |
| |
| - the different LoRA adapters have different ranks and/or different alpha values (i.e. scalings) |
| - you plan to torch.compile the model and want to avoid re-compilation |
| |
| It is important to call this function *after* the first LoRA adapter has been loaded (i.e. the one that will be |
| swapped out) but *before* the model is compiled. |
| |
| Even with this function, hot-swapping LoRA adapters that target different layers is still not supported. |
| |
| Note: This function modifies the model in-place. If you want to restore the model to its initial state, you will |
| have to reload it. |
| |
| Args: |
| model (`nn.Module`): |
| The model with the loaded adapter, before compilation. |
| target_rank (`int`, *optional*): |
| The target rank to pad the LoRA weights to. Should be the maximum rank among all LoRA adapters that will be |
| hot-swapped. If not specified, the target ranks will not be changed. |
| config (`LoraConfig` or `dict[str, LoraConfig]`, *optional*): |
| Optionally pass the `LoraConfig`s of the LoRA adapters. If passed, the rank in the configs will be updated |
| to `target_rank`. |
| check_compiled (`str`, *optional*, defaults to `"error"`): |
| How to handle the case when the model is already compiled, which should generally be avoided. The options |
| are: |
| - "error" (default): raise an error |
| - "warn": issue a warning |
| - "ignore": do nothing |
| |
| Raises: |
| ValueError |
| If the model is already compiled or if no adpater layer was found, raise an error. |
| |
| Example: |
| |
| ```py |
| base_model = ... |
| model = PeftModel.from_pretrained(base_model, path_adapter_0) |
| # Prepare the model to allow hotswapping even if ranks/scalings of 2nd adapter differ. |
| # You can skip this step if all ranks and scalings are identical. |
| prepare_model_for_compiled_hotswap(model, target_rank=highest_lora_rank) |
| model = torch.compile(model) |
| # do inference with adapter 0 |
| # replace the "default" lora adapter with the new one |
| hotswap_adapter(model, path_adapter_1, adapter_name="default", torch_device=device) |
| # do inference with adapter 1 |
| ``` |
| |
| """ |
| is_compiled = hasattr(model, "_orig_mod") or getattr(model, "_compiled_call_impl", False) |
| if is_compiled: |
| if check_compiled == "error": |
| raise ValueError("Call prepare_model_for_compiled_hotswap *before* compiling the model") |
| elif check_compiled == "warn": |
| warnings.warn( |
| "prepare_model_for_compiled_hotswap was called with a model that is already compiled. This will likely " |
| "result in re-compilation, hurting performance. Call the function before compiling the model." |
| ) |
| elif check_compiled != "ignore": |
| raise ValueError( |
| f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead." |
| ) |
|
|
| conversion_found_adapter = _convert_scalings_to_tensor(model) |
| if target_rank is not None: |
| padding_found_adapter = _pad_lora_weights(model, target_rank=target_rank) |
| else: |
| padding_found_adapter = False |
|
|
| if not (conversion_found_adapter or padding_found_adapter): |
| raise ValueError( |
| "No adapter layers found on the model, make sure call `prepare_model_for_compiled_hotswap` after loading " |
| "the first adapter and before loading the second adapter." |
| ) |
|
|
| if not config: |
| return |
| if target_rank is None: |
| return |
|
|
| if not isinstance(config, dict): |
| |
| config = {"dummy": config} |
|
|
| for lora_config in config.values(): |
| lora_config.r = target_rank |
| if lora_config.rank_pattern: |
| for key in lora_config.rank_pattern: |
| lora_config.rank_pattern[key] = target_rank |
|
|
|
|
| def hotswap_adapter_from_state_dict( |
| model: torch.nn.Module, |
| state_dict: dict[str, torch.Tensor], |
| adapter_name: str, |
| config: LoraConfig, |
| parameter_prefix: str = "lora_", |
| ): |
| """ |
| Swap out the adapter weights from the model with the weights from state_dict. |
| |
| As of now, only LoRA is supported. |
| |
| This is a low-level function that assumes that the adapters have been checked for compatibility and that the |
| state_dict has been correctly mapped to work with PEFT. For a high level function that performs this work for you, |
| use `hotswap_adapter` instead. |
| |
| Args: |
| model (`nn.Module`): |
| The model with the loaded adapter. |
| state_dict (`dict[str, torch.Tensor]`): |
| The state dict of the new adapter, which needs to be compatible (targeting same modules etc.). |
| adapter_name (`str`): |
| The name of the adapter that should be hot-swapped, e.g. `"default"`. The name will remain the same after |
| swapping. |
| config (`LoraConfig`): |
| The config of the LoRA adapter. This is used to determine the scaling and rank of the adapter. |
| parameter_prefix (`str`, *optional*, defaults to `"lora_"`) |
| The prefix used to identify the adapter's keys in the state dict. For LoRA, this would be `"lora_"` (the |
| default). |
| |
| Raises: |
| RuntimeError |
| If the old and the new adapter are not compatible, a RuntimeError is raised. |
| |
| """ |
| |
| |
|
|
| |
| is_compiled = hasattr(model, "_orig_mod") |
| is_compiled_inplace = bool(getattr(model, "_compiled_call_impl", None)) |
| |
| missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)} |
| unexpected_keys = [] |
|
|
| |
| for key, new_val in state_dict.items(): |
| try: |
| old_val = attrgetter(key)(model) |
| except AttributeError: |
| unexpected_keys.append(key) |
| continue |
|
|
| if is_compiled: |
| missing_keys.remove("_orig_mod." + key) |
| else: |
| missing_keys.remove(key) |
|
|
| |
| |
| if unexpected_keys: |
| msg = f"Hot swapping the adapter did not succeed, unexpected keys found: {', '.join(unexpected_keys)}." |
| raise RuntimeError(msg) |
|
|
| |
| |
| |
| for key in missing_keys: |
| |
| key = key.removeprefix("_orig_mod.") |
| |
| module_name = ".".join(key.split(".")[:-3]) |
| module = model.get_submodule(module_name) |
| old_val = attrgetter(key)(model) |
| old_val.data.fill_(0.0) |
|
|
| |
| for key, new_val in state_dict.items(): |
| |
| module_name = ".".join(key.split(".")[:-3]) |
| module = model.get_submodule(module_name) |
|
|
| |
| r_key = get_pattern_key(config.rank_pattern.keys(), key) |
| alpha_key = get_pattern_key(config.alpha_pattern.keys(), key) |
| rank = config.rank_pattern.get(r_key, config.r) |
| alpha = config.alpha_pattern.get(alpha_key, config.lora_alpha) |
| if config.use_rslora: |
| scaling = alpha / math.sqrt(rank) |
| else: |
| scaling = alpha / rank |
| _update_scaling(module, adapter_name=adapter_name, scaling=scaling) |
|
|
| |
| |
| old_val = attrgetter(key)(model) |
| new_val = new_val.to(old_val.data.device) |
|
|
| |
| |
| |
| if not is_compiled and not is_compiled_inplace: |
| try: |
| torch.utils.swap_tensors(old_val, new_val) |
| continue |
| except RuntimeError: |
| is_compiled = True |
|
|
| |
| |
| if old_val.shape == new_val.shape: |
| |
| |
| |
| old_val.data.copy_(new_val.data) |
| else: |
| |
| if old_val.dim() not in (2, 4): |
| raise NotImplementedError( |
| f"Trying to hotswap an adapter whose weight has {old_val.dim()} dimensions, but only Conv2d and " |
| "Linear are supported" |
| ) |
|
|
| |
| if old_val.shape[0] > new_val.shape[0]: |
| old_val.data.fill_(0) |
| old_val.data[: new_val.shape[0]].copy_(new_val.data) |
| elif old_val.shape[1] > new_val.shape[1]: |
| old_val.data.fill_(0) |
| old_val.data[:, : new_val.shape[1]].copy_(new_val.data) |
| else: |
| raise ValueError( |
| f"Incompatible shapes found for LoRA weights {key}: {old_val.shape} vs {new_val.shape}. Please " |
| "ensure that all ranks are padded to the largest rank among all LoRA adapters by using " |
| "peft.utils.hotswap.prepare_model_for_compiled_hotswap." |
| ) |
|
|
|
|
| def check_hotswap_configs_compatible(config0: PeftConfig, config1: PeftConfig) -> None: |
| """ |
| Check if two configs are compatible for hot-swapping. |
| |
| Only LoRA parameters are checked for now. |
| |
| To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they use |
| different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the weights |
| from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these values as |
| well, but that's not implemented yet, and we need to be careful not to trigger re-compilation if the model is |
| compiled (so no modification of the dict). |
| |
| """ |
|
|
| if config0.peft_type != config1.peft_type: |
| msg = f"Incompatible PEFT types found: {config0.peft_type.value} and {config1.peft_type.value}" |
| raise ValueError(msg) |
|
|
| if config0.peft_type not in CONFIG_KEYS_TO_CHECK: |
| msg = ( |
| f"Hotswapping only supports {', '.join(CONFIG_KEYS_TO_CHECK.keys())} but " |
| f"{config0.peft_type.value} was passed." |
| ) |
| raise ValueError(msg) |
| config_keys_to_check = CONFIG_KEYS_TO_CHECK[config0.peft_type] |
|
|
| |
| |
| config0 = config0.to_dict() |
| config1 = config1.to_dict() |
| sentinel = object() |
| for key in config_keys_to_check: |
| val0 = config0.get(key, sentinel) |
| val1 = config1.get(key, sentinel) |
| if val0 != val1: |
| raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}") |
|
|
|
|
| def hotswap_adapter(model, model_name_or_path, adapter_name, torch_device=None, **kwargs): |
| """Substitute old adapter data with new adapter data, keeping the rest the same. |
| |
| As of now, only LoRA is supported. |
| |
| This function is useful when you want to replace the loaded adapter with a new adapter. The adapter name will |
| remain the same, but the weights and other parameters will be swapped out. |
| |
| If the adapters are incomptabile, e.g. targeting different layers or having different alpha values, an error will |
| be raised. |
| |
| Example: |
| |
| ```py |
| >>> import torch |
| >>> from transformers import AutoModelForCausalLM |
| >>> from peft import PeftModel |
| >>> from peft.utils.hotswap import hotswap_adapter |
| |
| >>> model_id = ... |
| >>> inputs = ... |
| >>> device = ... |
| >>> model = AutoModelForCausalLM.from_pretrained(model_id).to(device) |
| |
| >>> # load lora 0 |
| >>> model = PeftModel.from_pretrained(model, "path-adapter-0") |
| >>> model = torch.compile(model) # optionally compile the model |
| >>> with torch.inference_mode(): |
| ... output_adapter_0 = model(inputs) |
| |
| >>> # replace the "default" lora adapter with the new one |
| >>> hotswap_adapter(model, "path-adapter-1", adapter_name="default", torch_device=device) |
| >>> with torch.inference_mode(): |
| ... output_adapter_1 = model(inputs).logits |
| ``` |
| |
| Args: |
| model ([`~PeftModel`]): |
| The PEFT model with the loaded adapter. |
| model_name_or_path (`str`): |
| The name or path of the model to load the new adapter from. |
| adapter_name (`str`): |
| The name of the adapter to swap, e.g. `"default"`. The name will stay the same after swapping. |
| torch_device: (`str`, *optional*, defaults to None): |
| The device to load the new adapter onto. |
| **kwargs (`optional`): |
| Additional keyword arguments used for loading the config and weights. |
| |
| """ |
| if torch_device is None: |
| torch_device = infer_device() |
|
|
| |
| |
| |
| hf_kwargs = { |
| "subfolder": kwargs.get("subfolder", None), |
| "revision": kwargs.get("revision", None), |
| "cache_dir": kwargs.get("cache_dir", None), |
| "token": kwargs.get("token", None), |
| } |
| if use_auth_token := kwargs.get("use_auth_token", None): |
| hf_kwargs["use_auth_token"] = use_auth_token |
| config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[PeftConfig._get_peft_type(model_name_or_path, **hf_kwargs)] |
| config = config_cls.from_pretrained(model_name_or_path, **kwargs) |
| |
| check_hotswap_configs_compatible(model.active_peft_config, config) |
|
|
| state_dict = load_peft_weights(model_name_or_path, device=torch_device, **kwargs) |
|
|
| |
| |
| |
|
|
| parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] |
| peft_model_state_dict = _insert_adapter_name_into_state_dict( |
| state_dict, adapter_name=adapter_name, parameter_prefix=parameter_prefix |
| ) |
|
|
| hotswap_adapter_from_state_dict( |
| model=model, |
| state_dict=peft_model_state_dict, |
| adapter_name=adapter_name, |
| parameter_prefix=parameter_prefix, |
| config=config, |
| ) |
|
|