Spaces:
Paused
Paused
| """ | |
| Manual LoRA merging script that handles key naming issues | |
| """ | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from safetensors.torch import load_file, save_file | |
| import os | |
| import argparse | |
| from tqdm import tqdm | |
| def merge_lora_weights( | |
| base_model_name, | |
| adapter_path, | |
| output_path, | |
| device_map="auto" | |
| ): | |
| """Manually merge LoRA weights into base model""" | |
| print(f"Loading base model: {base_model_name}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map=device_map, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ) | |
| print(f"Loading LoRA adapters from: {adapter_path}") | |
| adapter_weights = load_file(os.path.join(adapter_path, "adapter_model.safetensors")) | |
| print(f"Loaded {len(adapter_weights)} adapter weights") | |
| # Load adapter config to get scaling factor | |
| import json | |
| with open(os.path.join(adapter_path, "adapter_config.json")) as f: | |
| adapter_config = json.load(f) | |
| lora_alpha = adapter_config["lora_alpha"] | |
| r = adapter_config["r"] | |
| scaling = lora_alpha / r | |
| print(f"LoRA scaling factor: {scaling} (alpha={lora_alpha}, r={r})") | |
| # Group LoRA weights by layer | |
| lora_pairs = {} | |
| for key in adapter_weights.keys(): | |
| if "lora_A" in key: | |
| base_key = key.replace(".lora_A.weight", "") | |
| lora_pairs[base_key] = { | |
| "A": adapter_weights[key], | |
| "B": adapter_weights.get(base_key + ".lora_B.weight") | |
| } | |
| print(f"Found {len(lora_pairs)} LoRA pairs to merge") | |
| # Get model state dict | |
| model_state_dict = model.state_dict() | |
| # Map adapter keys to model keys | |
| # Adapter keys: base_model.model.model.layers.X.self_attn.q_proj | |
| # Model keys might be: model.layers.X.self_attn.q_proj (depending on device_map) | |
| print("\nMerging LoRA weights...") | |
| merged_count = 0 | |
| for adapter_key, lora_weights in tqdm(lora_pairs.items()): | |
| # Remove 'base_model.model.' prefix from adapter key | |
| # adapter_key looks like: base_model.model.model.layers.0.self_attn.q_proj | |
| if adapter_key.startswith("base_model.model."): | |
| model_key = adapter_key[len("base_model.model."):] | |
| else: | |
| model_key = adapter_key | |
| # Try to find the matching key in model | |
| found = False | |
| for mk in model_state_dict.keys(): | |
| if model_key in mk or mk.endswith(model_key): | |
| model_key = mk | |
| found = True | |
| break | |
| if not found: | |
| # Try alternative key formats | |
| alternatives = [ | |
| model_key, | |
| "model." + model_key, | |
| model_key.replace("model.", ""), | |
| ] | |
| for alt_key in alternatives: | |
| if alt_key in model_state_dict: | |
| model_key = alt_key | |
| found = True | |
| break | |
| if found and model_key in model_state_dict: | |
| # Merge: W' = W + (B @ A) * scaling | |
| lora_A = lora_weights["A"] | |
| lora_B = lora_weights["B"] | |
| # Move to same device as model weight | |
| device = model_state_dict[model_key].device | |
| lora_A = lora_A.to(device) | |
| lora_B = lora_B.to(device) | |
| # Compute delta_W = (lora_B @ lora_A) * scaling | |
| delta_W = (lora_B @ lora_A) * scaling | |
| # Add to original weight | |
| model_state_dict[model_key] = model_state_dict[model_key] + delta_W.to(model_state_dict[model_key].dtype) | |
| merged_count += 1 | |
| else: | |
| print(f"Warning: Could not find model key for {adapter_key}") | |
| print(f"\nSuccessfully merged {merged_count}/{len(lora_pairs)} LoRA weights") | |
| # Load merged weights back into model | |
| model.load_state_dict(model_state_dict, strict=False) | |
| # Save merged model | |
| print(f"\nSaving merged model to: {output_path}") | |
| os.makedirs(output_path, exist_ok=True) | |
| model.save_pretrained(output_path, safe_serialization=True, max_shard_size="5GB") | |
| # Also save tokenizer | |
| print("Saving tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) | |
| tokenizer.save_pretrained(output_path) | |
| print("\n✅ Merge complete!") | |
| return model | |
| if __name__ == "__main__": | |
| # For use in the Space | |
| BASE_MODEL = "moonshotai/Kimi-Linear-48B-A3B-Instruct" | |
| ADAPTER_PATH = "/app/lora_adapters" # We'll download here | |
| OUTPUT_PATH = "/app/merged_model" | |
| merge_lora_weights(BASE_MODEL, ADAPTER_PATH, OUTPUT_PATH) | |