import torch import os import glob import safetensors def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: print( f"Parameter {name} is not available in ZeRO-3, please check the ZeRO-3 status." ) with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param # Borrowed from peft.utils.get_peft_model_state_dict def get_peft_state_maybe_zero_3(named_params, bias): if bias == "none": to_return = {k: t for k, t in named_params if "lora_" in k} elif bias == "all": to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} elif bias == "lora_only": to_return = {} maybe_lora_bias = {} lora_bias_names = set() for k, t in named_params: if "lora_" in k: to_return[k] = t bias_name = k.split("lora_")[0] + "bias" lora_bias_names.add(bias_name) elif "bias" in k: maybe_lora_bias[k] = t for k, t in maybe_lora_bias: if bias_name in lora_bias_names: to_return[bias_name] = t else: raise NotImplementedError to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} return to_return def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): to_return = {k: t for k, t in named_params if "lora_" not in k} if require_grad_only: to_return = {k: t for k, t in to_return.items() if t.requires_grad} to_return = { k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items() } return to_return def _insert_adapter_name_into_state_dict( state_dict: dict[str, torch.Tensor], adapter_name: str, parameter_prefix: str ) -> dict[str, torch.Tensor]: """Utility function to remap the state_dict keys to fit the PEFT model by inserting the adapter name.""" peft_model_state_dict = {} for key, val in state_dict.items(): if parameter_prefix in key: suffix = key.split(parameter_prefix)[1] if "." in suffix: suffix_to_replace = ".".join(suffix.split(".")[1:]) key = key.replace( suffix_to_replace, f"{adapter_name}.{suffix_to_replace}" ) else: key = f"{key}.{adapter_name}" peft_model_state_dict[key] = val else: peft_model_state_dict[key] = val return peft_model_state_dict def save_video(tensor, path): from torchvision.io import write_video tensor = tensor * 255.0 tensor = tensor.permute(0, 2, 3, 1) tensor = tensor.clamp(0, 255).byte() write_video(path, tensor, 4, video_codec="h264") def load_model_from_checkpoint(model, checkpoint_dir, checkpoint_step): checkpoint_paths = glob.glob(os.path.join(checkpoint_dir, "checkpoint-*")) checkpoint_paths.sort(key=lambda x: int(x.split("-")[-1]), reverse=True) if checkpoint_step is None or checkpoint_step == -1: # get the latest checkpoint checkpoint_path = checkpoint_paths[0] print( f"===> Checkpoint step is not provided, using the latest checkpoint: {checkpoint_path}" ) else: checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint-{checkpoint_step}") if checkpoint_path not in checkpoint_paths: checkpoint_path = checkpoint_paths[0] print( f"===> Checkpoint step {checkpoint_step} not found, using the latest checkpoint: {checkpoint_path}" ) else: print( f"===> Checkpoint step {checkpoint_step} found, using the specified checkpoint: {checkpoint_path}" ) checkpoint_step = checkpoint_path.split("checkpoint-")[-1].split("/")[0] full_ckpt = os.path.join(checkpoint_path, "model.pth") lora_ckpt = os.path.join(checkpoint_path, "adapter_model.safetensors") non_lora_ckpt = os.path.join(checkpoint_path, "non_lora_state_dict.pth") if os.path.exists(full_ckpt): model_state_dict = torch.load(full_ckpt, map_location="cpu") model.load_state_dict(model_state_dict) else: lora_state_dict = safetensors.torch.load_file(lora_ckpt) non_lora_state_dict = torch.load(non_lora_ckpt, map_location="cpu") lora_state_dict = _insert_adapter_name_into_state_dict( lora_state_dict, adapter_name="default", parameter_prefix="lora_" ) model_state_dict = model.state_dict() model_state_dict.update(non_lora_state_dict) model_state_dict.update(lora_state_dict) model.load_state_dict(model_state_dict) return model, checkpoint_step def find_target_linear_names( model, num_lora_modules=-1, lora_namespan_exclude=[], verbose=False ): """ Find the target linear modules for LoRA. """ linear_cls = torch.nn.Linear embedding_cls = torch.nn.Embedding lora_module_names = [] for name, module in model.named_modules(): if any(ex_keyword in name for ex_keyword in lora_namespan_exclude): # print(f"Excluding module: {name}") continue if isinstance(module, (linear_cls, embedding_cls)): lora_module_names.append(name) if num_lora_modules > 0: lora_module_names = lora_module_names[-num_lora_modules:] if verbose: print(f"Found {len(lora_module_names)} lora modules: {lora_module_names}") return lora_module_names