aeb56 commited on
Commit
79334bc
·
1 Parent(s): 1a04e17

Add safe_merge and better error handling for LoRA merge with MoE models

Browse files
Files changed (1) hide show
  1. app.py +57 -11
app.py CHANGED
@@ -158,18 +158,64 @@ class ModelMerger:
158
  progress(0.50, desc="Loading LoRA adapters...")
159
  logger.info(f"Loading LoRA adapters from: {LORA_MODEL_NAME}")
160
 
161
- # Merge LoRA weights
162
- self.merged_model = PeftModel.from_pretrained(
163
- self.base_model,
164
- LORA_MODEL_NAME,
165
- torch_dtype=torch.bfloat16,
166
- )
167
- logger.info("LoRA adapters loaded successfully")
168
 
169
- progress(0.70, desc="Merging LoRA weights with base model...")
170
- logger.info("Merging LoRA weights...")
171
- self.merged_model = self.merged_model.merge_and_unload()
172
- logger.info("Models merged successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  # Save merged model
175
  progress(0.85, desc="Saving merged model...")
 
158
  progress(0.50, desc="Loading LoRA adapters...")
159
  logger.info(f"Loading LoRA adapters from: {LORA_MODEL_NAME}")
160
 
161
+ # Check if LoRA model exists and is accessible
162
+ try:
163
+ from huggingface_hub import repo_info
164
+ info = repo_info(LORA_MODEL_NAME, token=hf_token)
165
+ logger.info(f"LoRA model found: {info}")
166
+ except Exception as e:
167
+ logger.warning(f"Could not verify LoRA model: {str(e)}")
168
 
169
+ # Load LoRA adapters with additional parameters
170
+ try:
171
+ logger.info("Attempting to load LoRA adapters...")
172
+ logger.info(f"LoRA targets attention layers: q_proj, k_proj, v_proj, o_proj")
173
+
174
+ # Load PEFT model - this wraps the base model
175
+ peft_model = PeftModel.from_pretrained(
176
+ self.base_model,
177
+ LORA_MODEL_NAME,
178
+ torch_dtype=torch.bfloat16 if not use_8bit else None,
179
+ is_trainable=False,
180
+ )
181
+ logger.info("LoRA adapters loaded successfully")
182
+
183
+ progress(0.70, desc="Merging LoRA weights with base model...")
184
+ logger.info("Merging LoRA weights into base model...")
185
+
186
+ # Use merge_and_unload with explicit safe merge
187
+ try:
188
+ self.merged_model = peft_model.merge_and_unload(safe_merge=True)
189
+ logger.info("Models merged successfully with safe_merge=True")
190
+ except Exception as merge_error:
191
+ logger.warning(f"safe_merge=True failed, trying without: {str(merge_error)}")
192
+ # Fallback to regular merge
193
+ self.merged_model = peft_model.merge_and_unload()
194
+ logger.info("Models merged successfully")
195
+
196
+ except KeyError as e:
197
+ # Handle missing keys - might be an architecture mismatch
198
+ error_key = str(e)
199
+ error_msg = f"Key error when loading LoRA adapters: {error_key}\n\n"
200
+
201
+ if "block_sparse_moe" in error_key or "experts" in error_key:
202
+ error_msg += "⚠️ This error is related to MoE (Mixture of Experts) layers.\n\n"
203
+ error_msg += "The LoRA adapters only target attention layers (q/k/v/o_proj),\n"
204
+ error_msg += "but there seems to be a key naming mismatch with the base model.\n\n"
205
+ error_msg += "Possible causes:\n"
206
+ error_msg += "1. The base model version has changed since training\n"
207
+ error_msg += "2. Different transformers/peft library versions\n"
208
+ error_msg += "3. Model was saved with different device_map than loading\n\n"
209
+
210
+ error_msg += "Please verify:\n"
211
+ error_msg += f"- Base model: {BASE_MODEL_NAME}\n"
212
+ error_msg += f"- LoRA model: {LORA_MODEL_NAME}\n"
213
+ error_msg += "- Both use the same transformers version\n"
214
+ logger.error(error_msg)
215
+ raise Exception(error_msg)
216
+ except Exception as e:
217
+ logger.error(f"Unexpected error during merge: {str(e)}", exc_info=True)
218
+ raise
219
 
220
  # Save merged model
221
  progress(0.85, desc="Saving merged model...")