File size: 4,852 Bytes
2e7f2ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | import os
import torch
from models.modeling_vora import VoRAForCausalLM, VoRAConfig
from utils import logging
logger = logging.get_logger(__name__)
def key_mapping(state_dict, key_mapping_dict):
new_state_dict = dict()
for k, v in state_dict.items():
flag = 0
for prev_key in key_mapping_dict.keys():
if prev_key in k:
new_state_dict[k.replace(prev_key, key_mapping_dict[prev_key])] = v
flag = 1
break
if flag == 0:
new_state_dict[k] = v
return new_state_dict
def merge_lora(checkpoint, lora_key="lora_A"):
new_state_dict = {}
lora_processed = set()
for key in list(checkpoint.keys()):
if lora_key in key:
try:
idx = key.index(lora_key)
except ValueError:
continue
root_key = key[:idx]
suffix = key[idx + len(lora_key):]
if not suffix.startswith('.'):
continue
weight_key = f"{root_key}weight"
lora_A_key = f"{root_key}lora_A.weight"
lora_B_key = f"{root_key}lora_B.weight"
bias_key = f"{root_key}bias" # 新增:显式处理 bias
if weight_key in lora_processed:
continue
lora_processed.update({weight_key, lora_A_key, lora_B_key})
if any(k not in checkpoint for k in [weight_key, lora_A_key, lora_B_key]):
raise KeyError(f"Missing keys for module {root_key}")
W = checkpoint[weight_key]
A = checkpoint[lora_A_key]
B = checkpoint[lora_B_key]
new_state_dict[weight_key] = W + B @ A
if bias_key in checkpoint:
new_state_dict[bias_key] = checkpoint[bias_key]
lora_processed.add(bias_key)
for key, value in checkpoint.items():
if key not in lora_processed:
new_state_dict[key] = value
return new_state_dict
def partial_load_from_checkpoints(
local_checkpoint_path,
ckpt_rename_parameters=None,
map_location="cpu",
model=None,
valid_prefix=None,
lazy_load=False
):
ckpt_rename_parameters = ckpt_rename_parameters or dict()
if os.path.isdir(local_checkpoint_path):
from safetensors.torch import load
import multiprocessing
checkpoint = {}
files = [file for file in os.listdir(local_checkpoint_path) if file.endswith(".safetensors")]
if len(files) == 0:
raise ValueError(f"No safetensors file found in {local_checkpoint_path}")
file_paths = []
for file in files:
file_path = os.path.join(local_checkpoint_path, file)
if not lazy_load:
print(f"loading checkpoint from {file_path}")
with open(file_path, "rb") as f:
data = f.read()
loaded = load(data)
checkpoint.update(loaded)
else:
file_paths.append(file_path)
if lazy_load:
return file_paths
else:
checkpoint = torch.load(local_checkpoint_path, map_location=map_location)
if "state_dict" in checkpoint:
logger.info("partial loading checkpoint")
state_dict = checkpoint["state_dict"]
elif "module" in checkpoint:
# for ds zero2 checkpoint
logger.info("partial loading deepspeed zero2 checkpoint")
state_dict = checkpoint["module"]
ckpt_rename_parameters.update({"module.": ""})
else:
state_dict = checkpoint
if valid_prefix:
new_state_dict = dict()
for k, v in state_dict.items():
for prefix in valid_prefix:
if k.startswith(prefix):
new_state_dict[k] = v
state_dict = new_state_dict
state_dict = key_mapping(state_dict, ckpt_rename_parameters)
return state_dict
if __name__ == "__main__":
import argparse
import yaml
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--save_dir", type=str, required=True)
args = parser.parse_args()
config_path = args.config
checkpoint_path = args.checkpoint
save_path = args.save_dir
with open(config_path, "r") as f:
vora_config = yaml.safe_load(f)["model"]
vora_config["lora"]["r"] = -1
config = VoRAConfig(**vora_config)
model = VoRAForCausalLM._from_config(config=config)
state_dict = partial_load_from_checkpoints(checkpoint_path)
state_dict = merge_lora(state_dict)
model.load_state_dict(state_dict, strict=False)
model.save_pretrained(save_path)
|