import os import torch from models.modeling_vora import VoRAForCausalLM, VoRAConfig from utils import logging logger = logging.get_logger(__name__) def key_mapping(state_dict, key_mapping_dict): new_state_dict = dict() for k, v in state_dict.items(): flag = 0 for prev_key in key_mapping_dict.keys(): if prev_key in k: new_state_dict[k.replace(prev_key, key_mapping_dict[prev_key])] = v flag = 1 break if flag == 0: new_state_dict[k] = v return new_state_dict def merge_lora(checkpoint, lora_key="lora_A"): new_state_dict = {} lora_processed = set() for key in list(checkpoint.keys()): if lora_key in key: try: idx = key.index(lora_key) except ValueError: continue root_key = key[:idx] suffix = key[idx + len(lora_key):] if not suffix.startswith('.'): continue weight_key = f"{root_key}weight" lora_A_key = f"{root_key}lora_A.weight" lora_B_key = f"{root_key}lora_B.weight" bias_key = f"{root_key}bias" # 新增:显式处理 bias if weight_key in lora_processed: continue lora_processed.update({weight_key, lora_A_key, lora_B_key}) if any(k not in checkpoint for k in [weight_key, lora_A_key, lora_B_key]): raise KeyError(f"Missing keys for module {root_key}") W = checkpoint[weight_key] A = checkpoint[lora_A_key] B = checkpoint[lora_B_key] new_state_dict[weight_key] = W + B @ A if bias_key in checkpoint: new_state_dict[bias_key] = checkpoint[bias_key] lora_processed.add(bias_key) for key, value in checkpoint.items(): if key not in lora_processed: new_state_dict[key] = value return new_state_dict def partial_load_from_checkpoints( local_checkpoint_path, ckpt_rename_parameters=None, map_location="cpu", model=None, valid_prefix=None, lazy_load=False ): ckpt_rename_parameters = ckpt_rename_parameters or dict() if os.path.isdir(local_checkpoint_path): from safetensors.torch import load import multiprocessing checkpoint = {} files = [file for file in os.listdir(local_checkpoint_path) if file.endswith(".safetensors")] if len(files) == 0: raise ValueError(f"No safetensors file found in {local_checkpoint_path}") file_paths = [] for file in files: file_path = os.path.join(local_checkpoint_path, file) if not lazy_load: print(f"loading checkpoint from {file_path}") with open(file_path, "rb") as f: data = f.read() loaded = load(data) checkpoint.update(loaded) else: file_paths.append(file_path) if lazy_load: return file_paths else: checkpoint = torch.load(local_checkpoint_path, map_location=map_location) if "state_dict" in checkpoint: logger.info("partial loading checkpoint") state_dict = checkpoint["state_dict"] elif "module" in checkpoint: # for ds zero2 checkpoint logger.info("partial loading deepspeed zero2 checkpoint") state_dict = checkpoint["module"] ckpt_rename_parameters.update({"module.": ""}) else: state_dict = checkpoint if valid_prefix: new_state_dict = dict() for k, v in state_dict.items(): for prefix in valid_prefix: if k.startswith(prefix): new_state_dict[k] = v state_dict = new_state_dict state_dict = key_mapping(state_dict, ckpt_rename_parameters) return state_dict if __name__ == "__main__": import argparse import yaml parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--checkpoint", type=str, required=True) parser.add_argument("--save_dir", type=str, required=True) args = parser.parse_args() config_path = args.config checkpoint_path = args.checkpoint save_path = args.save_dir with open(config_path, "r") as f: vora_config = yaml.safe_load(f)["model"] vora_config["lora"]["r"] = -1 config = VoRAConfig(**vora_config) model = VoRAForCausalLM._from_config(config=config) state_dict = partial_load_from_checkpoints(checkpoint_path) state_dict = merge_lora(state_dict) model.load_state_dict(state_dict, strict=False) model.save_pretrained(save_path)