# coding=utf-8 # Original License: # Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .config import PeftType import warnings import torch def _find_mismatched_keys( model: torch.nn.Module, peft_model_state_dict: dict[str, torch.Tensor], ignore_mismatched_sizes: bool = False ) -> tuple[dict[str, torch.Tensor], list[tuple[str, tuple[int, ...], tuple[int, ...]]]]: if not ignore_mismatched_sizes: return peft_model_state_dict, [] mismatched = [] state_dict = model.state_dict() for key, tensor in peft_model_state_dict.items(): if key not in state_dict: continue # see https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/modeling_utils.py#L3858-L3864 if (state_dict[key].shape[-1] == 1) and (state_dict[key].numel() * 2 == tensor.numel()): # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size # differences. Without matching with module type or parameter type it seems like a practical way to detect # valid 4bit weights. continue if state_dict[key].shape != tensor.shape: mismatched.append((key, tensor.shape, state_dict[key].shape)) for key, _, _ in mismatched: del peft_model_state_dict[key] return peft_model_state_dict, mismatched def get_peft_model_state_dict(model, state_dict=None): """ Get the state dict of the Peft model. Args: model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP, the model should be the underlying model/unwrapped model (i.e. model.module). state_dict (`dict`, *optional*, defaults to `None`): The state dict of the model. If not provided, the state dict of the model will be used. """ if state_dict is None: state_dict = model.state_dict() if model.peft_config.peft_type == PeftType.LORA: # to_return = lora_state_dict(model, bias=model.peft_config.bias) # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` # to directly with the state dict which is necessary when using DeepSpeed or FSDP bias = model.peft_config.bias if bias == "none": to_return = {k: state_dict[k] for k in state_dict if "lora_" in k} elif bias == "all": to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} for k in state_dict: if "lora_" in k: to_return[k] = state_dict[k] bias_name = k.split("lora_")[0] + "bias" if bias_name in state_dict: to_return[bias_name] = state_dict[bias_name] else: raise NotImplementedError elif model.peft_config.peft_type == PeftType.BOTTLENECK: # return the state dict of the model with Bottleneck adapters bias = model.peft_config.bias if bias == "none": to_return = {k: state_dict[k] for k in state_dict if "adapter_" in k} elif bias == "all": to_return = {k: state_dict[k] for k in state_dict if "adapter_" in k or "bias" in k} elif bias == "adapter_only": to_return = {} for k in state_dict: if "adapter_" in k: to_return[k] = state_dict[k] bias_name = k.split("adapter_")[0] + "bias" if bias_name in state_dict: to_return[bias_name] = state_dict[bias_name] else: raise NotImplementedError elif model.peft_config.peft_type == PeftType.SAMA: bias = model.peft_config.bias if bias == "none": to_return = {k: state_dict[k] for k in state_dict if "_sama" in k} elif bias == "all": to_return = {k: state_dict[k] for k in state_dict if "_sama" in k or "bias" in k} elif bias == "sama_only": to_return = {} for k in state_dict: if "_sama" in k: to_return[k] = state_dict[k] bias_name = k.split("_sama")[0] + "bias" if bias_name in state_dict: to_return[bias_name] = state_dict[bias_name] else: raise NotImplementedError elif model.peft_config.is_prompt_learning: to_return = {} if model.peft_config.inference_mode: prompt_embeddings = model.prompt_encoder.embedding.weight else: prompt_embeddings = model.get_prompt_embedding_to_save() to_return["prompt_embeddings"] = prompt_embeddings else: raise NotImplementedError if model.modules_to_save is not None: for key, value in state_dict.items(): if any(module_name in key for module_name in model.modules_to_save): to_return[key] = value return to_return def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="default", ignore_mismatched_sizes: bool = False): """ Set the state dict of the Peft model. Args: model ([`PeftModel`]): The Peft model. peft_model_state_dict (`dict`): The state dict of the Peft model. adapter_name (`str`, *optional*, defaults to `"default"`): The name of the adapter whose state dict should be set. """ peft_model_state_dict, mismatched_keys = _find_mismatched_keys( model, peft_model_state_dict, ignore_mismatched_sizes=ignore_mismatched_sizes ) if mismatched_keys: # see https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/modeling_utils.py#L4039 mismatched_warning = "\n".join( [ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys ] ) msg = ( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint " f"and are being ignored because you passed `ignore_mismatched_sizes=True`: {mismatched_warning}." ) warnings.warn(msg) model.load_state_dict(peft_model_state_dict, strict=False) # if model.peft_config.peft_type != PeftType.LORA and model.peft_config.peft_type != PeftType.BOTTLENECK \ # and model.peft_config.peft_type != PeftType.SAMA: # model.prompt_encoder.embedding.load_state_dict( # {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True # ) if hasattr(model, "prompt_encoder") and model.prompt_encoder is not None: if "prompt_embeddings" in peft_model_state_dict: model.prompt_encoder.embedding.load_state_dict( {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True ) return model