fnmodel / merge_script.py
aeb56
Implement manual LoRA merging to fix PEFT key naming conflicts
3a259bc
raw
history blame
4.8 kB
"""
Manual LoRA merging script that handles key naming issues
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from safetensors.torch import load_file, save_file
import os
import argparse
from tqdm import tqdm
def merge_lora_weights(
base_model_name,
adapter_path,
output_path,
device_map="auto"
):
"""Manually merge LoRA weights into base model"""
print(f"Loading base model: {base_model_name}")
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
device_map=device_map,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
print(f"Loading LoRA adapters from: {adapter_path}")
adapter_weights = load_file(os.path.join(adapter_path, "adapter_model.safetensors"))
print(f"Loaded {len(adapter_weights)} adapter weights")
# Load adapter config to get scaling factor
import json
with open(os.path.join(adapter_path, "adapter_config.json")) as f:
adapter_config = json.load(f)
lora_alpha = adapter_config["lora_alpha"]
r = adapter_config["r"]
scaling = lora_alpha / r
print(f"LoRA scaling factor: {scaling} (alpha={lora_alpha}, r={r})")
# Group LoRA weights by layer
lora_pairs = {}
for key in adapter_weights.keys():
if "lora_A" in key:
base_key = key.replace(".lora_A.weight", "")
lora_pairs[base_key] = {
"A": adapter_weights[key],
"B": adapter_weights.get(base_key + ".lora_B.weight")
}
print(f"Found {len(lora_pairs)} LoRA pairs to merge")
# Get model state dict
model_state_dict = model.state_dict()
# Map adapter keys to model keys
# Adapter keys: base_model.model.model.layers.X.self_attn.q_proj
# Model keys might be: model.layers.X.self_attn.q_proj (depending on device_map)
print("\nMerging LoRA weights...")
merged_count = 0
for adapter_key, lora_weights in tqdm(lora_pairs.items()):
# Remove 'base_model.model.' prefix from adapter key
# adapter_key looks like: base_model.model.model.layers.0.self_attn.q_proj
if adapter_key.startswith("base_model.model."):
model_key = adapter_key[len("base_model.model."):]
else:
model_key = adapter_key
# Try to find the matching key in model
found = False
for mk in model_state_dict.keys():
if model_key in mk or mk.endswith(model_key):
model_key = mk
found = True
break
if not found:
# Try alternative key formats
alternatives = [
model_key,
"model." + model_key,
model_key.replace("model.", ""),
]
for alt_key in alternatives:
if alt_key in model_state_dict:
model_key = alt_key
found = True
break
if found and model_key in model_state_dict:
# Merge: W' = W + (B @ A) * scaling
lora_A = lora_weights["A"]
lora_B = lora_weights["B"]
# Move to same device as model weight
device = model_state_dict[model_key].device
lora_A = lora_A.to(device)
lora_B = lora_B.to(device)
# Compute delta_W = (lora_B @ lora_A) * scaling
delta_W = (lora_B @ lora_A) * scaling
# Add to original weight
model_state_dict[model_key] = model_state_dict[model_key] + delta_W.to(model_state_dict[model_key].dtype)
merged_count += 1
else:
print(f"Warning: Could not find model key for {adapter_key}")
print(f"\nSuccessfully merged {merged_count}/{len(lora_pairs)} LoRA weights")
# Load merged weights back into model
model.load_state_dict(model_state_dict, strict=False)
# Save merged model
print(f"\nSaving merged model to: {output_path}")
os.makedirs(output_path, exist_ok=True)
model.save_pretrained(output_path, safe_serialization=True, max_shard_size="5GB")
# Also save tokenizer
print("Saving tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.save_pretrained(output_path)
print("\n✅ Merge complete!")
return model
if __name__ == "__main__":
# For use in the Space
BASE_MODEL = "moonshotai/Kimi-Linear-48B-A3B-Instruct"
ADAPTER_PATH = "/app/lora_adapters" # We'll download here
OUTPUT_PATH = "/app/merged_model"
merge_lora_weights(BASE_MODEL, ADAPTER_PATH, OUTPUT_PATH)