| | import torch |
| | from safetensors.torch import load_file, save_file |
| | import os |
| | import json |
| | import shutil |
| |
|
| | |
| | |
| | |
| | |
| | LORA_PATH = "/projects/extern/kisski/kisski-narges-llm-interactive/dir.project/hasan/uni_work/biomni_integration/Biomni/brain_surgery/lora_extraction_results/dequantized_corrected_lora_rank_256" |
| |
|
| | |
| | |
| | |
| | adapter_file = os.path.join(LORA_PATH, "adapter_model.safetensors") |
| | config_file = os.path.join(LORA_PATH, "adapter_config.json") |
| | backup_adapter = adapter_file + ".original.bak" |
| | backup_config = config_file + ".original.bak" |
| |
|
| | print(f"π§ STARTING REPAIR ON: {LORA_PATH}") |
| |
|
| | if not os.path.exists(adapter_file): |
| | print(f"β Error: File not found: {adapter_file}") |
| | exit(1) |
| |
|
| | |
| | if not os.path.exists(backup_adapter): |
| | print("π¦ Creating backup of safetensors file...") |
| | shutil.copy2(adapter_file, backup_adapter) |
| |
|
| | if os.path.exists(config_file) and not os.path.exists(backup_config): |
| | print("π¦ Creating backup of config file...") |
| | shutil.copy2(config_file, backup_config) |
| |
|
| | |
| | |
| | |
| | print("\nπ Scanning weights for vLLM incompatibility...") |
| | try: |
| | tensors = load_file(adapter_file) |
| | except Exception as e: |
| | print(f"β Critical Error: Could not load safetensors file. It might be corrupt. {e}") |
| | exit(1) |
| |
|
| | new_tensors = {} |
| | removed_keys = [] |
| |
|
| | |
| | |
| | FORBIDDEN_KEYWORDS = [ |
| | "lm_head", |
| | "embed_tokens", |
| | "layernorm", |
| | "norm", |
| | "bias", |
| | "rotary_emb" |
| | ] |
| |
|
| | for key, tensor in tensors.items(): |
| | is_bad = False |
| | for bad_word in FORBIDDEN_KEYWORDS: |
| | if bad_word in key: |
| | is_bad = True |
| | removed_keys.append(key) |
| | break |
| | |
| | if not is_bad: |
| | new_tensors[key] = tensor |
| |
|
| | if len(removed_keys) > 0: |
| | print(f"βοΈ Found {len(removed_keys)} unsupported layers.") |
| | print(f" (Examples: {removed_keys[:3]} ...)") |
| | print(" Pruning them now...") |
| | save_file(new_tensors, adapter_file) |
| | print("β
Weights file updated and saved.") |
| | else: |
| | print("β
Weights file was already clean.") |
| |
|
| | |
| | |
| | |
| | print("\nπ Checking adapter_config.json...") |
| |
|
| | if os.path.exists(config_file): |
| | with open(config_file, 'r') as f: |
| | config = json.load(f) |
| | |
| | changed = False |
| | |
| | |
| | if config.get("modules_to_save") is not None: |
| | print(" - Setting 'modules_to_save' to null (was set)") |
| | config["modules_to_save"] = None |
| | changed = True |
| | |
| | |
| | |
| | if "target_modules" in config and isinstance(config["target_modules"], list): |
| | original_len = len(config["target_modules"]) |
| | config["target_modules"] = [ |
| | m for m in config["target_modules"] |
| | if not any(bad in m for bad in ["lm_head", "embed_tokens", "norm"]) |
| | ] |
| | if len(config["target_modules"]) < original_len: |
| | print(" - Cleaned 'target_modules' list") |
| | changed = True |
| |
|
| | if changed: |
| | with open(config_file, 'w') as f: |
| | json.dump(config, f, indent=2) |
| | print("β
Config file updated.") |
| | else: |
| | print("β
Config file was already correct.") |
| | else: |
| | print("β οΈ Warning: adapter_config.json not found!") |
| |
|
| | |
| | |
| | |
| | print("\n----- VERIFICATION -----") |
| | try: |
| | |
| | size_mb = os.path.getsize(adapter_file) / (1024 * 1024) |
| | print(f"File Size: {size_mb:.2f} MB") |
| | |
| | |
| | test_load = load_file(adapter_file) |
| | print(f"Keys Remaining: {len(test_load)}") |
| | |
| | |
| | stragglers = [k for k in test_load.keys() if "lm_head" in k or "norm" in k] |
| | if stragglers: |
| | print(f"β FAILURE: Still found bad keys: {stragglers}") |
| | else: |
| | print("π SUCCESS: LoRA is clean and vLLM-ready.") |
| |
|
| | except Exception as e: |
| | print(f"β FAILURE: File seems corrupted: {e}") |
| |
|
| | print("==============================================================================") |
| | print("You can now submit your SBATCH script.") |