HPSv3 / hpsv3 /utils /training_utils.py
sdsdgwe's picture
update
9b57ce7
import torch
import os
import glob
import safetensors
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(
f"Parameter {name} is not available in ZeRO-3, please check the ZeRO-3 status."
)
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {
k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
}
return to_return
def _insert_adapter_name_into_state_dict(
state_dict: dict[str, torch.Tensor], adapter_name: str, parameter_prefix: str
) -> dict[str, torch.Tensor]:
"""Utility function to remap the state_dict keys to fit the PEFT model by inserting the adapter name."""
peft_model_state_dict = {}
for key, val in state_dict.items():
if parameter_prefix in key:
suffix = key.split(parameter_prefix)[1]
if "." in suffix:
suffix_to_replace = ".".join(suffix.split(".")[1:])
key = key.replace(
suffix_to_replace, f"{adapter_name}.{suffix_to_replace}"
)
else:
key = f"{key}.{adapter_name}"
peft_model_state_dict[key] = val
else:
peft_model_state_dict[key] = val
return peft_model_state_dict
def save_video(tensor, path):
from torchvision.io import write_video
tensor = tensor * 255.0
tensor = tensor.permute(0, 2, 3, 1)
tensor = tensor.clamp(0, 255).byte()
write_video(path, tensor, 4, video_codec="h264")
def load_model_from_checkpoint(model, checkpoint_dir, checkpoint_step):
checkpoint_paths = glob.glob(os.path.join(checkpoint_dir, "checkpoint-*"))
checkpoint_paths.sort(key=lambda x: int(x.split("-")[-1]), reverse=True)
if checkpoint_step is None or checkpoint_step == -1:
# get the latest checkpoint
checkpoint_path = checkpoint_paths[0]
print(
f"===> Checkpoint step is not provided, using the latest checkpoint: {checkpoint_path}"
)
else:
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint-{checkpoint_step}")
if checkpoint_path not in checkpoint_paths:
checkpoint_path = checkpoint_paths[0]
print(
f"===> Checkpoint step {checkpoint_step} not found, using the latest checkpoint: {checkpoint_path}"
)
else:
print(
f"===> Checkpoint step {checkpoint_step} found, using the specified checkpoint: {checkpoint_path}"
)
checkpoint_step = checkpoint_path.split("checkpoint-")[-1].split("/")[0]
full_ckpt = os.path.join(checkpoint_path, "model.pth")
lora_ckpt = os.path.join(checkpoint_path, "adapter_model.safetensors")
non_lora_ckpt = os.path.join(checkpoint_path, "non_lora_state_dict.pth")
if os.path.exists(full_ckpt):
model_state_dict = torch.load(full_ckpt, map_location="cpu")
model.load_state_dict(model_state_dict)
else:
lora_state_dict = safetensors.torch.load_file(lora_ckpt)
non_lora_state_dict = torch.load(non_lora_ckpt, map_location="cpu")
lora_state_dict = _insert_adapter_name_into_state_dict(
lora_state_dict, adapter_name="default", parameter_prefix="lora_"
)
model_state_dict = model.state_dict()
model_state_dict.update(non_lora_state_dict)
model_state_dict.update(lora_state_dict)
model.load_state_dict(model_state_dict)
return model, checkpoint_step
def find_target_linear_names(
model, num_lora_modules=-1, lora_namespan_exclude=[], verbose=False
):
"""
Find the target linear modules for LoRA.
"""
linear_cls = torch.nn.Linear
embedding_cls = torch.nn.Embedding
lora_module_names = []
for name, module in model.named_modules():
if any(ex_keyword in name for ex_keyword in lora_namespan_exclude):
# print(f"Excluding module: {name}")
continue
if isinstance(module, (linear_cls, embedding_cls)):
lora_module_names.append(name)
if num_lora_modules > 0:
lora_module_names = lora_module_names[-num_lora_modules:]
if verbose:
print(f"Found {len(lora_module_names)} lora modules: {lora_module_names}")
return lora_module_names