File size: 5,869 Bytes
9b57ce7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
import os
import glob
import safetensors
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(
f"Parameter {name} is not available in ZeRO-3, please check the ZeRO-3 status."
)
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {
k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
}
return to_return
def _insert_adapter_name_into_state_dict(
state_dict: dict[str, torch.Tensor], adapter_name: str, parameter_prefix: str
) -> dict[str, torch.Tensor]:
"""Utility function to remap the state_dict keys to fit the PEFT model by inserting the adapter name."""
peft_model_state_dict = {}
for key, val in state_dict.items():
if parameter_prefix in key:
suffix = key.split(parameter_prefix)[1]
if "." in suffix:
suffix_to_replace = ".".join(suffix.split(".")[1:])
key = key.replace(
suffix_to_replace, f"{adapter_name}.{suffix_to_replace}"
)
else:
key = f"{key}.{adapter_name}"
peft_model_state_dict[key] = val
else:
peft_model_state_dict[key] = val
return peft_model_state_dict
def save_video(tensor, path):
from torchvision.io import write_video
tensor = tensor * 255.0
tensor = tensor.permute(0, 2, 3, 1)
tensor = tensor.clamp(0, 255).byte()
write_video(path, tensor, 4, video_codec="h264")
def load_model_from_checkpoint(model, checkpoint_dir, checkpoint_step):
checkpoint_paths = glob.glob(os.path.join(checkpoint_dir, "checkpoint-*"))
checkpoint_paths.sort(key=lambda x: int(x.split("-")[-1]), reverse=True)
if checkpoint_step is None or checkpoint_step == -1:
# get the latest checkpoint
checkpoint_path = checkpoint_paths[0]
print(
f"===> Checkpoint step is not provided, using the latest checkpoint: {checkpoint_path}"
)
else:
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint-{checkpoint_step}")
if checkpoint_path not in checkpoint_paths:
checkpoint_path = checkpoint_paths[0]
print(
f"===> Checkpoint step {checkpoint_step} not found, using the latest checkpoint: {checkpoint_path}"
)
else:
print(
f"===> Checkpoint step {checkpoint_step} found, using the specified checkpoint: {checkpoint_path}"
)
checkpoint_step = checkpoint_path.split("checkpoint-")[-1].split("/")[0]
full_ckpt = os.path.join(checkpoint_path, "model.pth")
lora_ckpt = os.path.join(checkpoint_path, "adapter_model.safetensors")
non_lora_ckpt = os.path.join(checkpoint_path, "non_lora_state_dict.pth")
if os.path.exists(full_ckpt):
model_state_dict = torch.load(full_ckpt, map_location="cpu")
model.load_state_dict(model_state_dict)
else:
lora_state_dict = safetensors.torch.load_file(lora_ckpt)
non_lora_state_dict = torch.load(non_lora_ckpt, map_location="cpu")
lora_state_dict = _insert_adapter_name_into_state_dict(
lora_state_dict, adapter_name="default", parameter_prefix="lora_"
)
model_state_dict = model.state_dict()
model_state_dict.update(non_lora_state_dict)
model_state_dict.update(lora_state_dict)
model.load_state_dict(model_state_dict)
return model, checkpoint_step
def find_target_linear_names(
model, num_lora_modules=-1, lora_namespan_exclude=[], verbose=False
):
"""
Find the target linear modules for LoRA.
"""
linear_cls = torch.nn.Linear
embedding_cls = torch.nn.Embedding
lora_module_names = []
for name, module in model.named_modules():
if any(ex_keyword in name for ex_keyword in lora_namespan_exclude):
# print(f"Excluding module: {name}")
continue
if isinstance(module, (linear_cls, embedding_cls)):
lora_module_names.append(name)
if num_lora_modules > 0:
lora_module_names = lora_module_names[-num_lora_modules:]
if verbose:
print(f"Found {len(lora_module_names)} lora modules: {lora_module_names}")
return lora_module_names
|