| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from .config import PeftType |
| import warnings |
| import torch |
|
|
| def _find_mismatched_keys( |
| model: torch.nn.Module, peft_model_state_dict: dict[str, torch.Tensor], ignore_mismatched_sizes: bool = False |
| ) -> tuple[dict[str, torch.Tensor], list[tuple[str, tuple[int, ...], tuple[int, ...]]]]: |
| if not ignore_mismatched_sizes: |
| return peft_model_state_dict, [] |
|
|
| mismatched = [] |
| state_dict = model.state_dict() |
| for key, tensor in peft_model_state_dict.items(): |
| if key not in state_dict: |
| continue |
|
|
| |
| if (state_dict[key].shape[-1] == 1) and (state_dict[key].numel() * 2 == tensor.numel()): |
| |
| |
| |
| continue |
|
|
| if state_dict[key].shape != tensor.shape: |
| mismatched.append((key, tensor.shape, state_dict[key].shape)) |
|
|
| for key, _, _ in mismatched: |
| del peft_model_state_dict[key] |
|
|
| return peft_model_state_dict, mismatched |
|
|
| def get_peft_model_state_dict(model, state_dict=None): |
| """ |
| Get the state dict of the Peft model. |
| |
| Args: |
| model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP, |
| the model should be the underlying model/unwrapped model (i.e. model.module). |
| state_dict (`dict`, *optional*, defaults to `None`): |
| The state dict of the model. If not provided, the state dict of the model |
| will be used. |
| """ |
| if state_dict is None: |
| state_dict = model.state_dict() |
| if model.peft_config.peft_type == PeftType.LORA: |
| |
| |
| |
| bias = model.peft_config.bias |
| if bias == "none": |
| to_return = {k: state_dict[k] for k in state_dict if "lora_" in k} |
| elif bias == "all": |
| to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k} |
| elif bias == "lora_only": |
| to_return = {} |
| for k in state_dict: |
| if "lora_" in k: |
| to_return[k] = state_dict[k] |
| bias_name = k.split("lora_")[0] + "bias" |
| if bias_name in state_dict: |
| to_return[bias_name] = state_dict[bias_name] |
| else: |
| raise NotImplementedError |
| elif model.peft_config.peft_type == PeftType.BOTTLENECK: |
| |
| bias = model.peft_config.bias |
| if bias == "none": |
| to_return = {k: state_dict[k] for k in state_dict if "adapter_" in k} |
| elif bias == "all": |
| to_return = {k: state_dict[k] for k in state_dict if "adapter_" in k or "bias" in k} |
| elif bias == "adapter_only": |
| to_return = {} |
| for k in state_dict: |
| if "adapter_" in k: |
| to_return[k] = state_dict[k] |
| bias_name = k.split("adapter_")[0] + "bias" |
| if bias_name in state_dict: |
| to_return[bias_name] = state_dict[bias_name] |
| else: |
| raise NotImplementedError |
| |
| elif model.peft_config.peft_type == PeftType.SAMA: |
| bias = model.peft_config.bias |
| if bias == "none": |
| to_return = {k: state_dict[k] for k in state_dict if "_sama" in k} |
| elif bias == "all": |
| to_return = {k: state_dict[k] for k in state_dict if "_sama" in k or "bias" in k} |
| elif bias == "sama_only": |
| to_return = {} |
| for k in state_dict: |
| if "_sama" in k: |
| to_return[k] = state_dict[k] |
| bias_name = k.split("_sama")[0] + "bias" |
| if bias_name in state_dict: |
| to_return[bias_name] = state_dict[bias_name] |
| else: |
| raise NotImplementedError |
| |
| elif model.peft_config.is_prompt_learning: |
| to_return = {} |
| if model.peft_config.inference_mode: |
| prompt_embeddings = model.prompt_encoder.embedding.weight |
| else: |
| prompt_embeddings = model.get_prompt_embedding_to_save() |
| to_return["prompt_embeddings"] = prompt_embeddings |
| else: |
| raise NotImplementedError |
| if model.modules_to_save is not None: |
| for key, value in state_dict.items(): |
| if any(module_name in key for module_name in model.modules_to_save): |
| to_return[key] = value |
| return to_return |
|
|
|
|
| def set_peft_model_state_dict(model, peft_model_state_dict, |
| adapter_name="default", |
| ignore_mismatched_sizes: bool = False): |
| """ |
| Set the state dict of the Peft model. |
| |
| Args: |
| model ([`PeftModel`]): The Peft model. |
| peft_model_state_dict (`dict`): The state dict of the Peft model. |
| adapter_name (`str`, *optional*, defaults to `"default"`): |
| The name of the adapter whose state dict should be set. |
| """ |
| peft_model_state_dict, mismatched_keys = _find_mismatched_keys( |
| model, peft_model_state_dict, ignore_mismatched_sizes=ignore_mismatched_sizes |
| ) |
| if mismatched_keys: |
| |
| mismatched_warning = "\n".join( |
| [ |
| f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" |
| for key, shape1, shape2 in mismatched_keys |
| ] |
| ) |
| msg = ( |
| f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint " |
| f"and are being ignored because you passed `ignore_mismatched_sizes=True`: {mismatched_warning}." |
| ) |
| warnings.warn(msg) |
|
|
| model.load_state_dict(peft_model_state_dict, strict=False) |
| |
| |
| |
| |
| |
| if hasattr(model, "prompt_encoder") and model.prompt_encoder is not None: |
| if "prompt_embeddings" in peft_model_state_dict: |
| model.prompt_encoder.embedding.load_state_dict( |
| {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True |
| ) |
| return model |
|
|