Upload 141 files
Browse files- hugging/patch_gpu.py +1039 -0
- hugging/td_fuse/config.py +1 -1
- hugging/td_fuse/heal.py +45 -28
- hugging/td_fuse/merge.py +3 -0
- hugging/td_fuse/transport.py +189 -12
- hugging/td_start.td +1 -1
hugging/patch_gpu.py
ADDED
|
@@ -0,0 +1,1039 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPU Patch Script — Apply neuron permutation fix + lower MiMo alpha.
|
| 3 |
+
Run this ON THE GPU after cd /workspace/td_toolkit/hugging:
|
| 4 |
+
python3 patch_gpu.py
|
| 5 |
+
|
| 6 |
+
What it does:
|
| 7 |
+
1. Adds neuron permutation to transport.py fast path
|
| 8 |
+
2. Adds _greedy_permutation() and _apply_permutation() helpers
|
| 9 |
+
3. Updates fuse_weights() to apply permutations before blending
|
| 10 |
+
4. Lowers MiMo alpha from 0.4 to 0.15 in config.py
|
| 11 |
+
5. Lowers MiMo strength from 0.4 to 0.15 in td_start.td
|
| 12 |
+
6. Adds torch import fix to heal.py (Bug #41)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
def patch_file(filepath, old, new):
|
| 18 |
+
"""Replace old text with new text in a file."""
|
| 19 |
+
with open(filepath, 'r') as f:
|
| 20 |
+
content = f.read()
|
| 21 |
+
if old not in content:
|
| 22 |
+
print(f" WARNING: patch target not found in {filepath}")
|
| 23 |
+
print(f" Looking for: {old[:80]}...")
|
| 24 |
+
return False
|
| 25 |
+
content = content.replace(old, new)
|
| 26 |
+
with open(filepath, 'w') as f:
|
| 27 |
+
f.write(content)
|
| 28 |
+
print(f" PATCHED: {filepath}")
|
| 29 |
+
return True
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def main():
|
| 33 |
+
print("=" * 60)
|
| 34 |
+
print("TD GPU Patch — Neuron Permutation Fix")
|
| 35 |
+
print("=" * 60)
|
| 36 |
+
|
| 37 |
+
# ================================================================
|
| 38 |
+
# PATCH 1: config.py — Lower MiMo alpha
|
| 39 |
+
# ================================================================
|
| 40 |
+
print("\n[1/4] Patching config.py (MiMo alpha 0.4 → 0.15)...")
|
| 41 |
+
patch_file(
|
| 42 |
+
"td_fuse/config.py",
|
| 43 |
+
'merge_alpha=0.4,',
|
| 44 |
+
'merge_alpha=0.15,',
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# ================================================================
|
| 48 |
+
# PATCH 2: td_start.td — Lower MiMo strength
|
| 49 |
+
# ================================================================
|
| 50 |
+
print("\n[2/4] Patching td_start.td (strength 0.4 → 0.15)...")
|
| 51 |
+
patch_file(
|
| 52 |
+
"td_start.td",
|
| 53 |
+
'strength 0.4',
|
| 54 |
+
'strength 0.15',
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# ================================================================
|
| 58 |
+
# PATCH 3: heal.py — Add missing torch import (Bug #41)
|
| 59 |
+
# ================================================================
|
| 60 |
+
print("\n[3/4] Patching heal.py (torch import fix)...")
|
| 61 |
+
# Check if already fixed
|
| 62 |
+
with open("td_fuse/heal.py", 'r') as f:
|
| 63 |
+
heal_content = f.read()
|
| 64 |
+
if "def apply_qlora_standard" in heal_content:
|
| 65 |
+
# Find the function and check if torch import exists after it
|
| 66 |
+
idx = heal_content.find("def apply_qlora_standard")
|
| 67 |
+
next_lines = heal_content[idx:idx+500]
|
| 68 |
+
if "import torch" not in next_lines[:200]:
|
| 69 |
+
# Add import torch after the function's docstring/imports
|
| 70 |
+
patch_file(
|
| 71 |
+
"td_fuse/heal.py",
|
| 72 |
+
"from peft import get_peft_model, LoraConfig, TaskType\n",
|
| 73 |
+
"from peft import get_peft_model, LoraConfig, TaskType\n import torch\n",
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
print(" Already patched (torch import exists)")
|
| 77 |
+
else:
|
| 78 |
+
print(" WARNING: apply_qlora_standard not found in heal.py")
|
| 79 |
+
|
| 80 |
+
# ================================================================
|
| 81 |
+
# PATCH 4: transport.py — Full rewrite with neuron permutation
|
| 82 |
+
# ================================================================
|
| 83 |
+
print("\n[4/4] Rewriting transport.py with neuron permutation...")
|
| 84 |
+
write_transport_py()
|
| 85 |
+
print(" WROTE: td_fuse/transport.py")
|
| 86 |
+
|
| 87 |
+
print("\n" + "=" * 60)
|
| 88 |
+
print("ALL PATCHES APPLIED!")
|
| 89 |
+
print("=" * 60)
|
| 90 |
+
print("\nWhat changed:")
|
| 91 |
+
print(" • MiMo merge alpha: 0.4 → 0.15 (gentler blend)")
|
| 92 |
+
print(" • Neuron permutation: MiMo's neurons get reorganised to match Qwen3")
|
| 93 |
+
print(" • heal.py: torch import fix (Bug #41)")
|
| 94 |
+
print("\nNow run the pipeline:")
|
| 95 |
+
print(" export PYTHONPATH=$(pwd)")
|
| 96 |
+
print(" python3 -m td_lang run td_start.td")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def write_transport_py():
|
| 100 |
+
"""Write the complete updated transport.py with neuron permutation."""
|
| 101 |
+
code = '''\
|
| 102 |
+
"""
|
| 103 |
+
Transport and Merge Wrapper — interfaces with official T&M code.
|
| 104 |
+
|
| 105 |
+
This wraps the official repo at:
|
| 106 |
+
github.com/chenhangcuisg-code/Cross-Architecture-Merging-for-Large-Language-Models/
|
| 107 |
+
|
| 108 |
+
We use THEIR code for:
|
| 109 |
+
- Correlation distance computation (corr_distance_matrix)
|
| 110 |
+
- Streaming Sinkhorn (sinkhorn_uniform_streaming)
|
| 111 |
+
- Transport plan computation (compute_P, compute_Q_and_layer_costs)
|
| 112 |
+
- Activation reconstruction (reconstruct_X)
|
| 113 |
+
|
| 114 |
+
We add:
|
| 115 |
+
- Qwen3 thinking mode protection
|
| 116 |
+
- MiMo MTP head handling
|
| 117 |
+
- Falcon SSM component handling
|
| 118 |
+
- Neuron permutation for scrambled models (MiMo)
|
| 119 |
+
- Sequential merge protection (MagMax + orthogonal projection)
|
| 120 |
+
- Progress reporting every 5 minutes
|
| 121 |
+
- Timeouts to prevent infinite hangs
|
| 122 |
+
|
| 123 |
+
Findings: #01, #07, #24
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
import sys
|
| 127 |
+
import time
|
| 128 |
+
import torch
|
| 129 |
+
import numpy as np
|
| 130 |
+
from pathlib import Path
|
| 131 |
+
from typing import Optional
|
| 132 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 133 |
+
from datasets import load_dataset
|
| 134 |
+
|
| 135 |
+
from .config import MergeConfig, ModelConfig, TARGET
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# ============================================================================
|
| 139 |
+
# PROGRESS TRACKER — prints status every 5 minutes so you know it's alive
|
| 140 |
+
# ============================================================================
|
| 141 |
+
|
| 142 |
+
class ProgressTracker:
|
| 143 |
+
"""Prints a heartbeat every interval_seconds so you know it's not stuck."""
|
| 144 |
+
|
| 145 |
+
def __init__(self, task_name: str, interval_seconds: int = 300):
|
| 146 |
+
self.task_name = task_name
|
| 147 |
+
self.interval = interval_seconds
|
| 148 |
+
self.start_time = time.time()
|
| 149 |
+
self.last_report = self.start_time
|
| 150 |
+
self.step = 0
|
| 151 |
+
self.total_steps = 0
|
| 152 |
+
print(f"\\n[{task_name}] Started at {time.strftime(\'%H:%M:%S\')}")
|
| 153 |
+
|
| 154 |
+
def set_total(self, total: int):
|
| 155 |
+
self.total_steps = total
|
| 156 |
+
|
| 157 |
+
def tick(self, step_name: str = ""):
|
| 158 |
+
"""Call this inside loops. Prints progress if 5 min have passed."""
|
| 159 |
+
self.step += 1
|
| 160 |
+
now = time.time()
|
| 161 |
+
elapsed = now - self.start_time
|
| 162 |
+
since_last = now - self.last_report
|
| 163 |
+
|
| 164 |
+
if since_last >= self.interval:
|
| 165 |
+
pct = f"{self.step}/{self.total_steps} ({100*self.step/self.total_steps:.0f}%)" if self.total_steps else f"step {self.step}"
|
| 166 |
+
eta = ""
|
| 167 |
+
if self.total_steps and self.step > 0:
|
| 168 |
+
rate = elapsed / self.step
|
| 169 |
+
remaining = (self.total_steps - self.step) * rate
|
| 170 |
+
eta = f", ETA {remaining/60:.1f} min"
|
| 171 |
+
print(f"[{self.task_name}] HEARTBEAT — {pct}, elapsed {elapsed/60:.1f} min{eta} | {step_name}")
|
| 172 |
+
sys.stdout.flush()
|
| 173 |
+
self.last_report = now
|
| 174 |
+
|
| 175 |
+
def done(self):
|
| 176 |
+
elapsed = time.time() - self.start_time
|
| 177 |
+
print(f"[{self.task_name}] Completed in {elapsed/60:.1f} min ({elapsed:.0f}s)")
|
| 178 |
+
sys.stdout.flush()
|
| 179 |
+
|
| 180 |
+
def check_timeout(self, timeout_seconds: int = 3600):
|
| 181 |
+
"""Raise if we've been running longer than timeout_seconds."""
|
| 182 |
+
elapsed = time.time() - self.start_time
|
| 183 |
+
if elapsed > timeout_seconds:
|
| 184 |
+
raise TimeoutError(
|
| 185 |
+
f"[{self.task_name}] TIMEOUT after {elapsed/60:.1f} min "
|
| 186 |
+
f"(limit: {timeout_seconds/60:.0f} min). Something is wrong."
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def setup_tm_repo(cfg: MergeConfig):
|
| 191 |
+
"""Add official T&M repo to Python path so we can import their code."""
|
| 192 |
+
repo_path = Path(cfg.tm_repo_path)
|
| 193 |
+
core_path = repo_path / "core"
|
| 194 |
+
|
| 195 |
+
if not core_path.exists():
|
| 196 |
+
raise FileNotFoundError(
|
| 197 |
+
f"Official T&M repo not found at {repo_path}\\n"
|
| 198 |
+
f"Please clone it:\\n"
|
| 199 |
+
f" git clone https://github.com/chenhangcuisg-code/"
|
| 200 |
+
f"Cross-Architecture-Merging-for-Large-Language-Models.git"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Add to path so we can import hot_transport etc.
|
| 204 |
+
if str(core_path) not in sys.path:
|
| 205 |
+
sys.path.insert(0, str(core_path))
|
| 206 |
+
print(f"[transport] Added T&M core to path: {core_path}")
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
| 210 |
+
"""
|
| 211 |
+
Load calibration data for activation extraction.
|
| 212 |
+
|
| 213 |
+
Mix: 600 Pile general + 300 Pile ArXiv + 600 neuralmagic Q&A = 1500 samples
|
| 214 |
+
Each sample truncated to cfg.calibration_seq_len tokens.
|
| 215 |
+
|
| 216 |
+
Findings: #08
|
| 217 |
+
"""
|
| 218 |
+
tracker = ProgressTracker("calibration-data", interval_seconds=120)
|
| 219 |
+
print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
|
| 220 |
+
|
| 221 |
+
samples = []
|
| 222 |
+
|
| 223 |
+
# --- Pile: general text (600 samples) ---
|
| 224 |
+
try:
|
| 225 |
+
pile = load_dataset(
|
| 226 |
+
cfg.calibration_dataset_pile,
|
| 227 |
+
split="validation",
|
| 228 |
+
streaming=True,
|
| 229 |
+
trust_remote_code=True,
|
| 230 |
+
)
|
| 231 |
+
count = 0
|
| 232 |
+
for example in pile:
|
| 233 |
+
if count >= 600:
|
| 234 |
+
break
|
| 235 |
+
text = example.get("text", "")
|
| 236 |
+
if len(text) > 100: # Skip very short texts
|
| 237 |
+
tokens = tokenizer(
|
| 238 |
+
text,
|
| 239 |
+
truncation=True,
|
| 240 |
+
max_length=cfg.calibration_seq_len,
|
| 241 |
+
return_tensors="pt",
|
| 242 |
+
)
|
| 243 |
+
samples.append(tokens)
|
| 244 |
+
count += 1
|
| 245 |
+
if count % 100 == 0:
|
| 246 |
+
print(f" Pile: {count}/600 samples loaded...")
|
| 247 |
+
sys.stdout.flush()
|
| 248 |
+
print(f" Pile general: {count} samples")
|
| 249 |
+
except Exception as e:
|
| 250 |
+
print(f" WARNING: Pile failed: {e}")
|
| 251 |
+
print(f" Falling back to neuralmagic only")
|
| 252 |
+
|
| 253 |
+
# --- neuralmagic: Q&A calibration (up to remaining) ---
|
| 254 |
+
remaining = cfg.calibration_samples - len(samples)
|
| 255 |
+
if remaining > 0:
|
| 256 |
+
try:
|
| 257 |
+
nm = load_dataset(
|
| 258 |
+
cfg.calibration_dataset_nm,
|
| 259 |
+
split="train",
|
| 260 |
+
trust_remote_code=True,
|
| 261 |
+
)
|
| 262 |
+
count = 0
|
| 263 |
+
for example in nm:
|
| 264 |
+
if count >= remaining:
|
| 265 |
+
break
|
| 266 |
+
text = example.get("text", example.get("content", ""))
|
| 267 |
+
if len(str(text)) > 50:
|
| 268 |
+
tokens = tokenizer(
|
| 269 |
+
str(text),
|
| 270 |
+
truncation=True,
|
| 271 |
+
max_length=cfg.calibration_seq_len,
|
| 272 |
+
return_tensors="pt",
|
| 273 |
+
)
|
| 274 |
+
samples.append(tokens)
|
| 275 |
+
count += 1
|
| 276 |
+
if count % 100 == 0:
|
| 277 |
+
print(f" neuralmagic: {count}/{remaining} samples loaded...")
|
| 278 |
+
sys.stdout.flush()
|
| 279 |
+
print(f" neuralmagic: {count} samples")
|
| 280 |
+
except Exception as e:
|
| 281 |
+
print(f" WARNING: neuralmagic failed: {e}")
|
| 282 |
+
|
| 283 |
+
tracker.done()
|
| 284 |
+
print(f"[transport] Total calibration samples: {len(samples)}")
|
| 285 |
+
sys.stdout.flush()
|
| 286 |
+
return samples
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def extract_activations(
|
| 290 |
+
model: AutoModelForCausalLM,
|
| 291 |
+
calibration_data: list,
|
| 292 |
+
device: str = "cuda",
|
| 293 |
+
) -> dict:
|
| 294 |
+
"""
|
| 295 |
+
Extract intermediate activations from each layer of a model.
|
| 296 |
+
|
| 297 |
+
Runs calibration data through the model with hooks on each layer
|
| 298 |
+
to capture activation patterns. These activations are what the
|
| 299 |
+
optimal transport algorithm aligns between source and target.
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
Dict mapping layer_name -> activation tensor [num_samples, hidden_dim]
|
| 303 |
+
"""
|
| 304 |
+
tracker = ProgressTracker("extract-activations", interval_seconds=300)
|
| 305 |
+
tracker.set_total(len(calibration_data))
|
| 306 |
+
print(f"[transport] Extracting activations from {len(calibration_data)} samples...")
|
| 307 |
+
sys.stdout.flush()
|
| 308 |
+
|
| 309 |
+
activations = {}
|
| 310 |
+
hooks = []
|
| 311 |
+
|
| 312 |
+
# Register hooks on each transformer layer
|
| 313 |
+
for name, module in model.named_modules():
|
| 314 |
+
if hasattr(module, "self_attn") or name.endswith(".mlp"):
|
| 315 |
+
# Hook to capture output activations
|
| 316 |
+
def make_hook(layer_name):
|
| 317 |
+
def hook_fn(module, input, output):
|
| 318 |
+
# Handle tuple outputs (some layers return tuples)
|
| 319 |
+
if isinstance(output, tuple):
|
| 320 |
+
act = output[0]
|
| 321 |
+
else:
|
| 322 |
+
act = output
|
| 323 |
+
if layer_name not in activations:
|
| 324 |
+
activations[layer_name] = []
|
| 325 |
+
# Mean pool over sequence length -> [hidden_dim]
|
| 326 |
+
activations[layer_name].append(
|
| 327 |
+
act.detach().float().mean(dim=1).cpu()
|
| 328 |
+
)
|
| 329 |
+
return hook_fn
|
| 330 |
+
|
| 331 |
+
h = module.register_forward_hook(make_hook(name))
|
| 332 |
+
hooks.append(h)
|
| 333 |
+
|
| 334 |
+
# Forward pass on calibration data
|
| 335 |
+
model.eval()
|
| 336 |
+
with torch.no_grad():
|
| 337 |
+
for i, tokens in enumerate(calibration_data):
|
| 338 |
+
inputs = {k: v.to(device) for k, v in tokens.items()}
|
| 339 |
+
try:
|
| 340 |
+
model(**inputs)
|
| 341 |
+
except Exception as e:
|
| 342 |
+
print(f" WARNING: Sample {i} failed: {e}")
|
| 343 |
+
continue
|
| 344 |
+
|
| 345 |
+
tracker.tick(f"sample {i+1}")
|
| 346 |
+
|
| 347 |
+
if (i + 1) % 100 == 0:
|
| 348 |
+
print(f" Processed {i + 1}/{len(calibration_data)} samples")
|
| 349 |
+
sys.stdout.flush()
|
| 350 |
+
|
| 351 |
+
# Timeout: 30 min for activation extraction
|
| 352 |
+
tracker.check_timeout(timeout_seconds=1800)
|
| 353 |
+
|
| 354 |
+
# Remove hooks
|
| 355 |
+
for h in hooks:
|
| 356 |
+
h.remove()
|
| 357 |
+
|
| 358 |
+
# Stack activations: [num_samples, hidden_dim]
|
| 359 |
+
layer_count = 0
|
| 360 |
+
for key in activations:
|
| 361 |
+
activations[key] = torch.cat(activations[key], dim=0)
|
| 362 |
+
layer_count += 1
|
| 363 |
+
|
| 364 |
+
print(f" Extracted {layer_count} layers, shapes: {activations[list(activations.keys())[0]].shape if activations else \'empty\'}")
|
| 365 |
+
tracker.done()
|
| 366 |
+
sys.stdout.flush()
|
| 367 |
+
|
| 368 |
+
return activations
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def compute_transport_plans(
|
| 372 |
+
source_activations: dict,
|
| 373 |
+
target_activations: dict,
|
| 374 |
+
cfg: MergeConfig,
|
| 375 |
+
) -> dict:
|
| 376 |
+
"""
|
| 377 |
+
Compute optimal transport plans between source and target activations.
|
| 378 |
+
|
| 379 |
+
This is where the magic happens. We use the official T&M code's:
|
| 380 |
+
- corr_distance_matrix: correlation distance between activation vectors
|
| 381 |
+
- sinkhorn_uniform_streaming: memory-efficient Sinkhorn solver
|
| 382 |
+
- compute_P: layer-level coupling (which source layers -> which target layers)
|
| 383 |
+
- compute_Q_and_layer_costs: neuron-level coupling within each layer pair
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
Dict with 'P' (layer coupling) and 'Q' (per-layer neuron coupling) matrices
|
| 387 |
+
"""
|
| 388 |
+
print("[transport] Computing transport plans...")
|
| 389 |
+
sys.stdout.flush()
|
| 390 |
+
|
| 391 |
+
try:
|
| 392 |
+
# Try importing official T&M code
|
| 393 |
+
from hot_transport import (
|
| 394 |
+
corr_distance_matrix,
|
| 395 |
+
sinkhorn_uniform_streaming,
|
| 396 |
+
compute_P,
|
| 397 |
+
compute_Q_and_layer_costs,
|
| 398 |
+
)
|
| 399 |
+
print("[transport] Using official T&M implementation")
|
| 400 |
+
return _compute_plans_official(
|
| 401 |
+
source_activations, target_activations, cfg,
|
| 402 |
+
corr_distance_matrix, sinkhorn_uniform_streaming,
|
| 403 |
+
compute_P, compute_Q_and_layer_costs,
|
| 404 |
+
)
|
| 405 |
+
except ImportError:
|
| 406 |
+
print("[transport] Official T&M code not available, using fallback")
|
| 407 |
+
return _compute_plans_fallback(
|
| 408 |
+
source_activations, target_activations, cfg
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def _compute_plans_official(
|
| 413 |
+
source_act, target_act, cfg,
|
| 414 |
+
corr_distance_matrix, sinkhorn_uniform_streaming,
|
| 415 |
+
compute_P, compute_Q_and_layer_costs,
|
| 416 |
+
) -> dict:
|
| 417 |
+
"""Use the official T&M code to compute transport plans."""
|
| 418 |
+
|
| 419 |
+
# Get matching layer pairs
|
| 420 |
+
source_layers = sorted(source_act.keys())
|
| 421 |
+
target_layers = sorted(target_act.keys())
|
| 422 |
+
|
| 423 |
+
# Compute Q matrices (neuron-level) and layer costs
|
| 424 |
+
Q_matrices, layer_costs = compute_Q_and_layer_costs(
|
| 425 |
+
source_act, target_act,
|
| 426 |
+
source_layers, target_layers,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# Compute P matrix (layer-level coupling)
|
| 430 |
+
P = compute_P(layer_costs)
|
| 431 |
+
|
| 432 |
+
return {
|
| 433 |
+
"P": P,
|
| 434 |
+
"Q": Q_matrices,
|
| 435 |
+
"source_layers": source_layers,
|
| 436 |
+
"target_layers": target_layers,
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def _compute_plans_fallback(
|
| 441 |
+
source_act: dict,
|
| 442 |
+
target_act: dict,
|
| 443 |
+
cfg: MergeConfig,
|
| 444 |
+
) -> dict:
|
| 445 |
+
"""
|
| 446 |
+
Fallback transport plan computation when official code isn't available.
|
| 447 |
+
|
| 448 |
+
Smart routing:
|
| 449 |
+
- Same-architecture models (same layer count): direct 1:1 layer matching
|
| 450 |
+
Check if neurons are aligned (DeepSeek) or scrambled (MiMo)
|
| 451 |
+
- Cross-architecture: sparse OT (only top-3 source layers per target)
|
| 452 |
+
"""
|
| 453 |
+
tracker = ProgressTracker("transport-plans", interval_seconds=300)
|
| 454 |
+
|
| 455 |
+
source_layers = sorted(source_act.keys())
|
| 456 |
+
target_layers = sorted(target_act.keys())
|
| 457 |
+
|
| 458 |
+
n_source = len(source_layers)
|
| 459 |
+
n_target = len(target_layers)
|
| 460 |
+
|
| 461 |
+
print(f"[transport] Source layers: {n_source}, Target layers: {n_target}")
|
| 462 |
+
sys.stdout.flush()
|
| 463 |
+
|
| 464 |
+
# --- FAST PATH: same architecture (same layer count) ---
|
| 465 |
+
# Both models have the same number of transformer layers
|
| 466 |
+
# Match layers 1:1 but CHECK if neurons correspond
|
| 467 |
+
# DeepSeek: same training base -> neurons aligned -> identity Q (fast)
|
| 468 |
+
# MiMo: different training -> neurons scrambled -> need Sinkhorn permutation
|
| 469 |
+
if n_source == n_target:
|
| 470 |
+
print("[transport] Same layer count -- using direct 1:1 layer matching")
|
| 471 |
+
sys.stdout.flush()
|
| 472 |
+
Q_matrices = {}
|
| 473 |
+
permutations = {} # layer_pair -> permutation array (neuron reordering)
|
| 474 |
+
P = np.eye(n_source) / n_source # Identity coupling
|
| 475 |
+
tracker.set_total(n_source)
|
| 476 |
+
|
| 477 |
+
# Check first layer to decide: are neurons aligned or scrambled?
|
| 478 |
+
first_sl = source_layers[0]
|
| 479 |
+
first_tl = target_layers[0]
|
| 480 |
+
S0 = source_act[first_sl].numpy()
|
| 481 |
+
T0 = target_act[first_tl].numpy()
|
| 482 |
+
if S0.shape[1] == T0.shape[1]:
|
| 483 |
+
S0_norm = (S0 - S0.mean(0)) / (S0.std(0) + 1e-8)
|
| 484 |
+
T0_norm = (T0 - T0.mean(0)) / (T0.std(0) + 1e-8)
|
| 485 |
+
diag_corr = np.mean(np.sum(S0_norm * T0_norm, axis=0) / S0.shape[0])
|
| 486 |
+
neurons_aligned = diag_corr > 0.3
|
| 487 |
+
else:
|
| 488 |
+
neurons_aligned = False
|
| 489 |
+
|
| 490 |
+
if neurons_aligned:
|
| 491 |
+
print(f"[transport] Neurons ARE aligned (diag_corr={diag_corr:.3f}) -- identity Q (fast)")
|
| 492 |
+
print("[transport] This should take under 1 minute...")
|
| 493 |
+
else:
|
| 494 |
+
corr_val = diag_corr if S0.shape[1] == T0.shape[1] else 0.0
|
| 495 |
+
print(f"[transport] Neurons NOT aligned (diag_corr={corr_val:.3f}) -- computing permutations via Sinkhorn")
|
| 496 |
+
print("[transport] This may take 2-5 minutes...")
|
| 497 |
+
sys.stdout.flush()
|
| 498 |
+
|
| 499 |
+
for i, (sl, tl) in enumerate(zip(source_layers, target_layers)):
|
| 500 |
+
S = source_act[sl].numpy()
|
| 501 |
+
T = target_act[tl].numpy()
|
| 502 |
+
|
| 503 |
+
if S.shape[1] == T.shape[1]:
|
| 504 |
+
if neurons_aligned:
|
| 505 |
+
# Neurons already correspond (e.g. DeepSeek) -- identity Q
|
| 506 |
+
Q_matrices[(sl, tl)] = np.eye(S.shape[1]) / S.shape[1]
|
| 507 |
+
else:
|
| 508 |
+
# Neurons are SCRAMBLED (e.g. MiMo) -- find the permutation
|
| 509 |
+
# 1. Compute correlation matrix between source and target neurons
|
| 510 |
+
S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8)
|
| 511 |
+
T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8)
|
| 512 |
+
corr = S_norm.T @ T_norm / S.shape[0] # [hidden_dim, hidden_dim]
|
| 513 |
+
|
| 514 |
+
# 2. Run Sinkhorn on cost matrix to get soft transport plan
|
| 515 |
+
cost = 1.0 - corr
|
| 516 |
+
Q_soft = _sinkhorn(cost, reg=0.05, max_iter=cfg.sinkhorn_max_iter)
|
| 517 |
+
|
| 518 |
+
# 3. Extract hard permutation: for each source neuron, which target neuron?
|
| 519 |
+
perm = np.argmax(Q_soft, axis=1) # source_neuron -> target_neuron
|
| 520 |
+
|
| 521 |
+
# 4. Check for duplicate assignments (Sinkhorn should avoid this, but be safe)
|
| 522 |
+
if len(set(perm)) < len(perm) * 0.9:
|
| 523 |
+
# Too many collisions -- fall back to Hungarian-style greedy
|
| 524 |
+
perm = _greedy_permutation(corr)
|
| 525 |
+
|
| 526 |
+
permutations[(sl, tl)] = perm
|
| 527 |
+
Q_matrices[(sl, tl)] = Q_soft
|
| 528 |
+
else:
|
| 529 |
+
# Different dims -- do lightweight Sinkhorn on this pair only
|
| 530 |
+
print(f" Layer {i}: dim mismatch ({S.shape[1]} vs {T.shape[1]}), using Sinkhorn...")
|
| 531 |
+
S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8)
|
| 532 |
+
T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8)
|
| 533 |
+
corr = S_norm.T @ T_norm / S.shape[0]
|
| 534 |
+
cost = 1.0 - corr
|
| 535 |
+
Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50)
|
| 536 |
+
|
| 537 |
+
tracker.tick(f"{sl} -> {tl}")
|
| 538 |
+
|
| 539 |
+
if (i + 1) % 10 == 0 or i == 0:
|
| 540 |
+
print(f" Matched layer {i + 1}/{n_source}: {sl} -> {tl}")
|
| 541 |
+
sys.stdout.flush()
|
| 542 |
+
|
| 543 |
+
# Timeout: 15 min (permutation takes longer than identity)
|
| 544 |
+
tracker.check_timeout(timeout_seconds=900)
|
| 545 |
+
|
| 546 |
+
if permutations:
|
| 547 |
+
print(f"[transport] Computed {len(permutations)} neuron permutations")
|
| 548 |
+
print(f"[transport] Direct matching complete: {n_source} layer pairs")
|
| 549 |
+
tracker.done()
|
| 550 |
+
sys.stdout.flush()
|
| 551 |
+
return {
|
| 552 |
+
"P": P,
|
| 553 |
+
"Q": Q_matrices,
|
| 554 |
+
"permutations": permutations,
|
| 555 |
+
"source_layers": source_layers,
|
| 556 |
+
"target_layers": target_layers,
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
# --- CROSS-ARCHITECTURE PATH: sparse OT ---
|
| 560 |
+
# Only compute top-3 source layers per target (not all NxN pairs)
|
| 561 |
+
print(f"[transport] Cross-architecture -- using sparse OT (top-3 per target)")
|
| 562 |
+
print(f"[transport] Estimated time: 5-15 minutes")
|
| 563 |
+
sys.stdout.flush()
|
| 564 |
+
|
| 565 |
+
# Step 1: Compute layer-level similarity (cheap: just mean activation correlation)
|
| 566 |
+
print("[transport] Step 1/3: Computing layer-level similarities...")
|
| 567 |
+
sys.stdout.flush()
|
| 568 |
+
layer_costs = np.zeros((n_source, n_target))
|
| 569 |
+
tracker.set_total(n_source * n_target + n_target * 3)
|
| 570 |
+
for i, sl in enumerate(source_layers):
|
| 571 |
+
for j, tl in enumerate(target_layers):
|
| 572 |
+
S_mean = source_act[sl].mean(0).numpy()
|
| 573 |
+
T_mean = target_act[tl].mean(0).numpy()
|
| 574 |
+
# Cosine similarity as cheap proxy
|
| 575 |
+
min_dim = min(len(S_mean), len(T_mean))
|
| 576 |
+
s = S_mean[:min_dim]
|
| 577 |
+
t = T_mean[:min_dim]
|
| 578 |
+
sim = np.dot(s, t) / (np.linalg.norm(s) * np.linalg.norm(t) + 1e-8)
|
| 579 |
+
layer_costs[i, j] = 1.0 - sim
|
| 580 |
+
tracker.tick(f"layer sim {i},{j}")
|
| 581 |
+
|
| 582 |
+
# Timeout: 30 min for cross-arch
|
| 583 |
+
tracker.check_timeout(timeout_seconds=1800)
|
| 584 |
+
|
| 585 |
+
print(f"[transport] Step 1/3 done: {n_source}x{n_target} similarities computed")
|
| 586 |
+
sys.stdout.flush()
|
| 587 |
+
|
| 588 |
+
# Step 2: For each target layer, only compute Q for top-3 most similar source layers
|
| 589 |
+
print("[transport] Step 2/3: Computing neuron-level transport (top-3 per target)...")
|
| 590 |
+
sys.stdout.flush()
|
| 591 |
+
Q_matrices = {}
|
| 592 |
+
for j, tl in enumerate(target_layers):
|
| 593 |
+
top3 = np.argsort(layer_costs[:, j])[:3]
|
| 594 |
+
for i in top3:
|
| 595 |
+
sl = source_layers[i]
|
| 596 |
+
S = source_act[sl].numpy()
|
| 597 |
+
T = target_act[tl].numpy()
|
| 598 |
+
|
| 599 |
+
# Lightweight Sinkhorn (50 iterations, not 100+)
|
| 600 |
+
min_dim = min(S.shape[1], T.shape[1])
|
| 601 |
+
S_sub = S[:, :min_dim]
|
| 602 |
+
T_sub = T[:, :min_dim]
|
| 603 |
+
S_norm = (S_sub - S_sub.mean(0)) / (S_sub.std(0) + 1e-8)
|
| 604 |
+
T_norm = (T_sub - T_sub.mean(0)) / (T_sub.std(0) + 1e-8)
|
| 605 |
+
corr = S_norm.T @ T_norm / S.shape[0]
|
| 606 |
+
cost = 1.0 - corr
|
| 607 |
+
Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50)
|
| 608 |
+
tracker.tick(f"Q({sl},{tl})")
|
| 609 |
+
|
| 610 |
+
if (j + 1) % 5 == 0 or j == 0:
|
| 611 |
+
print(f" Target layer {j + 1}/{n_target}: matched to top-3 sources")
|
| 612 |
+
sys.stdout.flush()
|
| 613 |
+
|
| 614 |
+
# Timeout: 30 min for cross-arch
|
| 615 |
+
tracker.check_timeout(timeout_seconds=1800)
|
| 616 |
+
|
| 617 |
+
print(f"[transport] Step 2/3 done: {len(Q_matrices)} Q matrices computed")
|
| 618 |
+
sys.stdout.flush()
|
| 619 |
+
|
| 620 |
+
# Step 3: Layer coupling via Sinkhorn on layer costs
|
| 621 |
+
print("[transport] Step 3/3: Computing layer coupling P matrix...")
|
| 622 |
+
sys.stdout.flush()
|
| 623 |
+
P = _sinkhorn(layer_costs, reg=0.1, max_iter=50)
|
| 624 |
+
|
| 625 |
+
print(f"[transport] Sparse OT complete: {len(Q_matrices)} layer pairs computed")
|
| 626 |
+
tracker.done()
|
| 627 |
+
sys.stdout.flush()
|
| 628 |
+
return {
|
| 629 |
+
"P": P,
|
| 630 |
+
"Q": Q_matrices,
|
| 631 |
+
"permutations": {},
|
| 632 |
+
"source_layers": source_layers,
|
| 633 |
+
"target_layers": target_layers,
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def _sinkhorn(
|
| 638 |
+
cost_matrix: np.ndarray,
|
| 639 |
+
reg: float = 0.05,
|
| 640 |
+
max_iter: int = 100,
|
| 641 |
+
) -> np.ndarray:
|
| 642 |
+
"""
|
| 643 |
+
Basic Sinkhorn-Knopp algorithm for optimal transport.
|
| 644 |
+
|
| 645 |
+
Solves: min <T, C> - reg * H(T)
|
| 646 |
+
where H(T) is the entropy of the transport plan.
|
| 647 |
+
|
| 648 |
+
This is the FALLBACK. The official code uses streaming Sinkhorn
|
| 649 |
+
which is more memory-efficient.
|
| 650 |
+
"""
|
| 651 |
+
n, m = cost_matrix.shape
|
| 652 |
+
K = np.exp(-cost_matrix / reg)
|
| 653 |
+
|
| 654 |
+
u = np.ones(n) / n
|
| 655 |
+
v = np.ones(m) / m
|
| 656 |
+
|
| 657 |
+
for iteration in range(max_iter):
|
| 658 |
+
u = 1.0 / (K @ v + 1e-10)
|
| 659 |
+
v = 1.0 / (K.T @ u + 1e-10)
|
| 660 |
+
|
| 661 |
+
# Transport plan
|
| 662 |
+
T = np.diag(u) @ K @ np.diag(v)
|
| 663 |
+
return T
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
def _greedy_permutation(corr_matrix: np.ndarray) -> np.ndarray:
|
| 667 |
+
"""
|
| 668 |
+
Greedy permutation assignment when Sinkhorn gives duplicate mappings.
|
| 669 |
+
|
| 670 |
+
For each source neuron (in order of strongest match), assign it to the
|
| 671 |
+
best available target neuron that hasn't been taken yet.
|
| 672 |
+
"""
|
| 673 |
+
n = corr_matrix.shape[0]
|
| 674 |
+
perm = np.full(n, -1, dtype=np.int64)
|
| 675 |
+
taken = set()
|
| 676 |
+
|
| 677 |
+
# Process source neurons by strength of their best match (strongest first)
|
| 678 |
+
best_scores = np.max(corr_matrix, axis=1)
|
| 679 |
+
order = np.argsort(-best_scores)
|
| 680 |
+
|
| 681 |
+
for src in order:
|
| 682 |
+
# Find best available target
|
| 683 |
+
sorted_targets = np.argsort(-corr_matrix[src])
|
| 684 |
+
for tgt in sorted_targets:
|
| 685 |
+
if tgt not in taken:
|
| 686 |
+
perm[src] = tgt
|
| 687 |
+
taken.add(tgt)
|
| 688 |
+
break
|
| 689 |
+
|
| 690 |
+
# Safety: any unassigned source neurons get remaining targets
|
| 691 |
+
remaining = set(range(n)) - taken
|
| 692 |
+
for src in range(n):
|
| 693 |
+
if perm[src] == -1:
|
| 694 |
+
perm[src] = remaining.pop()
|
| 695 |
+
|
| 696 |
+
return perm
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
def _apply_permutation(source_w: torch.Tensor, perm: np.ndarray, key: str) -> torch.Tensor:
|
| 700 |
+
"""
|
| 701 |
+
Apply neuron permutation to a source weight tensor before blending.
|
| 702 |
+
|
| 703 |
+
The permutation rearranges MiMo's neurons to match Qwen3's ordering.
|
| 704 |
+
Think of it like reorganising filing cabinets: same files, different order.
|
| 705 |
+
|
| 706 |
+
Which dimension to permute depends on the weight type:
|
| 707 |
+
- Input projections (q_proj, k_proj, v_proj, gate_proj, up_proj):
|
| 708 |
+
shape [out_features, in_features] -> permute columns (dim 1)
|
| 709 |
+
because input neurons need reordering
|
| 710 |
+
- Output projections (o_proj, down_proj):
|
| 711 |
+
shape [out_features, in_features] -> permute rows (dim 0)
|
| 712 |
+
because output neurons need reordering
|
| 713 |
+
- 1D weights (layer_norm, bias):
|
| 714 |
+
permute directly
|
| 715 |
+
"""
|
| 716 |
+
perm_tensor = torch.from_numpy(perm).long()
|
| 717 |
+
|
| 718 |
+
if source_w.dim() == 1:
|
| 719 |
+
# 1D: layer norms, biases
|
| 720 |
+
if len(perm_tensor) == source_w.shape[0]:
|
| 721 |
+
return source_w[perm_tensor]
|
| 722 |
+
return source_w
|
| 723 |
+
|
| 724 |
+
if source_w.dim() == 2:
|
| 725 |
+
# 2D: linear layers
|
| 726 |
+
out_features, in_features = source_w.shape
|
| 727 |
+
|
| 728 |
+
# Output projections: neurons on dim 0 (rows)
|
| 729 |
+
if any(proj in key for proj in ["o_proj", "down_proj"]):
|
| 730 |
+
if len(perm_tensor) == out_features:
|
| 731 |
+
return source_w[perm_tensor, :]
|
| 732 |
+
# Input projections: neurons on dim 1 (columns)
|
| 733 |
+
elif any(proj in key for proj in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]):
|
| 734 |
+
if len(perm_tensor) == in_features:
|
| 735 |
+
return source_w[:, perm_tensor]
|
| 736 |
+
# Other 2D weights: try columns first (more common)
|
| 737 |
+
else:
|
| 738 |
+
if len(perm_tensor) == in_features:
|
| 739 |
+
return source_w[:, perm_tensor]
|
| 740 |
+
elif len(perm_tensor) == out_features:
|
| 741 |
+
return source_w[perm_tensor, :]
|
| 742 |
+
|
| 743 |
+
# Can't permute -- return unchanged
|
| 744 |
+
return source_w
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def fuse_weights(
|
| 748 |
+
source_state: dict,
|
| 749 |
+
target_model: AutoModelForCausalLM,
|
| 750 |
+
transport_plans: dict,
|
| 751 |
+
source_config: ModelConfig,
|
| 752 |
+
cfg: MergeConfig,
|
| 753 |
+
target_activations: dict = None,
|
| 754 |
+
) -> AutoModelForCausalLM:
|
| 755 |
+
"""
|
| 756 |
+
Fuse source model weights into target model using transport plans.
|
| 757 |
+
|
| 758 |
+
For each layer pair with significant coupling (P > threshold):
|
| 759 |
+
1. Get the Q matrix (neuron-level correspondence)
|
| 760 |
+
2. Transport source weights into target neuron basis: W_fused = Q @ W_source
|
| 761 |
+
3. Blend with target: W_final = alpha * W_fused + (1-alpha) * W_target
|
| 762 |
+
|
| 763 |
+
Args:
|
| 764 |
+
source_state: Source model state dict (can be on CPU -- will be moved per-param)
|
| 765 |
+
target_model: Target model (on GPU)
|
| 766 |
+
transport_plans: Transport plan matrices from compute_transport_plans
|
| 767 |
+
source_config: Source model config
|
| 768 |
+
cfg: Merge configuration
|
| 769 |
+
|
| 770 |
+
Special handling per model:
|
| 771 |
+
- DeepSeek: Direct merge (same architecture)
|
| 772 |
+
- MiMo: Skip MTP heads, skip embeddings, apply neuron permutation
|
| 773 |
+
- Llama: Layer mapping (32->36), skip embeddings, drop QKV bias
|
| 774 |
+
- Falcon: Skip Mamba components, skip embeddings
|
| 775 |
+
|
| 776 |
+
Returns:
|
| 777 |
+
Target model with fused weights
|
| 778 |
+
"""
|
| 779 |
+
tracker = ProgressTracker("fuse-weights", interval_seconds=300)
|
| 780 |
+
print(f"\\n[transport] Fusing {source_config.name} -> target")
|
| 781 |
+
alpha = source_config.merge_alpha
|
| 782 |
+
|
| 783 |
+
try:
|
| 784 |
+
# Try official fusion code first
|
| 785 |
+
from generate_hot_residual import fuse_attention_only_from_hot_dir
|
| 786 |
+
print("[transport] Using official fusion implementation")
|
| 787 |
+
# TODO: Adapt official fusion to our pipeline
|
| 788 |
+
# For now, fall through to manual fusion
|
| 789 |
+
except ImportError:
|
| 790 |
+
pass
|
| 791 |
+
|
| 792 |
+
# --- Manual fusion using transport plans ---
|
| 793 |
+
# source_state is passed in (may be on CPU to save GPU memory)
|
| 794 |
+
target_state = target_model.state_dict()
|
| 795 |
+
P = transport_plans["P"]
|
| 796 |
+
Q = transport_plans["Q"]
|
| 797 |
+
permutations = transport_plans.get("permutations", {})
|
| 798 |
+
|
| 799 |
+
# Build layer-index -> permutation lookup
|
| 800 |
+
# permutations keys are (source_layer_name, target_layer_name) tuples
|
| 801 |
+
# We need to map weight keys like "model.layers.5.self_attn.q_proj.weight"
|
| 802 |
+
# to the permutation for layer 5
|
| 803 |
+
layer_perms = {}
|
| 804 |
+
for (sl, tl), perm in permutations.items():
|
| 805 |
+
# Extract layer index from target layer name (e.g. "model.layers.5.mlp" -> 5)
|
| 806 |
+
parts = tl.split(".")
|
| 807 |
+
for j, part in enumerate(parts):
|
| 808 |
+
if part == "layers" and j + 1 < len(parts):
|
| 809 |
+
try:
|
| 810 |
+
layer_idx = int(parts[j + 1])
|
| 811 |
+
layer_perms[layer_idx] = perm
|
| 812 |
+
except ValueError:
|
| 813 |
+
pass
|
| 814 |
+
break
|
| 815 |
+
|
| 816 |
+
if permutations:
|
| 817 |
+
print(f"[transport] Will apply neuron permutations to {len(layer_perms)} layers before blending")
|
| 818 |
+
else:
|
| 819 |
+
print("[transport] No neuron permutations needed (neurons already aligned)")
|
| 820 |
+
|
| 821 |
+
fused_count = 0
|
| 822 |
+
skipped_count = 0
|
| 823 |
+
permuted_count = 0
|
| 824 |
+
total_params = len(target_state)
|
| 825 |
+
tracker.set_total(total_params)
|
| 826 |
+
|
| 827 |
+
for target_key in target_state:
|
| 828 |
+
tracker.tick(target_key)
|
| 829 |
+
|
| 830 |
+
# Skip parameters we shouldn't merge
|
| 831 |
+
if _should_skip(target_key, source_config):
|
| 832 |
+
skipped_count += 1
|
| 833 |
+
continue
|
| 834 |
+
|
| 835 |
+
# Find corresponding source key
|
| 836 |
+
source_key = _map_key(target_key, source_config)
|
| 837 |
+
if source_key is None or source_key not in source_state:
|
| 838 |
+
skipped_count += 1
|
| 839 |
+
# Log first few misses to help debug key mapping issues
|
| 840 |
+
if skipped_count <= 5:
|
| 841 |
+
print(f" [skip] No source match for: {target_key} (mapped to: {source_key})")
|
| 842 |
+
sys.stdout.flush()
|
| 843 |
+
continue
|
| 844 |
+
|
| 845 |
+
target_w = target_state[target_key]
|
| 846 |
+
source_w = source_state[source_key]
|
| 847 |
+
|
| 848 |
+
# Handle dimension mismatches
|
| 849 |
+
if target_w.shape != source_w.shape:
|
| 850 |
+
# Use transport plan to align dimensions
|
| 851 |
+
source_w = _align_dimensions(source_w, target_w.shape, Q, target_key)
|
| 852 |
+
if source_w is None:
|
| 853 |
+
skipped_count += 1
|
| 854 |
+
continue
|
| 855 |
+
|
| 856 |
+
# --- NEURON PERMUTATION: rearrange source neurons to match target ---
|
| 857 |
+
# This is what makes MiMo merge work -- without this, it's like
|
| 858 |
+
# dumping one filing cabinet into another without matching folders
|
| 859 |
+
if layer_perms:
|
| 860 |
+
# Extract layer index from this weight's key
|
| 861 |
+
key_parts = target_key.split(".")
|
| 862 |
+
for j, part in enumerate(key_parts):
|
| 863 |
+
if part == "layers" and j + 1 < len(key_parts):
|
| 864 |
+
try:
|
| 865 |
+
lidx = int(key_parts[j + 1])
|
| 866 |
+
if lidx in layer_perms:
|
| 867 |
+
source_w = _apply_permutation(source_w, layer_perms[lidx], target_key)
|
| 868 |
+
permuted_count += 1
|
| 869 |
+
except ValueError:
|
| 870 |
+
pass
|
| 871 |
+
break
|
| 872 |
+
|
| 873 |
+
# Blend: W_final = alpha * source + (1-alpha) * target
|
| 874 |
+
fused_w = alpha * source_w.to(target_w.device) + (1 - alpha) * target_w
|
| 875 |
+
target_state[target_key] = fused_w
|
| 876 |
+
fused_count += 1
|
| 877 |
+
|
| 878 |
+
# Apply thinking mode protection (inside loop -- check each key)
|
| 879 |
+
if cfg.freeze_think_tokens and "embed_tokens" in target_key:
|
| 880 |
+
for token_id in cfg.think_token_ids:
|
| 881 |
+
if token_id < target_state[target_key].shape[0]:
|
| 882 |
+
# Restore original embedding for think tokens
|
| 883 |
+
orig_embed = target_model.state_dict()[target_key]
|
| 884 |
+
target_state[target_key][token_id] = orig_embed[token_id]
|
| 885 |
+
print(f"[transport] Protected think token {token_id}")
|
| 886 |
+
|
| 887 |
+
if fused_count % 50 == 0:
|
| 888 |
+
print(f" Fused {fused_count} params so far (skipped {skipped_count})...")
|
| 889 |
+
sys.stdout.flush()
|
| 890 |
+
|
| 891 |
+
# Timeout: 20 min for weight fusion
|
| 892 |
+
tracker.check_timeout(timeout_seconds=1200)
|
| 893 |
+
|
| 894 |
+
# Load fused weights (strict=False: vision encoder may have bitsandbytes quant keys
|
| 895 |
+
# that don't match the original key names -- we never modify vision weights anyway)
|
| 896 |
+
missing, unexpected = target_model.load_state_dict(target_state, strict=False)
|
| 897 |
+
if missing:
|
| 898 |
+
print(f"[transport] NOTE: {len(missing)} missing keys (likely quantized vision params -- safe to ignore)")
|
| 899 |
+
if unexpected:
|
| 900 |
+
print(f"[transport] NOTE: {len(unexpected)} unexpected keys (safe to ignore)")
|
| 901 |
+
perm_msg = f", permuted {permuted_count}" if permuted_count else ""
|
| 902 |
+
print(f"[transport] Fused {fused_count} params, skipped {skipped_count}{perm_msg}")
|
| 903 |
+
tracker.done()
|
| 904 |
+
sys.stdout.flush()
|
| 905 |
+
|
| 906 |
+
return target_model
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
def _should_skip(key: str, source_config: ModelConfig) -> bool:
|
| 910 |
+
"""Determine if a parameter should be skipped during merge."""
|
| 911 |
+
|
| 912 |
+
# Skip vision encoder params (Qwen3-VL) -- these should never be merged
|
| 913 |
+
if key.startswith("visual") or key.startswith("merger") or key.startswith("model.visual") or key.startswith("model.merger"):
|
| 914 |
+
return True
|
| 915 |
+
|
| 916 |
+
# Always skip if source model says to skip embeddings
|
| 917 |
+
if source_config.skip_embeddings and ("embed_tokens" in key or "lm_head" in key):
|
| 918 |
+
return True
|
| 919 |
+
|
| 920 |
+
# Skip MiMo MTP heads
|
| 921 |
+
if "drop_mtp_heads" in source_config.special_handling and "mtp_head" in key:
|
| 922 |
+
return True
|
| 923 |
+
|
| 924 |
+
# Skip Falcon Mamba-specific parameters
|
| 925 |
+
if "drop_mamba_state_params" in source_config.special_handling:
|
| 926 |
+
mamba_keys = ["mamba", "A_log", "dt_proj", ".D"]
|
| 927 |
+
if any(mk in key for mk in mamba_keys):
|
| 928 |
+
return True
|
| 929 |
+
|
| 930 |
+
# Skip QKV bias for Llama (Qwen3 doesn't have it)
|
| 931 |
+
if "drop_qkv_bias" in source_config.special_handling and ".bias" in key:
|
| 932 |
+
if any(proj in key for proj in ["q_proj", "k_proj", "v_proj"]):
|
| 933 |
+
return True
|
| 934 |
+
|
| 935 |
+
return False
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
def _strip_vl_prefix(key: str) -> str:
|
| 939 |
+
"""
|
| 940 |
+
Strip the 'language_model.' prefix that Qwen3-VL adds.
|
| 941 |
+
|
| 942 |
+
Qwen3-VL wraps all language params under 'model.language_model.*'
|
| 943 |
+
but source models (DeepSeek, MiMo, Llama, Falcon) use 'model.*' directly.
|
| 944 |
+
|
| 945 |
+
Example:
|
| 946 |
+
target: model.language_model.layers.0.self_attn.q_proj.weight
|
| 947 |
+
source: model.layers.0.self_attn.q_proj.weight
|
| 948 |
+
"""
|
| 949 |
+
# model.language_model.X -> model.X
|
| 950 |
+
if "language_model." in key:
|
| 951 |
+
return key.replace("language_model.", "")
|
| 952 |
+
return key
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]:
|
| 956 |
+
"""Map a target model parameter name to the corresponding source name."""
|
| 957 |
+
|
| 958 |
+
# Step 1: Strip Qwen3-VL's language_model. prefix so we can match source keys
|
| 959 |
+
source_key = _strip_vl_prefix(target_key)
|
| 960 |
+
|
| 961 |
+
# For same-architecture models (DeepSeek), keys match directly after prefix strip
|
| 962 |
+
if source_config.architecture == "transformer" and source_config.layers == 36:
|
| 963 |
+
return source_key
|
| 964 |
+
|
| 965 |
+
# For Llama (32 layers -> 36 layers), map layer indices
|
| 966 |
+
if "layer_mapping_32_to_36" in source_config.special_handling:
|
| 967 |
+
if "model.layers." in source_key:
|
| 968 |
+
# Extract layer number
|
| 969 |
+
parts = source_key.split(".")
|
| 970 |
+
try:
|
| 971 |
+
layer_idx = int(parts[2])
|
| 972 |
+
except (IndexError, ValueError):
|
| 973 |
+
return source_key
|
| 974 |
+
|
| 975 |
+
# Map 36 target layers to 32 source layers (stride)
|
| 976 |
+
source_layer = int(layer_idx * 32 / 36)
|
| 977 |
+
parts[2] = str(source_layer)
|
| 978 |
+
return ".".join(parts)
|
| 979 |
+
|
| 980 |
+
# For MiMo (same layer count, different extras), keys mostly match
|
| 981 |
+
if source_config.architecture == "transformer+mtp":
|
| 982 |
+
if "mtp_head" in source_key:
|
| 983 |
+
return None # MTP heads don't exist in target
|
| 984 |
+
return source_key
|
| 985 |
+
|
| 986 |
+
# For Falcon hybrid, only attention and MLP keys map
|
| 987 |
+
if source_config.architecture == "hybrid_ssm":
|
| 988 |
+
if any(k in source_key for k in ["self_attn", "mlp", "layer_norm"]):
|
| 989 |
+
return source_key # These exist in both
|
| 990 |
+
return None # Mamba components don't map
|
| 991 |
+
|
| 992 |
+
return source_key
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
def _align_dimensions(
|
| 996 |
+
source_w: torch.Tensor,
|
| 997 |
+
target_shape: tuple,
|
| 998 |
+
Q_matrices: dict,
|
| 999 |
+
key: str,
|
| 1000 |
+
) -> Optional[torch.Tensor]:
|
| 1001 |
+
"""
|
| 1002 |
+
Align source weight dimensions to target shape using transport plans.
|
| 1003 |
+
|
| 1004 |
+
For small mismatches: pad or truncate.
|
| 1005 |
+
For large mismatches: use Q matrix to project.
|
| 1006 |
+
"""
|
| 1007 |
+
if source_w.shape == target_shape:
|
| 1008 |
+
return source_w
|
| 1009 |
+
|
| 1010 |
+
# Simple case: different width (FFN size difference)
|
| 1011 |
+
if len(source_w.shape) == 2 and len(target_shape) == 2:
|
| 1012 |
+
s_rows, s_cols = source_w.shape
|
| 1013 |
+
t_rows, t_cols = target_shape
|
| 1014 |
+
|
| 1015 |
+
result = torch.zeros(target_shape, dtype=source_w.dtype)
|
| 1016 |
+
|
| 1017 |
+
# Copy what fits
|
| 1018 |
+
min_rows = min(s_rows, t_rows)
|
| 1019 |
+
min_cols = min(s_cols, t_cols)
|
| 1020 |
+
result[:min_rows, :min_cols] = source_w[:min_rows, :min_cols]
|
| 1021 |
+
|
| 1022 |
+
return result
|
| 1023 |
+
|
| 1024 |
+
# 1D case (biases, layer norms)
|
| 1025 |
+
if len(source_w.shape) == 1 and len(target_shape) == 1:
|
| 1026 |
+
result = torch.zeros(target_shape, dtype=source_w.dtype)
|
| 1027 |
+
min_len = min(source_w.shape[0], target_shape[0])
|
| 1028 |
+
result[:min_len] = source_w[:min_len]
|
| 1029 |
+
return result
|
| 1030 |
+
|
| 1031 |
+
# Can't align -- skip this parameter
|
| 1032 |
+
return None
|
| 1033 |
+
'''
|
| 1034 |
+
with open("td_fuse/transport.py", 'w') as f:
|
| 1035 |
+
f.write(code)
|
| 1036 |
+
|
| 1037 |
+
|
| 1038 |
+
if __name__ == "__main__":
|
| 1039 |
+
main()
|
hugging/td_fuse/config.py
CHANGED
|
@@ -107,7 +107,7 @@ SOURCES = [
|
|
| 107 |
skip_embeddings=True, # Must skip — vocab too different
|
| 108 |
trust_remote_code=True, # Custom MTP architecture
|
| 109 |
merge_risk="medium",
|
| 110 |
-
merge_alpha=0.
|
| 111 |
special_handling=["drop_mtp_heads", "skip_embeddings"],
|
| 112 |
notes=(
|
| 113 |
"Xiaomi's reasoning model. Same layer count and hidden dim as Qwen3. "
|
|
|
|
| 107 |
skip_embeddings=True, # Must skip — vocab too different
|
| 108 |
trust_remote_code=True, # Custom MTP architecture
|
| 109 |
merge_risk="medium",
|
| 110 |
+
merge_alpha=0.15, # Low — MiMo neurons need permutation, keep target dominant
|
| 111 |
special_handling=["drop_mtp_heads", "skip_embeddings"],
|
| 112 |
notes=(
|
| 113 |
"Xiaomi's reasoning model. Same layer count and hidden dim as Qwen3. "
|
hugging/td_fuse/heal.py
CHANGED
|
@@ -247,6 +247,7 @@ def apply_qlora_standard(
|
|
| 247 |
if os.path.exists(healed_check):
|
| 248 |
print('[heal] Found existing healed model — SKIPPING healing!')
|
| 249 |
return 'td_fuse_outputs/healed'
|
|
|
|
| 250 |
from peft import LoraConfig, get_peft_model, TaskType
|
| 251 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 252 |
|
|
@@ -353,39 +354,55 @@ def apply_qlora_standard(
|
|
| 353 |
print(f"\n[heal] Merging LoRA adapters...")
|
| 354 |
merged_model = model.merge_and_unload()
|
| 355 |
|
| 356 |
-
# Free disk space before saving — remove duplicate model copies
|
| 357 |
import shutil, gc
|
| 358 |
-
print("[heal] Freeing disk space before save...")
|
| 359 |
|
| 360 |
-
#
|
| 361 |
-
#
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
gc.collect()
|
| 381 |
|
| 382 |
-
# Report free space
|
| 383 |
-
stat = shutil.disk_usage("/")
|
| 384 |
-
print(f"[heal] Disk space: {stat.free / 1e9:.1f} GB free / {stat.total / 1e9:.1f} GB total")
|
| 385 |
-
|
| 386 |
-
merged_model.save_pretrained(str(healed_dir))
|
| 387 |
-
tokenizer.save_pretrained(str(healed_dir))
|
| 388 |
-
|
| 389 |
print(f"[heal] Healed model saved to {healed_dir}")
|
| 390 |
return str(healed_dir)
|
| 391 |
|
|
|
|
| 247 |
if os.path.exists(healed_check):
|
| 248 |
print('[heal] Found existing healed model — SKIPPING healing!')
|
| 249 |
return 'td_fuse_outputs/healed'
|
| 250 |
+
import torch
|
| 251 |
from peft import LoraConfig, get_peft_model, TaskType
|
| 252 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 253 |
|
|
|
|
| 354 |
print(f"\n[heal] Merging LoRA adapters...")
|
| 355 |
merged_model = model.merge_and_unload()
|
| 356 |
|
|
|
|
| 357 |
import shutil, gc
|
|
|
|
| 358 |
|
| 359 |
+
# SAVE FIRST — never delete anything until save is confirmed
|
| 360 |
+
# save_pretrained can fail on 4-bit merged models (NotImplementedError)
|
| 361 |
+
# So we go straight to the safe manual method
|
| 362 |
+
print(f"[heal] Saving healed model to {healed_dir}...")
|
| 363 |
+
try:
|
| 364 |
+
from safetensors.torch import save_file
|
| 365 |
+
import torch as _torch
|
| 366 |
+
state_dict = merged_model.state_dict()
|
| 367 |
+
clean_state = {}
|
| 368 |
+
for k, v in state_dict.items():
|
| 369 |
+
if hasattr(v, 'dequantize'):
|
| 370 |
+
clean_state[k] = v.dequantize().to(_torch.bfloat16)
|
| 371 |
+
elif v.dtype in (_torch.float32, _torch.float16, _torch.bfloat16):
|
| 372 |
+
clean_state[k] = v.to(_torch.bfloat16)
|
| 373 |
+
else:
|
| 374 |
+
clean_state[k] = v.float().to(_torch.bfloat16)
|
| 375 |
+
save_file(clean_state, str(healed_dir / "model.safetensors"))
|
| 376 |
+
if hasattr(merged_model, 'config'):
|
| 377 |
+
merged_model.config.save_pretrained(str(healed_dir))
|
| 378 |
+
tokenizer.save_pretrained(str(healed_dir))
|
| 379 |
+
print(f"[heal] SAVED OK: {healed_dir / 'model.safetensors'}")
|
| 380 |
+
except Exception as e:
|
| 381 |
+
# Emergency fallback: try save_pretrained as last resort
|
| 382 |
+
print(f"[heal] Manual save failed ({e}), trying save_pretrained...")
|
| 383 |
+
merged_model.save_pretrained(str(healed_dir))
|
| 384 |
+
tokenizer.save_pretrained(str(healed_dir))
|
| 385 |
+
print(f"[heal] SAVED OK via save_pretrained: {healed_dir}")
|
| 386 |
+
|
| 387 |
+
# Verify the save actually worked before cleaning up ANYTHING
|
| 388 |
+
saved_model = healed_dir / "model.safetensors"
|
| 389 |
+
if not saved_model.exists() or saved_model.stat().st_size < 1_000_000:
|
| 390 |
+
print(f"[heal] WARNING: Save may have failed — NOT deleting any backups!")
|
| 391 |
+
else:
|
| 392 |
+
save_size = saved_model.stat().st_size / 1e9
|
| 393 |
+
print(f"[heal] Verified: {saved_model} ({save_size:.1f} GB)")
|
| 394 |
+
# NOW safe to clean up old stuff
|
| 395 |
+
cleanup_targets = [
|
| 396 |
+
"td_fuse_outputs/final",
|
| 397 |
+
]
|
| 398 |
+
for target in cleanup_targets:
|
| 399 |
+
target_path = Path(target)
|
| 400 |
+
if target_path.exists() and target_path.is_dir():
|
| 401 |
+
shutil.rmtree(str(target_path))
|
| 402 |
+
print(f"[heal] Freed space: removed {target_path}")
|
| 403 |
|
| 404 |
gc.collect()
|
| 405 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
print(f"[heal] Healed model saved to {healed_dir}")
|
| 407 |
return str(healed_dir)
|
| 408 |
|
hugging/td_fuse/merge.py
CHANGED
|
@@ -484,6 +484,9 @@ class ResidualBank:
|
|
| 484 |
# What the source lost (what didn't make it into the merge)
|
| 485 |
if key in source_state:
|
| 486 |
original_source = source_state[key].float()
|
|
|
|
|
|
|
|
|
|
| 487 |
s_residual = original_source - merged_w
|
| 488 |
s_loss = s_residual.abs().mean().item()
|
| 489 |
|
|
|
|
| 484 |
# What the source lost (what didn't make it into the merge)
|
| 485 |
if key in source_state:
|
| 486 |
original_source = source_state[key].float()
|
| 487 |
+
# Skip if shapes don't match (e.g. vocab size mismatch on embeddings/lm_head)
|
| 488 |
+
if original_source.shape != merged_w.shape:
|
| 489 |
+
continue
|
| 490 |
s_residual = original_source - merged_w
|
| 491 |
s_loss = s_residual.abs().mean().item()
|
| 492 |
|
hugging/td_fuse/transport.py
CHANGED
|
@@ -360,24 +360,69 @@ def _compute_plans_fallback(
|
|
| 360 |
sys.stdout.flush()
|
| 361 |
|
| 362 |
# --- FAST PATH: same architecture (same layer count) ---
|
| 363 |
-
#
|
| 364 |
-
#
|
| 365 |
-
#
|
|
|
|
| 366 |
if n_source == n_target:
|
| 367 |
-
print("[transport] Same layer count -- using direct 1:1 layer matching
|
| 368 |
-
print("[transport] This should take under 1 minute...")
|
| 369 |
sys.stdout.flush()
|
| 370 |
Q_matrices = {}
|
|
|
|
| 371 |
P = np.eye(n_source) / n_source # Identity coupling
|
| 372 |
tracker.set_total(n_source)
|
| 373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
for i, (sl, tl) in enumerate(zip(source_layers, target_layers)):
|
| 375 |
S = source_act[sl].numpy()
|
| 376 |
T = target_act[tl].numpy()
|
| 377 |
|
| 378 |
-
# For same-dim layers, Q is identity (neurons already correspond)
|
| 379 |
if S.shape[1] == T.shape[1]:
|
| 380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
else:
|
| 382 |
# Different dims -- do lightweight Sinkhorn on this pair only
|
| 383 |
print(f" Layer {i}: dim mismatch ({S.shape[1]} vs {T.shape[1]}), using Sinkhorn...")
|
|
@@ -393,15 +438,18 @@ def _compute_plans_fallback(
|
|
| 393 |
print(f" Matched layer {i + 1}/{n_source}: {sl} -> {tl}")
|
| 394 |
sys.stdout.flush()
|
| 395 |
|
| 396 |
-
# Timeout:
|
| 397 |
-
tracker.check_timeout(timeout_seconds=
|
| 398 |
|
|
|
|
|
|
|
| 399 |
print(f"[transport] Direct matching complete: {n_source} layer pairs")
|
| 400 |
tracker.done()
|
| 401 |
sys.stdout.flush()
|
| 402 |
return {
|
| 403 |
"P": P,
|
| 404 |
"Q": Q_matrices,
|
|
|
|
| 405 |
"source_layers": source_layers,
|
| 406 |
"target_layers": target_layers,
|
| 407 |
}
|
|
@@ -478,6 +526,7 @@ def _compute_plans_fallback(
|
|
| 478 |
return {
|
| 479 |
"P": P,
|
| 480 |
"Q": Q_matrices,
|
|
|
|
| 481 |
"source_layers": source_layers,
|
| 482 |
"target_layers": target_layers,
|
| 483 |
}
|
|
@@ -512,6 +561,87 @@ def _sinkhorn(
|
|
| 512 |
return T
|
| 513 |
|
| 514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
def fuse_weights(
|
| 516 |
source_state: dict,
|
| 517 |
target_model: AutoModelForCausalLM,
|
|
@@ -562,9 +692,33 @@ def fuse_weights(
|
|
| 562 |
target_state = target_model.state_dict()
|
| 563 |
P = transport_plans["P"]
|
| 564 |
Q = transport_plans["Q"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
|
| 566 |
fused_count = 0
|
| 567 |
skipped_count = 0
|
|
|
|
| 568 |
total_params = len(target_state)
|
| 569 |
tracker.set_total(total_params)
|
| 570 |
|
|
@@ -597,6 +751,23 @@ def fuse_weights(
|
|
| 597 |
skipped_count += 1
|
| 598 |
continue
|
| 599 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
# Blend: W_final = alpha * source + (1-alpha) * target
|
| 601 |
fused_w = alpha * source_w.to(target_w.device) + (1 - alpha) * target_w
|
| 602 |
target_state[target_key] = fused_w
|
|
@@ -618,9 +789,15 @@ def fuse_weights(
|
|
| 618 |
# Timeout: 20 min for weight fusion
|
| 619 |
tracker.check_timeout(timeout_seconds=1200)
|
| 620 |
|
| 621 |
-
# Load fused weights
|
| 622 |
-
|
| 623 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
tracker.done()
|
| 625 |
sys.stdout.flush()
|
| 626 |
|
|
|
|
| 360 |
sys.stdout.flush()
|
| 361 |
|
| 362 |
# --- FAST PATH: same architecture (same layer count) ---
|
| 363 |
+
# Both models have the same number of transformer layers
|
| 364 |
+
# Match layers 1:1 but CHECK if neurons correspond
|
| 365 |
+
# DeepSeek: same training base → neurons aligned → identity Q (fast)
|
| 366 |
+
# MiMo: different training → neurons scrambled → need Sinkhorn permutation
|
| 367 |
if n_source == n_target:
|
| 368 |
+
print("[transport] Same layer count -- using direct 1:1 layer matching")
|
|
|
|
| 369 |
sys.stdout.flush()
|
| 370 |
Q_matrices = {}
|
| 371 |
+
permutations = {} # layer_pair -> permutation array (neuron reordering)
|
| 372 |
P = np.eye(n_source) / n_source # Identity coupling
|
| 373 |
tracker.set_total(n_source)
|
| 374 |
|
| 375 |
+
# Check first layer to decide: are neurons aligned or scrambled?
|
| 376 |
+
first_sl = source_layers[0]
|
| 377 |
+
first_tl = target_layers[0]
|
| 378 |
+
S0 = source_act[first_sl].numpy()
|
| 379 |
+
T0 = target_act[first_tl].numpy()
|
| 380 |
+
if S0.shape[1] == T0.shape[1]:
|
| 381 |
+
S0_norm = (S0 - S0.mean(0)) / (S0.std(0) + 1e-8)
|
| 382 |
+
T0_norm = (T0 - T0.mean(0)) / (T0.std(0) + 1e-8)
|
| 383 |
+
diag_corr = np.mean(np.sum(S0_norm * T0_norm, axis=0) / S0.shape[0])
|
| 384 |
+
neurons_aligned = diag_corr > 0.3
|
| 385 |
+
else:
|
| 386 |
+
neurons_aligned = False
|
| 387 |
+
|
| 388 |
+
if neurons_aligned:
|
| 389 |
+
print(f"[transport] Neurons ARE aligned (diag_corr={diag_corr:.3f}) — identity Q (fast)")
|
| 390 |
+
print("[transport] This should take under 1 minute...")
|
| 391 |
+
else:
|
| 392 |
+
corr_val = diag_corr if S0.shape[1] == T0.shape[1] else 0.0
|
| 393 |
+
print(f"[transport] Neurons NOT aligned (diag_corr={corr_val:.3f}) — computing permutations via Sinkhorn")
|
| 394 |
+
print("[transport] This may take 2-5 minutes...")
|
| 395 |
+
sys.stdout.flush()
|
| 396 |
+
|
| 397 |
for i, (sl, tl) in enumerate(zip(source_layers, target_layers)):
|
| 398 |
S = source_act[sl].numpy()
|
| 399 |
T = target_act[tl].numpy()
|
| 400 |
|
|
|
|
| 401 |
if S.shape[1] == T.shape[1]:
|
| 402 |
+
if neurons_aligned:
|
| 403 |
+
# Neurons already correspond (e.g. DeepSeek) — identity Q
|
| 404 |
+
Q_matrices[(sl, tl)] = np.eye(S.shape[1]) / S.shape[1]
|
| 405 |
+
else:
|
| 406 |
+
# Neurons are SCRAMBLED (e.g. MiMo) — find the permutation
|
| 407 |
+
# 1. Compute correlation matrix between source and target neurons
|
| 408 |
+
S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8)
|
| 409 |
+
T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8)
|
| 410 |
+
corr = S_norm.T @ T_norm / S.shape[0] # [hidden_dim, hidden_dim]
|
| 411 |
+
|
| 412 |
+
# 2. Run Sinkhorn on cost matrix to get soft transport plan
|
| 413 |
+
cost = 1.0 - corr
|
| 414 |
+
Q_soft = _sinkhorn(cost, reg=0.05, max_iter=cfg.sinkhorn_max_iter)
|
| 415 |
+
|
| 416 |
+
# 3. Extract hard permutation: for each source neuron, which target neuron?
|
| 417 |
+
perm = np.argmax(Q_soft, axis=1) # source_neuron -> target_neuron
|
| 418 |
+
|
| 419 |
+
# 4. Check for duplicate assignments (Sinkhorn should avoid this, but be safe)
|
| 420 |
+
if len(set(perm)) < len(perm) * 0.9:
|
| 421 |
+
# Too many collisions — fall back to Hungarian-style greedy
|
| 422 |
+
perm = _greedy_permutation(corr)
|
| 423 |
+
|
| 424 |
+
permutations[(sl, tl)] = perm
|
| 425 |
+
Q_matrices[(sl, tl)] = Q_soft
|
| 426 |
else:
|
| 427 |
# Different dims -- do lightweight Sinkhorn on this pair only
|
| 428 |
print(f" Layer {i}: dim mismatch ({S.shape[1]} vs {T.shape[1]}), using Sinkhorn...")
|
|
|
|
| 438 |
print(f" Matched layer {i + 1}/{n_source}: {sl} -> {tl}")
|
| 439 |
sys.stdout.flush()
|
| 440 |
|
| 441 |
+
# Timeout: 15 min (permutation takes longer than identity)
|
| 442 |
+
tracker.check_timeout(timeout_seconds=900)
|
| 443 |
|
| 444 |
+
if permutations:
|
| 445 |
+
print(f"[transport] Computed {len(permutations)} neuron permutations")
|
| 446 |
print(f"[transport] Direct matching complete: {n_source} layer pairs")
|
| 447 |
tracker.done()
|
| 448 |
sys.stdout.flush()
|
| 449 |
return {
|
| 450 |
"P": P,
|
| 451 |
"Q": Q_matrices,
|
| 452 |
+
"permutations": permutations,
|
| 453 |
"source_layers": source_layers,
|
| 454 |
"target_layers": target_layers,
|
| 455 |
}
|
|
|
|
| 526 |
return {
|
| 527 |
"P": P,
|
| 528 |
"Q": Q_matrices,
|
| 529 |
+
"permutations": {},
|
| 530 |
"source_layers": source_layers,
|
| 531 |
"target_layers": target_layers,
|
| 532 |
}
|
|
|
|
| 561 |
return T
|
| 562 |
|
| 563 |
|
| 564 |
+
def _greedy_permutation(corr_matrix: np.ndarray) -> np.ndarray:
|
| 565 |
+
"""
|
| 566 |
+
Greedy permutation assignment when Sinkhorn gives duplicate mappings.
|
| 567 |
+
|
| 568 |
+
For each source neuron (in order of strongest match), assign it to the
|
| 569 |
+
best available target neuron that hasn't been taken yet.
|
| 570 |
+
"""
|
| 571 |
+
n = corr_matrix.shape[0]
|
| 572 |
+
perm = np.full(n, -1, dtype=np.int64)
|
| 573 |
+
taken = set()
|
| 574 |
+
|
| 575 |
+
# Process source neurons by strength of their best match (strongest first)
|
| 576 |
+
best_scores = np.max(corr_matrix, axis=1)
|
| 577 |
+
order = np.argsort(-best_scores)
|
| 578 |
+
|
| 579 |
+
for src in order:
|
| 580 |
+
# Find best available target
|
| 581 |
+
sorted_targets = np.argsort(-corr_matrix[src])
|
| 582 |
+
for tgt in sorted_targets:
|
| 583 |
+
if tgt not in taken:
|
| 584 |
+
perm[src] = tgt
|
| 585 |
+
taken.add(tgt)
|
| 586 |
+
break
|
| 587 |
+
|
| 588 |
+
# Safety: any unassigned source neurons get remaining targets
|
| 589 |
+
remaining = set(range(n)) - taken
|
| 590 |
+
for src in range(n):
|
| 591 |
+
if perm[src] == -1:
|
| 592 |
+
perm[src] = remaining.pop()
|
| 593 |
+
|
| 594 |
+
return perm
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def _apply_permutation(source_w: torch.Tensor, perm: np.ndarray, key: str) -> torch.Tensor:
|
| 598 |
+
"""
|
| 599 |
+
Apply neuron permutation to a source weight tensor before blending.
|
| 600 |
+
|
| 601 |
+
The permutation rearranges MiMo's neurons to match Qwen3's ordering.
|
| 602 |
+
Think of it like reorganising filing cabinets: same files, different order.
|
| 603 |
+
|
| 604 |
+
Which dimension to permute depends on the weight type:
|
| 605 |
+
- Input projections (q_proj, k_proj, v_proj, gate_proj, up_proj):
|
| 606 |
+
shape [out_features, in_features] → permute columns (dim 1)
|
| 607 |
+
because input neurons need reordering
|
| 608 |
+
- Output projections (o_proj, down_proj):
|
| 609 |
+
shape [out_features, in_features] → permute rows (dim 0)
|
| 610 |
+
because output neurons need reordering
|
| 611 |
+
- 1D weights (layer_norm, bias):
|
| 612 |
+
permute directly
|
| 613 |
+
"""
|
| 614 |
+
perm_tensor = torch.from_numpy(perm).long()
|
| 615 |
+
|
| 616 |
+
if source_w.dim() == 1:
|
| 617 |
+
# 1D: layer norms, biases
|
| 618 |
+
if len(perm_tensor) == source_w.shape[0]:
|
| 619 |
+
return source_w[perm_tensor]
|
| 620 |
+
return source_w
|
| 621 |
+
|
| 622 |
+
if source_w.dim() == 2:
|
| 623 |
+
# 2D: linear layers
|
| 624 |
+
out_features, in_features = source_w.shape
|
| 625 |
+
|
| 626 |
+
# Output projections: neurons on dim 0 (rows)
|
| 627 |
+
if any(proj in key for proj in ["o_proj", "down_proj"]):
|
| 628 |
+
if len(perm_tensor) == out_features:
|
| 629 |
+
return source_w[perm_tensor, :]
|
| 630 |
+
# Input projections: neurons on dim 1 (columns)
|
| 631 |
+
elif any(proj in key for proj in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]):
|
| 632 |
+
if len(perm_tensor) == in_features:
|
| 633 |
+
return source_w[:, perm_tensor]
|
| 634 |
+
# Other 2D weights: try columns first (more common)
|
| 635 |
+
else:
|
| 636 |
+
if len(perm_tensor) == in_features:
|
| 637 |
+
return source_w[:, perm_tensor]
|
| 638 |
+
elif len(perm_tensor) == out_features:
|
| 639 |
+
return source_w[perm_tensor, :]
|
| 640 |
+
|
| 641 |
+
# Can't permute — return unchanged
|
| 642 |
+
return source_w
|
| 643 |
+
|
| 644 |
+
|
| 645 |
def fuse_weights(
|
| 646 |
source_state: dict,
|
| 647 |
target_model: AutoModelForCausalLM,
|
|
|
|
| 692 |
target_state = target_model.state_dict()
|
| 693 |
P = transport_plans["P"]
|
| 694 |
Q = transport_plans["Q"]
|
| 695 |
+
permutations = transport_plans.get("permutations", {})
|
| 696 |
+
|
| 697 |
+
# Build layer-index -> permutation lookup
|
| 698 |
+
# permutations keys are (source_layer_name, target_layer_name) tuples
|
| 699 |
+
# We need to map weight keys like "model.layers.5.self_attn.q_proj.weight"
|
| 700 |
+
# to the permutation for layer 5
|
| 701 |
+
layer_perms = {}
|
| 702 |
+
for (sl, tl), perm in permutations.items():
|
| 703 |
+
# Extract layer index from target layer name (e.g. "model.layers.5.mlp" -> 5)
|
| 704 |
+
parts = tl.split(".")
|
| 705 |
+
for j, part in enumerate(parts):
|
| 706 |
+
if part == "layers" and j + 1 < len(parts):
|
| 707 |
+
try:
|
| 708 |
+
layer_idx = int(parts[j + 1])
|
| 709 |
+
layer_perms[layer_idx] = perm
|
| 710 |
+
except ValueError:
|
| 711 |
+
pass
|
| 712 |
+
break
|
| 713 |
+
|
| 714 |
+
if permutations:
|
| 715 |
+
print(f"[transport] Will apply neuron permutations to {len(layer_perms)} layers before blending")
|
| 716 |
+
else:
|
| 717 |
+
print("[transport] No neuron permutations needed (neurons already aligned)")
|
| 718 |
|
| 719 |
fused_count = 0
|
| 720 |
skipped_count = 0
|
| 721 |
+
permuted_count = 0
|
| 722 |
total_params = len(target_state)
|
| 723 |
tracker.set_total(total_params)
|
| 724 |
|
|
|
|
| 751 |
skipped_count += 1
|
| 752 |
continue
|
| 753 |
|
| 754 |
+
# --- NEURON PERMUTATION: rearrange source neurons to match target ---
|
| 755 |
+
# This is what makes MiMo merge work — without this, it's like
|
| 756 |
+
# dumping one filing cabinet into another without matching folders
|
| 757 |
+
if layer_perms:
|
| 758 |
+
# Extract layer index from this weight's key
|
| 759 |
+
key_parts = target_key.split(".")
|
| 760 |
+
for j, part in enumerate(key_parts):
|
| 761 |
+
if part == "layers" and j + 1 < len(key_parts):
|
| 762 |
+
try:
|
| 763 |
+
lidx = int(key_parts[j + 1])
|
| 764 |
+
if lidx in layer_perms:
|
| 765 |
+
source_w = _apply_permutation(source_w, layer_perms[lidx], target_key)
|
| 766 |
+
permuted_count += 1
|
| 767 |
+
except ValueError:
|
| 768 |
+
pass
|
| 769 |
+
break
|
| 770 |
+
|
| 771 |
# Blend: W_final = alpha * source + (1-alpha) * target
|
| 772 |
fused_w = alpha * source_w.to(target_w.device) + (1 - alpha) * target_w
|
| 773 |
target_state[target_key] = fused_w
|
|
|
|
| 789 |
# Timeout: 20 min for weight fusion
|
| 790 |
tracker.check_timeout(timeout_seconds=1200)
|
| 791 |
|
| 792 |
+
# Load fused weights (strict=False: vision encoder may have bitsandbytes quant keys
|
| 793 |
+
# that don't match the original key names — we never modify vision weights anyway)
|
| 794 |
+
missing, unexpected = target_model.load_state_dict(target_state, strict=False)
|
| 795 |
+
if missing:
|
| 796 |
+
print(f"[transport] NOTE: {len(missing)} missing keys (likely quantized vision params — safe to ignore)")
|
| 797 |
+
if unexpected:
|
| 798 |
+
print(f"[transport] NOTE: {len(unexpected)} unexpected keys (safe to ignore)")
|
| 799 |
+
perm_msg = f", permuted {permuted_count}" if permuted_count else ""
|
| 800 |
+
print(f"[transport] Fused {fused_count} params, skipped {skipped_count}{perm_msg}")
|
| 801 |
tracker.done()
|
| 802 |
sys.stdout.flush()
|
| 803 |
|
hugging/td_start.td
CHANGED
|
@@ -51,7 +51,7 @@ merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength
|
|
| 51 |
# Medium risk: same layer count (36) and hidden_dim (4096)
|
| 52 |
# MTP heads get dropped automatically (no Qwen3 equivalent)
|
| 53 |
# Embeddings skipped (28% vocab overlap too low)
|
| 54 |
-
merge "XiaomiMiMo/MiMo-7B-RL" into base using transport strength 0.
|
| 55 |
|
| 56 |
# --- Step 3: Heal any merge damage ---
|
| 57 |
# QLoRA fine-tune to smooth out rough edges from the merge
|
|
|
|
| 51 |
# Medium risk: same layer count (36) and hidden_dim (4096)
|
| 52 |
# MTP heads get dropped automatically (no Qwen3 equivalent)
|
| 53 |
# Embeddings skipped (28% vocab overlap too low)
|
| 54 |
+
merge "XiaomiMiMo/MiMo-7B-RL" into base using transport strength 0.15
|
| 55 |
|
| 56 |
# --- Step 3: Heal any merge damage ---
|
| 57 |
# QLoRA fine-tune to smooth out rough edges from the merge
|