File size: 4,852 Bytes
2e7f2ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os

import torch

from models.modeling_vora import VoRAForCausalLM, VoRAConfig
from utils import logging


logger = logging.get_logger(__name__)


def key_mapping(state_dict, key_mapping_dict):
    new_state_dict = dict()
    for k, v in state_dict.items():
        flag = 0
        for prev_key in key_mapping_dict.keys():
            if prev_key in k:
                new_state_dict[k.replace(prev_key, key_mapping_dict[prev_key])] = v
                flag = 1
                break
        if flag == 0:
            new_state_dict[k] = v
    return new_state_dict


def merge_lora(checkpoint, lora_key="lora_A"):
    new_state_dict = {}
    lora_processed = set()

    for key in list(checkpoint.keys()):
        if lora_key in key:
            try:
                idx = key.index(lora_key)
            except ValueError:
                continue
            root_key = key[:idx]
            suffix = key[idx + len(lora_key):]
            
            if not suffix.startswith('.'):
                continue

            weight_key = f"{root_key}weight"
            lora_A_key = f"{root_key}lora_A.weight"
            lora_B_key = f"{root_key}lora_B.weight"
            bias_key = f"{root_key}bias"  # 新增:显式处理 bias
            
            if weight_key in lora_processed:
                continue
            lora_processed.update({weight_key, lora_A_key, lora_B_key})
            
            if any(k not in checkpoint for k in [weight_key, lora_A_key, lora_B_key]):
                raise KeyError(f"Missing keys for module {root_key}")
            
            W = checkpoint[weight_key]
            A = checkpoint[lora_A_key]
            B = checkpoint[lora_B_key]
            new_state_dict[weight_key] = W + B @ A
            
            if bias_key in checkpoint:
                new_state_dict[bias_key] = checkpoint[bias_key]
                lora_processed.add(bias_key) 

    for key, value in checkpoint.items():
        if key not in lora_processed:
            new_state_dict[key] = value

    return new_state_dict


def partial_load_from_checkpoints(
        local_checkpoint_path,
        ckpt_rename_parameters=None,
        map_location="cpu",
        model=None,
        valid_prefix=None,
        lazy_load=False
    ):
    
    ckpt_rename_parameters = ckpt_rename_parameters or dict()
    if os.path.isdir(local_checkpoint_path):
        from safetensors.torch import load
        import multiprocessing
        checkpoint = {}
        files = [file for file in os.listdir(local_checkpoint_path) if file.endswith(".safetensors")]
        if len(files) == 0:
            raise ValueError(f"No safetensors file found in {local_checkpoint_path}")
        file_paths = []
        for file in files:
            file_path = os.path.join(local_checkpoint_path, file)
            if not lazy_load:
                print(f"loading checkpoint from {file_path}")
                with open(file_path, "rb") as f:
                    data = f.read()
                loaded = load(data)
                checkpoint.update(loaded)
            else:
                file_paths.append(file_path)
        if lazy_load:
            return file_paths
    else:
        checkpoint = torch.load(local_checkpoint_path, map_location=map_location)

    if "state_dict" in checkpoint:
        logger.info("partial loading checkpoint")
        state_dict = checkpoint["state_dict"]
    elif "module" in checkpoint:
        # for ds zero2 checkpoint
        logger.info("partial loading deepspeed zero2 checkpoint")
        state_dict = checkpoint["module"]
        ckpt_rename_parameters.update({"module.": ""})
    else:
        state_dict = checkpoint

    if valid_prefix:
        new_state_dict = dict()
        for k, v in state_dict.items():
            for prefix in valid_prefix:
                if k.startswith(prefix):
                    new_state_dict[k] = v
        state_dict = new_state_dict
    state_dict = key_mapping(state_dict, ckpt_rename_parameters)
    return state_dict


if __name__ == "__main__":
    import argparse
    import yaml

    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--checkpoint", type=str, required=True)
    parser.add_argument("--save_dir", type=str, required=True)
    args = parser.parse_args()

    config_path = args.config
    checkpoint_path = args.checkpoint
    save_path = args.save_dir

    with open(config_path, "r") as f:
        vora_config = yaml.safe_load(f)["model"]
    vora_config["lora"]["r"] = -1
    config = VoRAConfig(**vora_config)

    model = VoRAForCausalLM._from_config(config=config)

    state_dict = partial_load_from_checkpoints(checkpoint_path)
    state_dict = merge_lora(state_dict)
    model.load_state_dict(state_dict, strict=False)
    model.save_pretrained(save_path)