| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from __future__ import annotations |
|
|
| from operator import attrgetter |
|
|
| import torch |
|
|
| from peft.config import PeftConfig |
| from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING |
|
|
| from .constants import PEFT_TYPE_TO_PREFIX_MAPPING |
| from .other import infer_device |
| from .peft_types import PeftType |
| from .save_and_load import _insert_adapter_name_into_state_dict, load_peft_weights |
|
|
|
|
| |
| CONFIG_KEYS_TO_CHECK = {PeftType.LORA: ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]} |
|
|
|
|
| def hotswap_adapter_from_state_dict(model, state_dict, adapter_name, parameter_prefix="lora_"): |
| """ |
| Swap out the adapter weights from the model with the weights from state_dict. |
| |
| As of now, only LoRA is supported. |
| |
| This is a low-level function that assumes that the adapters have been checked for compatibility and that the |
| state_dict has been correctly mapped to work with PEFT. For a high level function that performs this work for you, |
| use `hotswap_adapter` instead. |
| |
| Args: |
| model (`nn.Module`): |
| The model with the loaded adapter. |
| state_dict (`dict[str, torch.Tensor]`): |
| The state dict of the new adapter, which needs to be compatible (targeting same modules etc.). |
| adapter_name (`str`): |
| The name of the adapter that should be hot-swapped, e.g. `"default"`. The name will remain the same after |
| swapping. |
| parameter_prefix (`str`, *optional*, defaults to `"lora_"`) |
| The prefix used to identify the adapter's keys in the state dict. For LoRA, this would be `"lora_"` (the |
| default). |
| |
| Raises: |
| RuntimeError |
| If the old and the new adapter are not compatible, a RuntimeError is raised. |
| |
| """ |
| |
| |
|
|
| is_compiled = hasattr(model, "_orig_mod") |
| |
| missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)} |
| unexpected_keys = set() |
|
|
| |
| for key, new_val in state_dict.items(): |
| try: |
| old_val = attrgetter(key)(model) |
| except AttributeError: |
| unexpected_keys.add(key) |
| continue |
|
|
| if is_compiled: |
| missing_keys.remove("_orig_mod." + key) |
| else: |
| missing_keys.remove(key) |
|
|
| if missing_keys or unexpected_keys: |
| msg = "Hot swapping the adapter did not succeed." |
| if missing_keys: |
| msg += f" Missing keys: {', '.join(sorted(missing_keys))}." |
| if unexpected_keys: |
| msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}." |
| raise RuntimeError(msg) |
|
|
| |
| for key, new_val in state_dict.items(): |
| |
| old_val = attrgetter(key)(model) |
| if is_compiled: |
| |
| |
| old_val.data = new_val.data |
| else: |
| torch.utils.swap_tensors(old_val, new_val) |
|
|
|
|
| def _check_hotswap_configs_compatible(config0: PeftConfig, config1: PeftConfig) -> None: |
| """ |
| Check if two configs are compatible for hot-swapping. |
| |
| Only LoRA parameters are checked for now. |
| |
| To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they use |
| different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the weights |
| from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these values as |
| well, but that's not implemented yet, and we need to be careful not to trigger re-compilation if the model is |
| compiled (so no modification of the dict). |
| |
| """ |
|
|
| if config0.peft_type != config1.peft_type: |
| msg = f"Incompatible PEFT types found: {config0.peft_type.value} and {config1.peft_type.value}" |
| raise ValueError(msg) |
|
|
| if config0.peft_type not in CONFIG_KEYS_TO_CHECK: |
| msg = ( |
| f"Hotswapping only supports {', '.join(CONFIG_KEYS_TO_CHECK.keys())} but " |
| f"{config0.peft_type.value} was passed." |
| ) |
| raise ValueError(msg) |
| config_keys_to_check = CONFIG_KEYS_TO_CHECK[config0.peft_type] |
|
|
| |
| |
| config0 = config0.to_dict() |
| config1 = config1.to_dict() |
| sentinel = object() |
| for key in config_keys_to_check: |
| val0 = config0.get(key, sentinel) |
| val1 = config1.get(key, sentinel) |
| if val0 != val1: |
| raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}") |
|
|
|
|
| def hotswap_adapter(model, model_name_or_path, adapter_name, torch_device=None, **kwargs): |
| """Substitute old adapter data with new adapter data, keeping the rest the same. |
| |
| As of now, only LoRA is supported. |
| |
| This function is useful when you want to replace the loaded adapter with a new adapter. The adapter name will |
| remain the same, but the weights and other parameters will be swapped out. |
| |
| If the adapters are incomptabile, e.g. targeting different layers or having different alpha values, an error will |
| be raised. |
| |
| Example: |
| |
| ```py |
| >>> import torch |
| >>> from transformers import AutoModelForCausalLM |
| >>> from peft import PeftModel |
| >>> from peft.utils.hotswap import hotswap_adapter |
| |
| >>> model_id = ... |
| >>> inputs = ... |
| >>> device = ... |
| >>> model = AutoModelForCausalLM.from_pretrained(model_id).to(device) |
| |
| >>> # load lora 0 |
| >>> model = PeftModel.from_pretrained(model, "path-adapter-0") |
| >>> model = torch.compile(model) # optionally compile the model |
| >>> with torch.inference_mode(): |
| ... output_adapter_0 = model(inputs) |
| |
| >>> # replace the "default" lora adapter with the new one |
| >>> hotswap_adapter(model, "path-adapter-1", adapter_name="default", torch_device=device) |
| >>> with torch.inference_mode(): |
| ... output_adapter_1 = model(inputs).logits |
| ``` |
| |
| Args: |
| model ([`~PeftModel`]): |
| The PEFT model with the loaded adapter. |
| model_name_or_path (`str`): |
| The name or path of the model to load the new adapter from. |
| adapter_name (`str`): |
| The name of the adapter to swap, e.g. `"default"`. The name will stay the same after swapping. |
| torch_device: (`str`, *optional*, defaults to None): |
| The device to load the new adapter onto. |
| **kwargs (`optional`): |
| Additional keyword arguments used for loading the config and weights. |
| |
| """ |
| if torch_device is None: |
| torch_device = infer_device() |
|
|
| |
| |
| |
|
|
| config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[ |
| PeftConfig._get_peft_type( |
| model_name_or_path, |
| subfolder=kwargs.get("subfolder", None), |
| revision=kwargs.get("revision", None), |
| cache_dir=kwargs.get("cache_dir", None), |
| use_auth_token=kwargs.get("use_auth_token", None), |
| token=kwargs.get("token", None), |
| ) |
| ] |
| config = config_cls.from_pretrained(model_name_or_path, **kwargs) |
| |
| _check_hotswap_configs_compatible(model.active_peft_config, config) |
|
|
| state_dict = load_peft_weights(model_name_or_path, device=torch_device, **kwargs) |
|
|
| |
| |
| |
|
|
| parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] |
| peft_model_state_dict = _insert_adapter_name_into_state_dict( |
| state_dict, adapter_name=adapter_name, parameter_prefix=parameter_prefix |
| ) |
|
|
| hotswap_adapter_from_state_dict( |
| model=model, |
| state_dict=peft_model_state_dict, |
| adapter_name=adapter_name, |
| parameter_prefix=parameter_prefix, |
| ) |
|
|