aeb56 commited on
Commit
3a259bc
·
1 Parent(s): d3d4339

Implement manual LoRA merging to fix PEFT key naming conflicts

Browse files
Files changed (2) hide show
  1. app.py +105 -69
  2. merge_script.py +140 -0
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import gradio as gr
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from peft import PeftModel, PeftConfig
 
6
  import gc
7
  from huggingface_hub import login, snapshot_download
8
  import logging
@@ -61,6 +62,83 @@ class ModelMerger:
61
  logger.error(f"Login failed: {str(e)}")
62
  return f"❌ Login failed: {str(e)}"
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def merge_models(self, hf_token, use_8bit=False, progress=gr.Progress()):
65
  """Merge LoRA adapters with base model"""
66
  try:
@@ -118,24 +196,16 @@ class ModelMerger:
118
  precision_desc = "bfloat16"
119
 
120
  try:
121
- # For merging, use sequential device map to avoid complex key nesting
122
- # This ensures consistent key names between training and merging
123
  load_kwargs = {
124
  "trust_remote_code": True,
125
  "low_cpu_mem_usage": True,
126
- "device_map": "sequential", # Changed from "auto" to avoid key nesting issues
127
  "max_memory": max_memory,
 
128
  }
129
 
130
- if use_8bit:
131
- # Use 8-bit quantization for tighter memory constraints
132
- load_kwargs["load_in_8bit"] = True
133
- load_kwargs["llm_int8_enable_fp32_cpu_offload"] = True
134
- load_kwargs["llm_int8_threshold"] = 6.0
135
- logger.info("Enabling CPU offload for 8-bit quantization")
136
- else:
137
- # Use bfloat16 for best quality when memory allows
138
- load_kwargs["torch_dtype"] = torch.bfloat16
139
 
