Add extraction script: sanitize_lora.py
Browse files
extraction_scripts/sanitize_lora.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from safetensors.torch import load_file, save_file
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import shutil
|
| 6 |
+
|
| 7 |
+
# ==============================================================================
|
| 8 |
+
# CONFIGURATION
|
| 9 |
+
# ==============================================================================
|
| 10 |
+
# UPDATE THIS PATH to the folder containing your adapter_model.safetensors
|
| 11 |
+
LORA_PATH = "/projects/extern/kisski/kisski-narges-llm-interactive/dir.project/hasan/uni_work/biomni_integration/Biomni/brain_surgery/lora_extraction_results/dequantized_corrected_lora_rank_256"
|
| 12 |
+
|
| 13 |
+
# ==============================================================================
|
| 14 |
+
# 1. SETUP & BACKUP
|
| 15 |
+
# ==============================================================================
|
| 16 |
+
adapter_file = os.path.join(LORA_PATH, "adapter_model.safetensors")
|
| 17 |
+
config_file = os.path.join(LORA_PATH, "adapter_config.json")
|
| 18 |
+
backup_adapter = adapter_file + ".original.bak"
|
| 19 |
+
backup_config = config_file + ".original.bak"
|
| 20 |
+
|
| 21 |
+
print(f"π§ STARTING REPAIR ON: {LORA_PATH}")
|
| 22 |
+
|
| 23 |
+
if not os.path.exists(adapter_file):
|
| 24 |
+
print(f"β Error: File not found: {adapter_file}")
|
| 25 |
+
exit(1)
|
| 26 |
+
|
| 27 |
+
# Create backups if they don't exist yet
|
| 28 |
+
if not os.path.exists(backup_adapter):
|
| 29 |
+
print("π¦ Creating backup of safetensors file...")
|
| 30 |
+
shutil.copy2(adapter_file, backup_adapter)
|
| 31 |
+
|
| 32 |
+
if os.path.exists(config_file) and not os.path.exists(backup_config):
|
| 33 |
+
print("π¦ Creating backup of config file...")
|
| 34 |
+
shutil.copy2(config_file, backup_config)
|
| 35 |
+
|
| 36 |
+
# ==============================================================================
|
| 37 |
+
# 2. PRUNE UNSUPPORTED LAYERS (Weights)
|
| 38 |
+
# ==============================================================================
|
| 39 |
+
print("\nπ Scanning weights for vLLM incompatibility...")
|
| 40 |
+
try:
|
| 41 |
+
tensors = load_file(adapter_file)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"β Critical Error: Could not load safetensors file. It might be corrupt. {e}")
|
| 44 |
+
exit(1)
|
| 45 |
+
|
| 46 |
+
new_tensors = {}
|
| 47 |
+
removed_keys = []
|
| 48 |
+
|
| 49 |
+
# vLLM only supports LoRA on Linear layers (q,k,v,o,gate,up,down).
|
| 50 |
+
# Anything else causes a crash on load.
|
| 51 |
+
FORBIDDEN_KEYWORDS = [
|
| 52 |
+
"lm_head", # The output vocabulary layer
|
| 53 |
+
"embed_tokens", # The input embedding layer
|
| 54 |
+
"layernorm", # Normalization layers
|
| 55 |
+
"norm", # Generic normalization (rms_norm)
|
| 56 |
+
"bias", # Biases (usually not supported in standard vLLM LoRA)
|
| 57 |
+
"rotary_emb" # RoPE embeddings
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
for key, tensor in tensors.items():
|
| 61 |
+
is_bad = False
|
| 62 |
+
for bad_word in FORBIDDEN_KEYWORDS:
|
| 63 |
+
if bad_word in key:
|
| 64 |
+
is_bad = True
|
| 65 |
+
removed_keys.append(key)
|
| 66 |
+
break
|
| 67 |
+
|
| 68 |
+
if not is_bad:
|
| 69 |
+
new_tensors[key] = tensor
|
| 70 |
+
|
| 71 |
+
if len(removed_keys) > 0:
|
| 72 |
+
print(f"βοΈ Found {len(removed_keys)} unsupported layers.")
|
| 73 |
+
print(f" (Examples: {removed_keys[:3]} ...)")
|
| 74 |
+
print(" Pruning them now...")
|
| 75 |
+
save_file(new_tensors, adapter_file)
|
| 76 |
+
print("β
Weights file updated and saved.")
|
| 77 |
+
else:
|
| 78 |
+
print("β
Weights file was already clean.")
|
| 79 |
+
|
| 80 |
+
# ==============================================================================
|
| 81 |
+
# 3. FIX CONFIGURATION (JSON)
|
| 82 |
+
# ==============================================================================
|
| 83 |
+
print("\nπ Checking adapter_config.json...")
|
| 84 |
+
|
| 85 |
+
if os.path.exists(config_file):
|
| 86 |
+
with open(config_file, 'r') as f:
|
| 87 |
+
config = json.load(f)
|
| 88 |
+
|
| 89 |
+
changed = False
|
| 90 |
+
|
| 91 |
+
# Fix 1: modules_to_save must be null
|
| 92 |
+
if config.get("modules_to_save") is not None:
|
| 93 |
+
print(" - Setting 'modules_to_save' to null (was set)")
|
| 94 |
+
config["modules_to_save"] = None
|
| 95 |
+
changed = True
|
| 96 |
+
|
| 97 |
+
# Fix 2: Ensure target modules list is clean (optional but good practice)
|
| 98 |
+
# Sometimes extractors put 'lm_head' in target_modules too
|
| 99 |
+
if "target_modules" in config and isinstance(config["target_modules"], list):
|
| 100 |
+
original_len = len(config["target_modules"])
|
| 101 |
+
config["target_modules"] = [
|
| 102 |
+
m for m in config["target_modules"]
|
| 103 |
+
if not any(bad in m for bad in ["lm_head", "embed_tokens", "norm"])
|
| 104 |
+
]
|
| 105 |
+
if len(config["target_modules"]) < original_len:
|
| 106 |
+
print(" - Cleaned 'target_modules' list")
|
| 107 |
+
changed = True
|
| 108 |
+
|
| 109 |
+
if changed:
|
| 110 |
+
with open(config_file, 'w') as f:
|
| 111 |
+
json.dump(config, f, indent=2)
|
| 112 |
+
print("β
Config file updated.")
|
| 113 |
+
else:
|
| 114 |
+
print("β
Config file was already correct.")
|
| 115 |
+
else:
|
| 116 |
+
print("β οΈ Warning: adapter_config.json not found!")
|
| 117 |
+
|
| 118 |
+
# ==============================================================================
|
| 119 |
+
# 4. FINAL VERIFICATION
|
| 120 |
+
# ==============================================================================
|
| 121 |
+
print("\n----- VERIFICATION -----")
|
| 122 |
+
try:
|
| 123 |
+
# 1. Check file size
|
| 124 |
+
size_mb = os.path.getsize(adapter_file) / (1024 * 1024)
|
| 125 |
+
print(f"File Size: {size_mb:.2f} MB")
|
| 126 |
+
|
| 127 |
+
# 2. Check Loadability
|
| 128 |
+
test_load = load_file(adapter_file)
|
| 129 |
+
print(f"Keys Remaining: {len(test_load)}")
|
| 130 |
+
|
| 131 |
+
# 3. Check for stragglers
|
| 132 |
+
stragglers = [k for k in test_load.keys() if "lm_head" in k or "norm" in k]
|
| 133 |
+
if stragglers:
|
| 134 |
+
print(f"β FAILURE: Still found bad keys: {stragglers}")
|
| 135 |
+
else:
|
| 136 |
+
print("π SUCCESS: LoRA is clean and vLLM-ready.")
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
print(f"β FAILURE: File seems corrupted: {e}")
|
| 140 |
+
|
| 141 |
+
print("==============================================================================")
|
| 142 |
+
print("You can now submit your SBATCH script.")
|