| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| PEFT utilities: Utilities related to peft library |
| """ |
|
|
| import collections |
| import functools |
| import importlib |
|
|
| from packaging import version |
|
|
| from . import logging |
| from .import_utils import is_peft_available, is_peft_version, is_torch_available |
| from .torch_utils import empty_device_cache |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| if is_torch_available(): |
| import torch |
|
|
|
|
| def recurse_remove_peft_layers(model): |
| r""" |
| Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`. |
| """ |
| from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
| has_base_layer_pattern = False |
| for module in model.modules(): |
| if isinstance(module, BaseTunerLayer): |
| has_base_layer_pattern = hasattr(module, "base_layer") |
| break |
|
|
| if has_base_layer_pattern: |
| from peft.utils import _get_submodules |
|
|
| key_list = [key for key, _ in model.named_modules() if "lora" not in key] |
| for key in key_list: |
| try: |
| parent, target, target_name = _get_submodules(model, key) |
| except AttributeError: |
| continue |
| if hasattr(target, "base_layer"): |
| setattr(parent, target_name, target.get_base_layer()) |
| else: |
| |
| |
| from peft.tuners.lora import LoraLayer |
|
|
| for name, module in model.named_children(): |
| if len(list(module.children())) > 0: |
| |
| recurse_remove_peft_layers(module) |
|
|
| module_replaced = False |
|
|
| if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear): |
| new_module = torch.nn.Linear( |
| module.in_features, |
| module.out_features, |
| bias=module.bias is not None, |
| ).to(module.weight.device) |
| new_module.weight = module.weight |
| if module.bias is not None: |
| new_module.bias = module.bias |
|
|
| module_replaced = True |
| elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d): |
| new_module = torch.nn.Conv2d( |
| module.in_channels, |
| module.out_channels, |
| module.kernel_size, |
| module.stride, |
| module.padding, |
| module.dilation, |
| module.groups, |
| ).to(module.weight.device) |
|
|
| new_module.weight = module.weight |
| if module.bias is not None: |
| new_module.bias = module.bias |
|
|
| module_replaced = True |
|
|
| if module_replaced: |
| setattr(model, name, new_module) |
| del module |
|
|
| empty_device_cache() |
| return model |
|
|
|
|
| def scale_lora_layers(model, weight): |
| """ |
| Adjust the weightage given to the LoRA layers of the model. |
| |
| Args: |
| model (`torch.nn.Module`): |
| The model to scale. |
| weight (`float`): |
| The weight to be given to the LoRA layers. |
| """ |
| from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
| if weight == 1.0: |
| return |
|
|
| for module in model.modules(): |
| if isinstance(module, BaseTunerLayer): |
| module.scale_layer(weight) |
|
|
|
|
| def unscale_lora_layers(model, weight: float | None = None): |
| """ |
| Removes the previously passed weight given to the LoRA layers of the model. |
| |
| Args: |
| model (`torch.nn.Module`): |
| The model to scale. |
| weight (`float`, *optional*): |
| The weight to be given to the LoRA layers. If no scale is passed the scale of the lora layer will be |
| re-initialized to the correct value. If 0.0 is passed, we will re-initialize the scale with the correct |
| value. |
| """ |
| from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
| if weight is None or weight == 1.0: |
| return |
|
|
| for module in model.modules(): |
| if isinstance(module, BaseTunerLayer): |
| if weight != 0: |
| module.unscale_layer(weight) |
| else: |
| for adapter_name in module.active_adapters: |
| |
| module.set_scale(adapter_name, 1.0) |
|
|
|
|
| def get_peft_kwargs( |
| rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None |
| ): |
| rank_pattern = {} |
| alpha_pattern = {} |
| r = lora_alpha = list(rank_dict.values())[0] |
|
|
| if len(set(rank_dict.values())) > 1: |
| |
| r = collections.Counter(rank_dict.values()).most_common()[0][0] |
|
|
| |
| rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) |
| rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} |
|
|
| if network_alpha_dict is not None and len(network_alpha_dict) > 0: |
| if len(set(network_alpha_dict.values())) > 1: |
| |
| lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] |
|
|
| |
| alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) |
| if is_unet: |
| alpha_pattern = { |
| ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v |
| for k, v in alpha_pattern.items() |
| } |
| else: |
| alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} |
| else: |
| lora_alpha = set(network_alpha_dict.values()).pop() |
|
|
| target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) |
| use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) |
| |
| lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict) |
|
|
| lora_config_kwargs = { |
| "r": r, |
| "lora_alpha": lora_alpha, |
| "rank_pattern": rank_pattern, |
| "alpha_pattern": alpha_pattern, |
| "target_modules": target_modules, |
| "use_dora": use_dora, |
| "lora_bias": lora_bias, |
| } |
|
|
| return lora_config_kwargs |
|
|
|
|
| def get_adapter_name(model): |
| from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
| for module in model.modules(): |
| if isinstance(module, BaseTunerLayer): |
| return f"default_{len(module.r)}" |
| return "default_0" |
|
|
|
|
| def set_adapter_layers(model, enabled=True): |
| from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
| for module in model.modules(): |
| if isinstance(module, BaseTunerLayer): |
| |
| if hasattr(module, "enable_adapters"): |
| module.enable_adapters(enabled=enabled) |
| else: |
| module.disable_adapters = not enabled |
|
|
|
|
| def delete_adapter_layers(model, adapter_name): |
| from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
| for module in model.modules(): |
| if isinstance(module, BaseTunerLayer): |
| if hasattr(module, "delete_adapter"): |
| module.delete_adapter(adapter_name) |
| else: |
| raise ValueError( |
| "The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1" |
| ) |
|
|
| |
| if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"): |
| model.peft_config.pop(adapter_name, None) |
| |
| |
| if len(model.peft_config) == 0: |
| del model.peft_config |
| model._hf_peft_config_loaded = None |
|
|
|
|
| def set_weights_and_activate_adapters(model, adapter_names, weights): |
| from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
| def get_module_weight(weight_for_adapter, module_name): |
| if not isinstance(weight_for_adapter, dict): |
| |
| return weight_for_adapter |
|
|
| for layer_name, weight_ in weight_for_adapter.items(): |
| if layer_name in module_name: |
| return weight_ |
|
|
| parts = module_name.split(".") |
| |
| key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}" |
| block_weight = weight_for_adapter.get(key, 1.0) |
|
|
| return block_weight |
|
|
| for module_name, module in model.named_modules(): |
| if isinstance(module, BaseTunerLayer): |
| |
| if hasattr(module, "set_adapter"): |
| module.set_adapter(adapter_names) |
| else: |
| module.active_adapter = adapter_names |
|
|
| |
| for adapter_name, weight in zip(adapter_names, weights): |
| module.set_scale(adapter_name, get_module_weight(weight, module_name)) |
|
|
|
|
| def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"): |
| """ |
| Decorator to automatically handle LoRA layer scaling/unscaling in forward methods. |
| |
| This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward |
| pass, and ensures unscaling happens after, even if an exception occurs. |
| |
| Args: |
| kwargs_name (`str`, defaults to `"joint_attention_kwargs"`): |
| The name of the keyword argument that contains the LoRA scale. Common values include |
| "joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc. |
| """ |
|
|
| def decorator(forward_fn): |
| @functools.wraps(forward_fn) |
| def wrapper(self, *args, **kwargs): |
| from . import USE_PEFT_BACKEND |
|
|
| lora_scale = 1.0 |
| attention_kwargs = kwargs.get(kwargs_name) |
|
|
| if attention_kwargs is not None: |
| attention_kwargs = attention_kwargs.copy() |
| kwargs[kwargs_name] = attention_kwargs |
| lora_scale = attention_kwargs.pop("scale", 1.0) |
|
|
| if not USE_PEFT_BACKEND and lora_scale != 1.0: |
| logger.warning( |
| f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective." |
| ) |
|
|
| |
| if USE_PEFT_BACKEND: |
| scale_lora_layers(self, lora_scale) |
|
|
| try: |
| |
| result = forward_fn(self, *args, **kwargs) |
| return result |
| finally: |
| |
| if USE_PEFT_BACKEND: |
| unscale_lora_layers(self, lora_scale) |
|
|
| return wrapper |
|
|
| return decorator |
|
|
|
|
| def check_peft_version(min_version: str) -> None: |
| r""" |
| Checks if the version of PEFT is compatible. |
| |
| Args: |
| version (`str`): |
| The version of PEFT to check against. |
| """ |
| if not is_peft_available(): |
| raise ValueError("PEFT is not installed. Please install it with `pip install peft`") |
|
|
| is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) > version.parse(min_version) |
|
|
| if not is_peft_version_compatible: |
| raise ValueError( |
| f"The version of PEFT you are using is not compatible, please use a version that is greater" |
| f" than {min_version}" |
| ) |
|
|
|
|
| def _create_lora_config( |
| state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None |
| ): |
| from peft import LoraConfig |
|
|
| if metadata is not None: |
| lora_config_kwargs = metadata |
| else: |
| lora_config_kwargs = get_peft_kwargs( |
| rank_pattern_dict, |
| network_alpha_dict=network_alphas, |
| peft_state_dict=state_dict, |
| is_unet=is_unet, |
| model_state_dict=model_state_dict, |
| adapter_name=adapter_name, |
| ) |
|
|
| _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs) |
|
|
| |
| if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]: |
| if is_peft_version("<", "0.9.0"): |
| raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.") |
|
|
| if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]: |
| if is_peft_version("<=", "0.13.2"): |
| raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.") |
|
|
| try: |
| return LoraConfig(**lora_config_kwargs) |
| except TypeError as e: |
| raise TypeError("`LoraConfig` class could not be instantiated.") from e |
|
|
|
|
| def _maybe_raise_error_for_ambiguous_keys(config): |
| rank_pattern = config["rank_pattern"].copy() |
| target_modules = config["target_modules"] |
|
|
| for key in list(rank_pattern.keys()): |
| |
| |
| |
| |
| |
| exact_matches = [mod for mod in target_modules if mod == key] |
| substring_matches = [mod for mod in target_modules if key in mod and mod != key] |
|
|
| if exact_matches and substring_matches: |
| if is_peft_version("<", "0.14.1"): |
| raise ValueError( |
| "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." |
| ) |
|
|
|
|
| def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): |
| warn_msg = "" |
| if incompatible_keys is not None: |
| |
| unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) |
| if unexpected_keys: |
| lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] |
| if lora_unexpected_keys: |
| warn_msg = ( |
| f"Loading adapter weights from state_dict led to unexpected keys found in the model:" |
| f" {', '.join(lora_unexpected_keys)}. " |
| ) |
|
|
| |
| missing_keys = getattr(incompatible_keys, "missing_keys", None) |
| if missing_keys: |
| lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] |
| if lora_missing_keys: |
| warn_msg += ( |
| f"Loading adapter weights from state_dict led to missing keys in the model:" |
| f" {', '.join(lora_missing_keys)}." |
| ) |
|
|
| if warn_msg: |
| logger.warning(warn_msg) |
|
|