Scalable_monarch_adapter / smpeft /utils /save_and_load.py
nvan13's picture
Upload folder using huggingface_hub
ecadbd9 verified
# coding=utf-8
# Original License:
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config import PeftType
import warnings
import torch
def _find_mismatched_keys(
model: torch.nn.Module, peft_model_state_dict: dict[str, torch.Tensor], ignore_mismatched_sizes: bool = False
) -> tuple[dict[str, torch.Tensor], list[tuple[str, tuple[int, ...], tuple[int, ...]]]]:
if not ignore_mismatched_sizes:
return peft_model_state_dict, []
mismatched = []
state_dict = model.state_dict()
for key, tensor in peft_model_state_dict.items():
if key not in state_dict:
continue
# see https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/modeling_utils.py#L3858-L3864
if (state_dict[key].shape[-1] == 1) and (state_dict[key].numel() * 2 == tensor.numel()):
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size
# differences. Without matching with module type or parameter type it seems like a practical way to detect
# valid 4bit weights.
continue
if state_dict[key].shape != tensor.shape:
mismatched.append((key, tensor.shape, state_dict[key].shape))
for key, _, _ in mismatched:
del peft_model_state_dict[key]
return peft_model_state_dict, mismatched
def get_peft_model_state_dict(model, state_dict=None):
"""
Get the state dict of the Peft model.
Args:
model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
the model should be the underlying model/unwrapped model (i.e. model.module).
state_dict (`dict`, *optional*, defaults to `None`):
The state dict of the model. If not provided, the state dict of the model
will be used.
"""
if state_dict is None:
state_dict = model.state_dict()
if model.peft_config.peft_type == PeftType.LORA:
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to directly with the state dict which is necessary when using DeepSpeed or FSDP
bias = model.peft_config.bias
if bias == "none":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
elif bias == "all":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
for k in state_dict:
if "lora_" in k:
to_return[k] = state_dict[k]
bias_name = k.split("lora_")[0] + "bias"
if bias_name in state_dict:
to_return[bias_name] = state_dict[bias_name]
else:
raise NotImplementedError
elif model.peft_config.peft_type == PeftType.BOTTLENECK:
# return the state dict of the model with Bottleneck adapters
bias = model.peft_config.bias
if bias == "none":
to_return = {k: state_dict[k] for k in state_dict if "adapter_" in k}
elif bias == "all":
to_return = {k: state_dict[k] for k in state_dict if "adapter_" in k or "bias" in k}
elif bias == "adapter_only":
to_return = {}
for k in state_dict:
if "adapter_" in k:
to_return[k] = state_dict[k]
bias_name = k.split("adapter_")[0] + "bias"
if bias_name in state_dict:
to_return[bias_name] = state_dict[bias_name]
else:
raise NotImplementedError
elif model.peft_config.peft_type == PeftType.SAMA:
bias = model.peft_config.bias
if bias == "none":
to_return = {k: state_dict[k] for k in state_dict if "_sama" in k}
elif bias == "all":
to_return = {k: state_dict[k] for k in state_dict if "_sama" in k or "bias" in k}
elif bias == "sama_only":
to_return = {}
for k in state_dict:
if "_sama" in k:
to_return[k] = state_dict[k]
bias_name = k.split("_sama")[0] + "bias"
if bias_name in state_dict:
to_return[bias_name] = state_dict[bias_name]
else:
raise NotImplementedError
elif model.peft_config.is_prompt_learning:
to_return = {}
if model.peft_config.inference_mode:
prompt_embeddings = model.prompt_encoder.embedding.weight
else:
prompt_embeddings = model.get_prompt_embedding_to_save()
to_return["prompt_embeddings"] = prompt_embeddings
else:
raise NotImplementedError
if model.modules_to_save is not None:
for key, value in state_dict.items():
if any(module_name in key for module_name in model.modules_to_save):
to_return[key] = value
return to_return
def set_peft_model_state_dict(model, peft_model_state_dict,
adapter_name="default",
ignore_mismatched_sizes: bool = False):
"""
Set the state dict of the Peft model.
Args:
model ([`PeftModel`]): The Peft model.
peft_model_state_dict (`dict`): The state dict of the Peft model.
adapter_name (`str`, *optional*, defaults to `"default"`):
The name of the adapter whose state dict should be set.
"""
peft_model_state_dict, mismatched_keys = _find_mismatched_keys(
model, peft_model_state_dict, ignore_mismatched_sizes=ignore_mismatched_sizes
)
if mismatched_keys:
# see https://github.com/huggingface/transformers/blob/09f9f566de83eef1f13ee83b5a1bbeebde5c80c1/src/transformers/modeling_utils.py#L4039
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
msg = (
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint "
f"and are being ignored because you passed `ignore_mismatched_sizes=True`: {mismatched_warning}."
)
warnings.warn(msg)
model.load_state_dict(peft_model_state_dict, strict=False)
# if model.peft_config.peft_type != PeftType.LORA and model.peft_config.peft_type != PeftType.BOTTLENECK \
# and model.peft_config.peft_type != PeftType.SAMA:
# model.prompt_encoder.embedding.load_state_dict(
# {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True
# )
if hasattr(model, "prompt_encoder") and model.prompt_encoder is not None:
if "prompt_embeddings" in peft_model_state_dict:
model.prompt_encoder.embedding.load_state_dict(
{"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True
)
return model