|
|
import torch |
|
|
import os |
|
|
import glob |
|
|
import safetensors |
|
|
|
|
|
|
|
|
def maybe_zero_3(param, ignore_status=False, name=None): |
|
|
from deepspeed import zero |
|
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
|
|
|
|
|
if hasattr(param, "ds_id"): |
|
|
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: |
|
|
if not ignore_status: |
|
|
print( |
|
|
f"Parameter {name} is not available in ZeRO-3, please check the ZeRO-3 status." |
|
|
) |
|
|
with zero.GatheredParameters([param]): |
|
|
param = param.data.detach().cpu().clone() |
|
|
else: |
|
|
param = param.detach().cpu().clone() |
|
|
return param |
|
|
|
|
|
|
|
|
|
|
|
def get_peft_state_maybe_zero_3(named_params, bias): |
|
|
if bias == "none": |
|
|
to_return = {k: t for k, t in named_params if "lora_" in k} |
|
|
elif bias == "all": |
|
|
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} |
|
|
elif bias == "lora_only": |
|
|
to_return = {} |
|
|
maybe_lora_bias = {} |
|
|
lora_bias_names = set() |
|
|
for k, t in named_params: |
|
|
if "lora_" in k: |
|
|
to_return[k] = t |
|
|
bias_name = k.split("lora_")[0] + "bias" |
|
|
lora_bias_names.add(bias_name) |
|
|
elif "bias" in k: |
|
|
maybe_lora_bias[k] = t |
|
|
for k, t in maybe_lora_bias: |
|
|
if bias_name in lora_bias_names: |
|
|
to_return[bias_name] = t |
|
|
else: |
|
|
raise NotImplementedError |
|
|
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} |
|
|
return to_return |
|
|
|
|
|
|
|
|
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): |
|
|
to_return = {k: t for k, t in named_params if "lora_" not in k} |
|
|
if require_grad_only: |
|
|
to_return = {k: t for k, t in to_return.items() if t.requires_grad} |
|
|
to_return = { |
|
|
k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() |
|
|
} |
|
|
return to_return |
|
|
|
|
|
|
|
|
def _insert_adapter_name_into_state_dict( |
|
|
state_dict: dict[str, torch.Tensor], adapter_name: str, parameter_prefix: str |
|
|
) -> dict[str, torch.Tensor]: |
|
|
"""Utility function to remap the state_dict keys to fit the PEFT model by inserting the adapter name.""" |
|
|
peft_model_state_dict = {} |
|
|
for key, val in state_dict.items(): |
|
|
if parameter_prefix in key: |
|
|
suffix = key.split(parameter_prefix)[1] |
|
|
if "." in suffix: |
|
|
suffix_to_replace = ".".join(suffix.split(".")[1:]) |
|
|
key = key.replace( |
|
|
suffix_to_replace, f"{adapter_name}.{suffix_to_replace}" |
|
|
) |
|
|
else: |
|
|
key = f"{key}.{adapter_name}" |
|
|
peft_model_state_dict[key] = val |
|
|
else: |
|
|
peft_model_state_dict[key] = val |
|
|
return peft_model_state_dict |
|
|
|
|
|
|
|
|
def save_video(tensor, path): |
|
|
from torchvision.io import write_video |
|
|
|
|
|
tensor = tensor * 255.0 |
|
|
tensor = tensor.permute(0, 2, 3, 1) |
|
|
tensor = tensor.clamp(0, 255).byte() |
|
|
write_video(path, tensor, 4, video_codec="h264") |
|
|
|
|
|
|
|
|
def load_model_from_checkpoint(model, checkpoint_dir, checkpoint_step): |
|
|
checkpoint_paths = glob.glob(os.path.join(checkpoint_dir, "checkpoint-*")) |
|
|
checkpoint_paths.sort(key=lambda x: int(x.split("-")[-1]), reverse=True) |
|
|
|
|
|
if checkpoint_step is None or checkpoint_step == -1: |
|
|
|
|
|
checkpoint_path = checkpoint_paths[0] |
|
|
print( |
|
|
f"===> Checkpoint step is not provided, using the latest checkpoint: {checkpoint_path}" |
|
|
) |
|
|
else: |
|
|
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint-{checkpoint_step}") |
|
|
if checkpoint_path not in checkpoint_paths: |
|
|
checkpoint_path = checkpoint_paths[0] |
|
|
print( |
|
|
f"===> Checkpoint step {checkpoint_step} not found, using the latest checkpoint: {checkpoint_path}" |
|
|
) |
|
|
else: |
|
|
print( |
|
|
f"===> Checkpoint step {checkpoint_step} found, using the specified checkpoint: {checkpoint_path}" |
|
|
) |
|
|
|
|
|
checkpoint_step = checkpoint_path.split("checkpoint-")[-1].split("/")[0] |
|
|
|
|
|
full_ckpt = os.path.join(checkpoint_path, "model.pth") |
|
|
lora_ckpt = os.path.join(checkpoint_path, "adapter_model.safetensors") |
|
|
non_lora_ckpt = os.path.join(checkpoint_path, "non_lora_state_dict.pth") |
|
|
if os.path.exists(full_ckpt): |
|
|
model_state_dict = torch.load(full_ckpt, map_location="cpu") |
|
|
model.load_state_dict(model_state_dict) |
|
|
else: |
|
|
lora_state_dict = safetensors.torch.load_file(lora_ckpt) |
|
|
non_lora_state_dict = torch.load(non_lora_ckpt, map_location="cpu") |
|
|
|
|
|
lora_state_dict = _insert_adapter_name_into_state_dict( |
|
|
lora_state_dict, adapter_name="default", parameter_prefix="lora_" |
|
|
) |
|
|
|
|
|
model_state_dict = model.state_dict() |
|
|
model_state_dict.update(non_lora_state_dict) |
|
|
model_state_dict.update(lora_state_dict) |
|
|
model.load_state_dict(model_state_dict) |
|
|
|
|
|
return model, checkpoint_step |
|
|
|
|
|
|
|
|
def find_target_linear_names( |
|
|
model, num_lora_modules=-1, lora_namespan_exclude=[], verbose=False |
|
|
): |
|
|
""" |
|
|
Find the target linear modules for LoRA. |
|
|
""" |
|
|
linear_cls = torch.nn.Linear |
|
|
embedding_cls = torch.nn.Embedding |
|
|
lora_module_names = [] |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if any(ex_keyword in name for ex_keyword in lora_namespan_exclude): |
|
|
|
|
|
continue |
|
|
|
|
|
if isinstance(module, (linear_cls, embedding_cls)): |
|
|
lora_module_names.append(name) |
|
|
|
|
|
if num_lora_modules > 0: |
|
|
lora_module_names = lora_module_names[-num_lora_modules:] |
|
|
if verbose: |
|
|
print(f"Found {len(lora_module_names)} lora modules: {lora_module_names}") |
|
|
return lora_module_names |
|
|
|