File size: 4,799 Bytes
3a259bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
Manual LoRA merging script that handles key naming issues
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from safetensors.torch import load_file, save_file
import os
import argparse
from tqdm import tqdm

def merge_lora_weights(
    base_model_name,
    adapter_path,
    output_path,
    device_map="auto"
):
    """Manually merge LoRA weights into base model"""
    
    print(f"Loading base model: {base_model_name}")
    model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.bfloat16,
        device_map=device_map,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )
    
    print(f"Loading LoRA adapters from: {adapter_path}")
    adapter_weights = load_file(os.path.join(adapter_path, "adapter_model.safetensors"))
    
    print(f"Loaded {len(adapter_weights)} adapter weights")
    
    # Load adapter config to get scaling factor
    import json
    with open(os.path.join(adapter_path, "adapter_config.json")) as f:
        adapter_config = json.load(f)
    
    lora_alpha = adapter_config["lora_alpha"]
    r = adapter_config["r"]
    scaling = lora_alpha / r
    
    print(f"LoRA scaling factor: {scaling} (alpha={lora_alpha}, r={r})")
    
    # Group LoRA weights by layer
    lora_pairs = {}
    for key in adapter_weights.keys():
        if "lora_A" in key:
            base_key = key.replace(".lora_A.weight", "")
            lora_pairs[base_key] = {
                "A": adapter_weights[key],
                "B": adapter_weights.get(base_key + ".lora_B.weight")
            }
    
    print(f"Found {len(lora_pairs)} LoRA pairs to merge")
    
    # Get model state dict
    model_state_dict = model.state_dict()
    
    # Map adapter keys to model keys
    # Adapter keys: base_model.model.model.layers.X.self_attn.q_proj
    # Model keys might be: model.layers.X.self_attn.q_proj (depending on device_map)
    
    print("\nMerging LoRA weights...")
    merged_count = 0
    
    for adapter_key, lora_weights in tqdm(lora_pairs.items()):
        # Remove 'base_model.model.' prefix from adapter key
        # adapter_key looks like: base_model.model.model.layers.0.self_attn.q_proj
        if adapter_key.startswith("base_model.model."):
            model_key = adapter_key[len("base_model.model."):]
        else:
            model_key = adapter_key
        
        # Try to find the matching key in model
        found = False
        for mk in model_state_dict.keys():
            if model_key in mk or mk.endswith(model_key):
                model_key = mk
                found = True
                break
        
        if not found:
            # Try alternative key formats
            alternatives = [
                model_key,
                "model." + model_key,
                model_key.replace("model.", ""),
            ]
            
            for alt_key in alternatives:
                if alt_key in model_state_dict:
                    model_key = alt_key
                    found = True
                    break
        
        if found and model_key in model_state_dict:
            # Merge: W' = W + (B @ A) * scaling
            lora_A = lora_weights["A"]
            lora_B = lora_weights["B"]
            
            # Move to same device as model weight
            device = model_state_dict[model_key].device
            lora_A = lora_A.to(device)
            lora_B = lora_B.to(device)
            
            # Compute delta_W = (lora_B @ lora_A) * scaling
            delta_W = (lora_B @ lora_A) * scaling
            
            # Add to original weight
            model_state_dict[model_key] = model_state_dict[model_key] + delta_W.to(model_state_dict[model_key].dtype)
            merged_count += 1
        else:
            print(f"Warning: Could not find model key for {adapter_key}")
    
    print(f"\nSuccessfully merged {merged_count}/{len(lora_pairs)} LoRA weights")
    
    # Load merged weights back into model
    model.load_state_dict(model_state_dict, strict=False)
    
    # Save merged model
    print(f"\nSaving merged model to: {output_path}")
    os.makedirs(output_path, exist_ok=True)
    model.save_pretrained(output_path, safe_serialization=True, max_shard_size="5GB")
    
    # Also save tokenizer
    print("Saving tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
    tokenizer.save_pretrained(output_path)
    
    print("\n✅ Merge complete!")
    return model

if __name__ == "__main__":
    # For use in the Space
    BASE_MODEL = "moonshotai/Kimi-Linear-48B-A3B-Instruct"
    ADAPTER_PATH = "/app/lora_adapters"  # We'll download here
    OUTPUT_PATH = "/app/merged_model"
    
    merge_lora_weights(BASE_MODEL, ADAPTER_PATH, OUTPUT_PATH)