hassanshka commited on
Commit
33b997f
Β·
verified Β·
1 Parent(s): 6954fd0

Add extraction script: sanitize_lora.py

Browse files
Files changed (1) hide show
  1. extraction_scripts/sanitize_lora.py +142 -0
extraction_scripts/sanitize_lora.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from safetensors.torch import load_file, save_file
3
+ import os
4
+ import json
5
+ import shutil
6
+
7
+ # ==============================================================================
8
+ # CONFIGURATION
9
+ # ==============================================================================
10
+ # UPDATE THIS PATH to the folder containing your adapter_model.safetensors
11
+ LORA_PATH = "/projects/extern/kisski/kisski-narges-llm-interactive/dir.project/hasan/uni_work/biomni_integration/Biomni/brain_surgery/lora_extraction_results/dequantized_corrected_lora_rank_256"
12
+
13
+ # ==============================================================================
14
+ # 1. SETUP & BACKUP
15
+ # ==============================================================================
16
+ adapter_file = os.path.join(LORA_PATH, "adapter_model.safetensors")
17
+ config_file = os.path.join(LORA_PATH, "adapter_config.json")
18
+ backup_adapter = adapter_file + ".original.bak"
19
+ backup_config = config_file + ".original.bak"
20
+
21
+ print(f"πŸ”§ STARTING REPAIR ON: {LORA_PATH}")
22
+
23
+ if not os.path.exists(adapter_file):
24
+ print(f"❌ Error: File not found: {adapter_file}")
25
+ exit(1)
26
+
27
+ # Create backups if they don't exist yet
28
+ if not os.path.exists(backup_adapter):
29
+ print("πŸ“¦ Creating backup of safetensors file...")
30
+ shutil.copy2(adapter_file, backup_adapter)
31
+
32
+ if os.path.exists(config_file) and not os.path.exists(backup_config):
33
+ print("πŸ“¦ Creating backup of config file...")
34
+ shutil.copy2(config_file, backup_config)
35
+
36
+ # ==============================================================================
37
+ # 2. PRUNE UNSUPPORTED LAYERS (Weights)
38
+ # ==============================================================================
39
+ print("\nπŸ” Scanning weights for vLLM incompatibility...")
40
+ try:
41
+ tensors = load_file(adapter_file)
42
+ except Exception as e:
43
+ print(f"❌ Critical Error: Could not load safetensors file. It might be corrupt. {e}")
44
+ exit(1)
45
+
46
+ new_tensors = {}
47
+ removed_keys = []
48
+
49
+ # vLLM only supports LoRA on Linear layers (q,k,v,o,gate,up,down).
50
+ # Anything else causes a crash on load.
51
+ FORBIDDEN_KEYWORDS = [
52
+ "lm_head", # The output vocabulary layer
53
+ "embed_tokens", # The input embedding layer
54
+ "layernorm", # Normalization layers
55
+ "norm", # Generic normalization (rms_norm)
56
+ "bias", # Biases (usually not supported in standard vLLM LoRA)
57
+ "rotary_emb" # RoPE embeddings
58
+ ]
59
+
60
+ for key, tensor in tensors.items():
61
+ is_bad = False
62
+ for bad_word in FORBIDDEN_KEYWORDS:
63
+ if bad_word in key:
64
+ is_bad = True
65
+ removed_keys.append(key)
66
+ break
67
+
68
+ if not is_bad:
69
+ new_tensors[key] = tensor
70
+
71
+ if len(removed_keys) > 0:
72
+ print(f"βœ‚οΈ Found {len(removed_keys)} unsupported layers.")
73
+ print(f" (Examples: {removed_keys[:3]} ...)")
74
+ print(" Pruning them now...")
75
+ save_file(new_tensors, adapter_file)
76
+ print("βœ… Weights file updated and saved.")
77
+ else:
78
+ print("βœ… Weights file was already clean.")
79
+
80
+ # ==============================================================================
81
+ # 3. FIX CONFIGURATION (JSON)
82
+ # ==============================================================================
83
+ print("\nπŸ” Checking adapter_config.json...")
84
+
85
+ if os.path.exists(config_file):
86
+ with open(config_file, 'r') as f:
87
+ config = json.load(f)
88
+
89
+ changed = False
90
+
91
+ # Fix 1: modules_to_save must be null
92
+ if config.get("modules_to_save") is not None:
93
+ print(" - Setting 'modules_to_save' to null (was set)")
94
+ config["modules_to_save"] = None
95
+ changed = True
96
+
97
+ # Fix 2: Ensure target modules list is clean (optional but good practice)
98
+ # Sometimes extractors put 'lm_head' in target_modules too
99
+ if "target_modules" in config and isinstance(config["target_modules"], list):
100
+ original_len = len(config["target_modules"])
101
+ config["target_modules"] = [
102
+ m for m in config["target_modules"]
103
+ if not any(bad in m for bad in ["lm_head", "embed_tokens", "norm"])
104
+ ]
105
+ if len(config["target_modules"]) < original_len:
106
+ print(" - Cleaned 'target_modules' list")
107
+ changed = True
108
+
109
+ if changed:
110
+ with open(config_file, 'w') as f:
111
+ json.dump(config, f, indent=2)
112
+ print("βœ… Config file updated.")
113
+ else:
114
+ print("βœ… Config file was already correct.")
115
+ else:
116
+ print("⚠️ Warning: adapter_config.json not found!")
117
+
118
+ # ==============================================================================
119
+ # 4. FINAL VERIFICATION
120
+ # ==============================================================================
121
+ print("\n----- VERIFICATION -----")
122
+ try:
123
+ # 1. Check file size
124
+ size_mb = os.path.getsize(adapter_file) / (1024 * 1024)
125
+ print(f"File Size: {size_mb:.2f} MB")
126
+
127
+ # 2. Check Loadability
128
+ test_load = load_file(adapter_file)
129
+ print(f"Keys Remaining: {len(test_load)}")
130
+
131
+ # 3. Check for stragglers
132
+ stragglers = [k for k in test_load.keys() if "lm_head" in k or "norm" in k]
133
+ if stragglers:
134
+ print(f"❌ FAILURE: Still found bad keys: {stragglers}")
135
+ else:
136
+ print("πŸ† SUCCESS: LoRA is clean and vLLM-ready.")
137
+
138
+ except Exception as e:
139
+ print(f"❌ FAILURE: File seems corrupted: {e}")
140
+
141
+ print("==============================================================================")
142
+ print("You can now submit your SBATCH script.")