140
  self.base_model = AutoModelForCausalLM.from_pretrained(
141
  BASE_MODEL_NAME,
@@ -155,68 +225,34 @@ class ModelMerger:
155
  error_msg += "\n💡 **Try enabling 8-bit quantization** to reduce memory usage by ~50%."
156
  raise Exception(error_msg)
157
 
158
- # Load LoRA configuration
159
- progress(0.50, desc="Loading LoRA adapters...")
160
- logger.info(f"Loading LoRA adapters from: {LORA_MODEL_NAME}")
161
 
162
- # Check if LoRA model exists and is accessible
163
- try:
164
- from huggingface_hub import repo_info
165
- info = repo_info(LORA_MODEL_NAME, token=hf_token)
166
- logger.info(f"LoRA model found: {info}")
167
- except Exception as e:
168
- logger.warning(f"Could not verify LoRA model: {str(e)}")
 
 
 
 
169
 
170
- # Load LoRA adapters with additional parameters
171
  try:
172
- logger.info("Attempting to load LoRA adapters...")
173
- logger.info(f"LoRA targets attention layers: q_proj, k_proj, v_proj, o_proj")
174
-
175
- # Load PEFT model - this wraps the base model
176
- peft_model = PeftModel.from_pretrained(
177
- self.base_model,
178
- LORA_MODEL_NAME,
179
- torch_dtype=torch.bfloat16 if not use_8bit else None,
180
- is_trainable=False,
181
- )
182
- logger.info("LoRA adapters loaded successfully")
183
-
184
- progress(0.70, desc="Merging LoRA weights with base model...")
185
- logger.info("Merging LoRA weights into base model...")
186
-
187
- # Use merge_and_unload with explicit safe merge
188
- try:
189
- self.merged_model = peft_model.merge_and_unload(safe_merge=True)
190
- logger.info("Models merged successfully with safe_merge=True")
191
- except Exception as merge_error:
192
- logger.warning(f"safe_merge=True failed, trying without: {str(merge_error)}")
193
- # Fallback to regular merge
194
- self.merged_model = peft_model.merge_and_unload()
195
- logger.info("Models merged successfully")
196
-
197
- except KeyError as e:
198
- # Handle missing keys - might be an architecture mismatch
199
- error_key = str(e)
200
- error_msg = f"Key error when loading LoRA adapters: {error_key}\n\n"
201
-
202
- if "block_sparse_moe" in error_key or "experts" in error_key:
203
- error_msg += "⚠️ This error is related to MoE (Mixture of Experts) layers.\n\n"
204
- error_msg += "The LoRA adapters only target attention layers (q/k/v/o_proj),\n"
205
- error_msg += "but there seems to be a key naming mismatch with the base model.\n\n"
206
- error_msg += "Possible causes:\n"
207
- error_msg += "1. The base model version has changed since training\n"
208
- error_msg += "2. Different transformers/peft library versions\n"
209
- error_msg += "3. Model was saved with different device_map than loading\n\n"
210
 
211
- error_msg += "Please verify:\n"
212
- error_msg += f"- Base model: {BASE_MODEL_NAME}\n"
213
- error_msg += f"- LoRA model: {LORA_MODEL_NAME}\n"
214
- error_msg += "- Both use the same transformers version\n"
215
- logger.error(error_msg)
 
 
216
  raise Exception(error_msg)
217
- except Exception as e:
218
- logger.error(f"Unexpected error during merge: {str(e)}", exc_info=True)
219
- raise
220
 
221
  # Save merged model
222
  progress(0.85, desc="Saving merged model...")
 
3
  import gradio as gr
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from peft import PeftModel, PeftConfig
6
+ from safetensors.torch import load_file
7
  import gc
8
  from huggingface_hub import login, snapshot_download
9
  import logging
 
62
  logger.error(f"Login failed: {str(e)}")
63
  return f"❌ Login failed: {str(e)}"
64
 
65
+ def manual_merge_lora(self, model, adapter_path, progress=gr.Progress()):
66
+ """Manually merge LoRA weights into model to avoid PEFT key naming issues"""
67
+ import json
68
+ from tqdm import tqdm
69
+
70
+ logger.info("Using manual LoRA merge to avoid key naming conflicts...")
71
+ progress(0.55, desc="Loading LoRA adapter weights...")
72
+
73
+ # Load adapter weights
74
+ adapter_file = os.path.join(adapter_path, "adapter_model.safetensors")
75
+ adapter_weights = load_file(adapter_file)
76
+ logger.info(f"Loaded {len(adapter_weights)} adapter weight tensors")
77
+
78
+ # Load adapter config
79
+ config_file = os.path.join(adapter_path, "adapter_config.json")
80
+ with open(config_file) as f:
81
+ adapter_config = json.load(f)
82
+
83
+ lora_alpha = adapter_config["lora_alpha"]
84
+ r = adapter_config["r"]
85
+ scaling = lora_alpha / r
86
+ logger.info(f"LoRA scaling: {scaling} (alpha={lora_alpha}, r={r})")
87
+
88
+ # Group LoRA A and B weights
89
+ lora_pairs = {}
90
+ for key in adapter_weights.keys():
91
+ if "lora_A" in key:
92
+ base_key = key.replace(".lora_A.weight", "")
93
+ lora_pairs[base_key] = {
94
+ "A": adapter_weights[key],
95
+ "B": adapter_weights.get(base_key + ".lora_B.weight")
96
+ }
97
+
98
+ logger.info(f"Found {len(lora_pairs)} LoRA pairs to merge")
99
+
100
+ progress(0.65, desc=f"Merging {len(lora_pairs)} LoRA layers...")
101
+
102
+ # Get model state dict
103
+ model_state_dict = model.state_dict()
104
+ merged_count = 0
105
+
106
+ for adapter_key, lora_weights in lora_pairs.items():
107
+ # adapter_key: base_model.model.model.layers.0.self_attn.q_proj
108
+ # Need to find corresponding key in model_state_dict
109
+
110
+ # Remove 'base_model.model.' prefix
111
+ if adapter_key.startswith("base_model.model."):
112
+ search_key = adapter_key[len("base_model.model."):]
113
+ else:
114
+ search_key = adapter_key
115
+
116
+ # Find matching key in model
117
+ model_key = None
118
+ for mk in model_state_dict.keys():
119
+ if search_key in mk or mk.endswith(search_key.split(".")[-4:][0]):
120
+ # Match by layer structure
121
+ if all(part in mk for part in search_key.split(".")[-4:]):
122
+ model_key = mk
123
+ break
124
+
125
+ if model_key and model_key in model_state_dict:
126
+ lora_A = lora_weights["A"].to(model_state_dict[model_key].device)
127
+ lora_B = lora_weights["B"].to(model_state_dict[model_key].device)
128
+
129
+ # Merge: W_new = W_old + (lora_B @ lora_A) * scaling
130
+ delta_W = (lora_B @ lora_A) * scaling
131
+ model_state_dict[model_key] = model_state_dict[model_key] + delta_W.to(model_state_dict[model_key].dtype)
132
+ merged_count += 1
133
+
134
+ logger.info(f"Successfully merged {merged_count}/{len(lora_pairs)} LoRA weights")
135
+
136
+ # Load merged weights back
137
+ progress(0.75, desc="Loading merged weights into model...")
138
+ model.load_state_dict(model_state_dict, strict=False)
139
+
140
+ return model
141
+
142
  def merge_models(self, hf_token, use_8bit=False, progress=gr.Progress()):
143
  """Merge LoRA adapters with base model"""
144
  try:
 
196
  precision_desc = "bfloat16"
197
 
198
  try:
199
+ # Try loading with balanced device map to distribute evenly
 
200
  load_kwargs = {
201
  "trust_remote_code": True,
202
  "low_cpu_mem_usage": True,
203
+ "device_map": "balanced", # Distribute layers evenly across GPUs
204
  "max_memory": max_memory,
205
+ "torch_dtype": torch.bfloat16,
206
  }
207
 
208
+ logger.info("Loading base model with balanced device map...")
 
 
 
 
 
 
 
 
209
 
210
  self.base_model = AutoModelForCausalLM.from_pretrained(
211
  BASE_MODEL_NAME,
 
225
  error_msg += "\n💡 **Try enabling 8-bit quantization** to reduce memory usage by ~50%."
226
  raise Exception(error_msg)
227
 
228
+ # Download LoRA adapters
229
+ progress(0.50, desc="Downloading LoRA adapters...")
230
+ logger.info(f"Downloading LoRA adapters from: {LORA_MODEL_NAME}")
231
 
232
+ # Download entire adapter folder
233
+ adapter_path = snapshot_download(
234
+ repo_id=LORA_MODEL_NAME,
235
+ token=hf_token,
236
+ allow_patterns=["adapter_*", "*.json"]
237
+ )
238
+ logger.info(f"LoRA adapters downloaded to: {adapter_path}")
239
+
240
+ # Use manual merge to avoid PEFT key naming issues
241
+ progress(0.55, desc="Merging LoRA weights (manual merge)...")
242
+ logger.info("Using manual LoRA merge to avoid key naming conflicts with PEFT")
243
 
 
244
  try:
245
+ self.merged_model = self.manual_merge_lora(self.base_model, adapter_path, progress)
246
+ logger.info("LoRA weights merged successfully using manual method")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
+ except Exception as merge_error:
249
+ logger.error(f"Manual merge failed: {str(merge_error)}", exc_info=True)
250
+ error_msg = f"Failed to merge LoRA adapters: {str(merge_error)}\n\n"
251
+ error_msg += "This could be due to:\n"
252
+ error_msg += "1. Incompatible model architectures\n"
253
+ error_msg += "2. Corrupted adapter files\n"
254
+ error_msg += "3. Memory issues during merge\n"
255
  raise Exception(error_msg)
 
 
 
256
 
257
  # Save merged model
258
  progress(0.85, desc="Saving merged model...")
merge_script.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Manual LoRA merging script that handles key naming issues
3
+ """
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from safetensors.torch import load_file, save_file
7
+ import os
8
+ import argparse
9
+ from tqdm import tqdm
10
+
11
+ def merge_lora_weights(
12
+ base_model_name,
13
+ adapter_path,
14
+ output_path,
15
+ device_map="auto"
16
+ ):
17
+ """Manually merge LoRA weights into base model"""
18
+
19
+ print(f"Loading base model: {base_model_name}")
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ base_model_name,
22
+ torch_dtype=torch.bfloat16,
23
+ device_map=device_map,
24
+ trust_remote_code=True,
25
+ low_cpu_mem_usage=True,
26
+ )
27
+
28
+ print(f"Loading LoRA adapters from: {adapter_path}")
29
+ adapter_weights = load_file(os.path.join(adapter_path, "adapter_model.safetensors"))
30
+
31
+ print(f"Loaded {len(adapter_weights)} adapter weights")
32
+
33
+ # Load adapter config to get scaling factor
34
+ import json
35
+ with open(os.path.join(adapter_path, "adapter_config.json")) as f:
36
+ adapter_config = json.load(f)
37
+
38
+ lora_alpha = adapter_config["lora_alpha"]
39
+ r = adapter_config["r"]
40
+ scaling = lora_alpha / r
41
+
42
+ print(f"LoRA scaling factor: {scaling} (alpha={lora_alpha}, r={r})")
43
+
44
+ # Group LoRA weights by layer
45
+ lora_pairs = {}
46
+ for key in adapter_weights.keys():
47
+ if "lora_A" in key:
48
+ base_key = key.replace(".lora_A.weight", "")
49
+ lora_pairs[base_key] = {
50
+ "A": adapter_weights[key],
51
+ "B": adapter_weights.get(base_key + ".lora_B.weight")
52
+ }
53
+
54
+ print(f"Found {len(lora_pairs)} LoRA pairs to merge")
55
+
56
+ # Get model state dict
57
+ model_state_dict = model.state_dict()
58
+
59
+ # Map adapter keys to model keys
60
+ # Adapter keys: base_model.model.model.layers.X.self_attn.q_proj
61
+ # Model keys might be: model.layers.X.self_attn.q_proj (depending on device_map)
62
+
63
+ print("\nMerging LoRA weights...")
64
+ merged_count = 0
65
+
66
+ for adapter_key, lora_weights in tqdm(lora_pairs.items()):
67
+ # Remove 'base_model.model.' prefix from adapter key
68
+ # adapter_key looks like: base_model.model.model.layers.0.self_attn.q_proj
69
+ if adapter_key.startswith("base_model.model."):
70
+ model_key = adapter_key[len("base_model.model."):]
71
+ else:
72
+ model_key = adapter_key
73
+
74
+ # Try to find the matching key in model
75
+ found = False
76
+ for mk in model_state_dict.keys():
77
+ if model_key in mk or mk.endswith(model_key):
78
+ model_key = mk
79
+ found = True
80
+ break
81
+
82
+ if not found:
83
+ # Try alternative key formats
84
+ alternatives = [
85
+ model_key,
86
+ "model." + model_key,
87
+ model_key.replace("model.", ""),
88
+ ]
89
+
90
+ for alt_key in alternatives:
91
+ if alt_key in model_state_dict:
92
+ model_key = alt_key
93
+ found = True
94
+ break
95
+
96
+ if found and model_key in model_state_dict:
97
+ # Merge: W' = W + (B @ A) * scaling
98
+ lora_A = lora_weights["A"]
99
+ lora_B = lora_weights["B"]
100
+
101
+ # Move to same device as model weight
102
+ device = model_state_dict[model_key].device
103
+ lora_A = lora_A.to(device)
104
+ lora_B = lora_B.to(device)
105
+
106
+ # Compute delta_W = (lora_B @ lora_A) * scaling
107
+ delta_W = (lora_B @ lora_A) * scaling
108
+
109
+ # Add to original weight
110
+ model_state_dict[model_key] = model_state_dict[model_key] + delta_W.to(model_state_dict[model_key].dtype)
111
+ merged_count += 1
112
+ else:
113
+ print(f"Warning: Could not find model key for {adapter_key}")
114
+
115
+ print(f"\nSuccessfully merged {merged_count}/{len(lora_pairs)} LoRA weights")
116
+
117
+ # Load merged weights back into model
118
+ model.load_state_dict(model_state_dict, strict=False)
119
+
120
+ # Save merged model
121
+ print(f"\nSaving merged model to: {output_path}")
122
+ os.makedirs(output_path, exist_ok=True)
123
+ model.save_pretrained(output_path, safe_serialization=True, max_shard_size="5GB")
124
+
125
+ # Also save tokenizer
126
+ print("Saving tokenizer...")
127
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
128
+ tokenizer.save_pretrained(output_path)
129
+
130
+ print("\n✅ Merge complete!")
131
+ return model
132
+
133
+ if __name__ == "__main__":
134
+ # For use in the Space
135
+ BASE_MODEL = "moonshotai/Kimi-Linear-48B-A3B-Instruct"
136
+ ADAPTER_PATH = "/app/lora_adapters" # We'll download here
137
+ OUTPUT_PATH = "/app/merged_model"
138
+
139
+ merge_lora_weights(BASE_MODEL, ADAPTER_PATH, OUTPUT_PATH)
140
+