Spaces:
Paused
Paused
aeb56
commited on
Commit
·
79334bc
1
Parent(s):
1a04e17
Add safe_merge and better error handling for LoRA merge with MoE models
Browse files
app.py
CHANGED
|
@@ -158,18 +158,64 @@ class ModelMerger:
|
|
| 158 |
progress(0.50, desc="Loading LoRA adapters...")
|
| 159 |
logger.info(f"Loading LoRA adapters from: {LORA_MODEL_NAME}")
|
| 160 |
|
| 161 |
-
#
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
LORA_MODEL_NAME,
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
# Save merged model
|
| 175 |
progress(0.85, desc="Saving merged model...")
|
|
|
|
| 158 |
progress(0.50, desc="Loading LoRA adapters...")
|
| 159 |
logger.info(f"Loading LoRA adapters from: {LORA_MODEL_NAME}")
|
| 160 |
|
| 161 |
+
# Check if LoRA model exists and is accessible
|
| 162 |
+
try:
|
| 163 |
+
from huggingface_hub import repo_info
|
| 164 |
+
info = repo_info(LORA_MODEL_NAME, token=hf_token)
|
| 165 |
+
logger.info(f"LoRA model found: {info}")
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.warning(f"Could not verify LoRA model: {str(e)}")
|
| 168 |
|
| 169 |
+
# Load LoRA adapters with additional parameters
|
| 170 |
+
try:
|
| 171 |
+
logger.info("Attempting to load LoRA adapters...")
|
| 172 |
+
logger.info(f"LoRA targets attention layers: q_proj, k_proj, v_proj, o_proj")
|
| 173 |
+
|
| 174 |
+
# Load PEFT model - this wraps the base model
|
| 175 |
+
peft_model = PeftModel.from_pretrained(
|
| 176 |
+
self.base_model,
|
| 177 |
+
LORA_MODEL_NAME,
|
| 178 |
+
torch_dtype=torch.bfloat16 if not use_8bit else None,
|
| 179 |
+
is_trainable=False,
|
| 180 |
+
)
|
| 181 |
+
logger.info("LoRA adapters loaded successfully")
|
| 182 |
+
|
| 183 |
+
progress(0.70, desc="Merging LoRA weights with base model...")
|
| 184 |
+
logger.info("Merging LoRA weights into base model...")
|
| 185 |
+
|
| 186 |
+
# Use merge_and_unload with explicit safe merge
|
| 187 |
+
try:
|
| 188 |
+
self.merged_model = peft_model.merge_and_unload(safe_merge=True)
|
| 189 |
+
logger.info("Models merged successfully with safe_merge=True")
|
| 190 |
+
except Exception as merge_error:
|
| 191 |
+
logger.warning(f"safe_merge=True failed, trying without: {str(merge_error)}")
|
| 192 |
+
# Fallback to regular merge
|
| 193 |
+
self.merged_model = peft_model.merge_and_unload()
|
| 194 |
+
logger.info("Models merged successfully")
|
| 195 |
+
|
| 196 |
+
except KeyError as e:
|
| 197 |
+
# Handle missing keys - might be an architecture mismatch
|
| 198 |
+
error_key = str(e)
|
| 199 |
+
error_msg = f"Key error when loading LoRA adapters: {error_key}\n\n"
|
| 200 |
+
|
| 201 |
+
if "block_sparse_moe" in error_key or "experts" in error_key:
|
| 202 |
+
error_msg += "⚠️ This error is related to MoE (Mixture of Experts) layers.\n\n"
|
| 203 |
+
error_msg += "The LoRA adapters only target attention layers (q/k/v/o_proj),\n"
|
| 204 |
+
error_msg += "but there seems to be a key naming mismatch with the base model.\n\n"
|
| 205 |
+
error_msg += "Possible causes:\n"
|
| 206 |
+
error_msg += "1. The base model version has changed since training\n"
|
| 207 |
+
error_msg += "2. Different transformers/peft library versions\n"
|
| 208 |
+
error_msg += "3. Model was saved with different device_map than loading\n\n"
|
| 209 |
+
|
| 210 |
+
error_msg += "Please verify:\n"
|
| 211 |
+
error_msg += f"- Base model: {BASE_MODEL_NAME}\n"
|
| 212 |
+
error_msg += f"- LoRA model: {LORA_MODEL_NAME}\n"
|
| 213 |
+
error_msg += "- Both use the same transformers version\n"
|
| 214 |
+
logger.error(error_msg)
|
| 215 |
+
raise Exception(error_msg)
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"Unexpected error during merge: {str(e)}", exc_info=True)
|
| 218 |
+
raise
|
| 219 |
|
| 220 |
# Save merged model
|
| 221 |
progress(0.85, desc="Saving merged model...")
|