td-builder commited on
Commit
91670c7
·
verified ·
1 Parent(s): 2293db3

Upload 141 files

Browse files
hugging/patch_gpu.py ADDED
@@ -0,0 +1,1039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPU Patch Script — Apply neuron permutation fix + lower MiMo alpha.
3
+ Run this ON THE GPU after cd /workspace/td_toolkit/hugging:
4
+ python3 patch_gpu.py
5
+
6
+ What it does:
7
+ 1. Adds neuron permutation to transport.py fast path
8
+ 2. Adds _greedy_permutation() and _apply_permutation() helpers
9
+ 3. Updates fuse_weights() to apply permutations before blending
10
+ 4. Lowers MiMo alpha from 0.4 to 0.15 in config.py
11
+ 5. Lowers MiMo strength from 0.4 to 0.15 in td_start.td
12
+ 6. Adds torch import fix to heal.py (Bug #41)
13
+ """
14
+
15
+ import os
16
+
17
+ def patch_file(filepath, old, new):
18
+ """Replace old text with new text in a file."""
19
+ with open(filepath, 'r') as f:
20
+ content = f.read()
21
+ if old not in content:
22
+ print(f" WARNING: patch target not found in {filepath}")
23
+ print(f" Looking for: {old[:80]}...")
24
+ return False
25
+ content = content.replace(old, new)
26
+ with open(filepath, 'w') as f:
27
+ f.write(content)
28
+ print(f" PATCHED: {filepath}")
29
+ return True
30
+
31
+
32
+ def main():
33
+ print("=" * 60)
34
+ print("TD GPU Patch — Neuron Permutation Fix")
35
+ print("=" * 60)
36
+
37
+ # ================================================================
38
+ # PATCH 1: config.py — Lower MiMo alpha
39
+ # ================================================================
40
+ print("\n[1/4] Patching config.py (MiMo alpha 0.4 → 0.15)...")
41
+ patch_file(
42
+ "td_fuse/config.py",
43
+ 'merge_alpha=0.4,',
44
+ 'merge_alpha=0.15,',
45
+ )
46
+
47
+ # ================================================================
48
+ # PATCH 2: td_start.td — Lower MiMo strength
49
+ # ================================================================
50
+ print("\n[2/4] Patching td_start.td (strength 0.4 → 0.15)...")
51
+ patch_file(
52
+ "td_start.td",
53
+ 'strength 0.4',
54
+ 'strength 0.15',
55
+ )
56
+
57
+ # ================================================================
58
+ # PATCH 3: heal.py — Add missing torch import (Bug #41)
59
+ # ================================================================
60
+ print("\n[3/4] Patching heal.py (torch import fix)...")
61
+ # Check if already fixed
62
+ with open("td_fuse/heal.py", 'r') as f:
63
+ heal_content = f.read()
64
+ if "def apply_qlora_standard" in heal_content:
65
+ # Find the function and check if torch import exists after it
66
+ idx = heal_content.find("def apply_qlora_standard")
67
+ next_lines = heal_content[idx:idx+500]
68
+ if "import torch" not in next_lines[:200]:
69
+ # Add import torch after the function's docstring/imports
70
+ patch_file(
71
+ "td_fuse/heal.py",
72
+ "from peft import get_peft_model, LoraConfig, TaskType\n",
73
+ "from peft import get_peft_model, LoraConfig, TaskType\n import torch\n",
74
+ )
75
+ else:
76
+ print(" Already patched (torch import exists)")
77
+ else:
78
+ print(" WARNING: apply_qlora_standard not found in heal.py")
79
+
80
+ # ================================================================
81
+ # PATCH 4: transport.py — Full rewrite with neuron permutation
82
+ # ================================================================
83
+ print("\n[4/4] Rewriting transport.py with neuron permutation...")
84
+ write_transport_py()
85
+ print(" WROTE: td_fuse/transport.py")
86
+
87
+ print("\n" + "=" * 60)
88
+ print("ALL PATCHES APPLIED!")
89
+ print("=" * 60)
90
+ print("\nWhat changed:")
91
+ print(" • MiMo merge alpha: 0.4 → 0.15 (gentler blend)")
92
+ print(" • Neuron permutation: MiMo's neurons get reorganised to match Qwen3")
93
+ print(" • heal.py: torch import fix (Bug #41)")
94
+ print("\nNow run the pipeline:")
95
+ print(" export PYTHONPATH=$(pwd)")
96
+ print(" python3 -m td_lang run td_start.td")
97
+
98
+
99
+ def write_transport_py():
100
+ """Write the complete updated transport.py with neuron permutation."""
101
+ code = '''\
102
+ """
103
+ Transport and Merge Wrapper — interfaces with official T&M code.
104
+
105
+ This wraps the official repo at:
106
+ github.com/chenhangcuisg-code/Cross-Architecture-Merging-for-Large-Language-Models/
107
+
108
+ We use THEIR code for:
109
+ - Correlation distance computation (corr_distance_matrix)
110
+ - Streaming Sinkhorn (sinkhorn_uniform_streaming)
111
+ - Transport plan computation (compute_P, compute_Q_and_layer_costs)
112
+ - Activation reconstruction (reconstruct_X)
113
+
114
+ We add:
115
+ - Qwen3 thinking mode protection
116
+ - MiMo MTP head handling
117
+ - Falcon SSM component handling
118
+ - Neuron permutation for scrambled models (MiMo)
119
+ - Sequential merge protection (MagMax + orthogonal projection)
120
+ - Progress reporting every 5 minutes
121
+ - Timeouts to prevent infinite hangs
122
+
123
+ Findings: #01, #07, #24
124
+ """
125
+
126
+ import sys
127
+ import time
128
+ import torch
129
+ import numpy as np
130
+ from pathlib import Path
131
+ from typing import Optional
132
+ from transformers import AutoModelForCausalLM, AutoTokenizer
133
+ from datasets import load_dataset
134
+
135
+ from .config import MergeConfig, ModelConfig, TARGET
136
+
137
+
138
+ # ============================================================================
139
+ # PROGRESS TRACKER — prints status every 5 minutes so you know it's alive
140
+ # ============================================================================
141
+
142
+ class ProgressTracker:
143
+ """Prints a heartbeat every interval_seconds so you know it's not stuck."""
144
+
145
+ def __init__(self, task_name: str, interval_seconds: int = 300):
146
+ self.task_name = task_name
147
+ self.interval = interval_seconds
148
+ self.start_time = time.time()
149
+ self.last_report = self.start_time
150
+ self.step = 0
151
+ self.total_steps = 0
152
+ print(f"\\n[{task_name}] Started at {time.strftime(\'%H:%M:%S\')}")
153
+
154
+ def set_total(self, total: int):
155
+ self.total_steps = total
156
+
157
+ def tick(self, step_name: str = ""):
158
+ """Call this inside loops. Prints progress if 5 min have passed."""
159
+ self.step += 1
160
+ now = time.time()
161
+ elapsed = now - self.start_time
162
+ since_last = now - self.last_report
163
+
164
+ if since_last >= self.interval:
165
+ pct = f"{self.step}/{self.total_steps} ({100*self.step/self.total_steps:.0f}%)" if self.total_steps else f"step {self.step}"
166
+ eta = ""
167
+ if self.total_steps and self.step > 0:
168
+ rate = elapsed / self.step
169
+ remaining = (self.total_steps - self.step) * rate
170
+ eta = f", ETA {remaining/60:.1f} min"
171
+ print(f"[{self.task_name}] HEARTBEAT — {pct}, elapsed {elapsed/60:.1f} min{eta} | {step_name}")
172
+ sys.stdout.flush()
173
+ self.last_report = now
174
+
175
+ def done(self):
176
+ elapsed = time.time() - self.start_time
177
+ print(f"[{self.task_name}] Completed in {elapsed/60:.1f} min ({elapsed:.0f}s)")
178
+ sys.stdout.flush()
179
+
180
+ def check_timeout(self, timeout_seconds: int = 3600):
181
+ """Raise if we've been running longer than timeout_seconds."""
182
+ elapsed = time.time() - self.start_time
183
+ if elapsed > timeout_seconds:
184
+ raise TimeoutError(
185
+ f"[{self.task_name}] TIMEOUT after {elapsed/60:.1f} min "
186
+ f"(limit: {timeout_seconds/60:.0f} min). Something is wrong."
187
+ )
188
+
189
+
190
+ def setup_tm_repo(cfg: MergeConfig):
191
+ """Add official T&M repo to Python path so we can import their code."""
192
+ repo_path = Path(cfg.tm_repo_path)
193
+ core_path = repo_path / "core"
194
+
195
+ if not core_path.exists():
196
+ raise FileNotFoundError(
197
+ f"Official T&M repo not found at {repo_path}\\n"
198
+ f"Please clone it:\\n"
199
+ f" git clone https://github.com/chenhangcuisg-code/"
200
+ f"Cross-Architecture-Merging-for-Large-Language-Models.git"
201
+ )
202
+
203
+ # Add to path so we can import hot_transport etc.
204
+ if str(core_path) not in sys.path:
205
+ sys.path.insert(0, str(core_path))
206
+ print(f"[transport] Added T&M core to path: {core_path}")
207
+
208
+
209
+ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
210
+ """
211
+ Load calibration data for activation extraction.
212
+
213
+ Mix: 600 Pile general + 300 Pile ArXiv + 600 neuralmagic Q&A = 1500 samples
214
+ Each sample truncated to cfg.calibration_seq_len tokens.
215
+
216
+ Findings: #08
217
+ """
218
+ tracker = ProgressTracker("calibration-data", interval_seconds=120)
219
+ print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
220
+
221
+ samples = []
222
+
223
+ # --- Pile: general text (600 samples) ---
224
+ try:
225
+ pile = load_dataset(
226
+ cfg.calibration_dataset_pile,
227
+ split="validation",
228
+ streaming=True,
229
+ trust_remote_code=True,
230
+ )
231
+ count = 0
232
+ for example in pile:
233
+ if count >= 600:
234
+ break
235
+ text = example.get("text", "")
236
+ if len(text) > 100: # Skip very short texts
237
+ tokens = tokenizer(
238
+ text,
239
+ truncation=True,
240
+ max_length=cfg.calibration_seq_len,
241
+ return_tensors="pt",
242
+ )
243
+ samples.append(tokens)
244
+ count += 1
245
+ if count % 100 == 0:
246
+ print(f" Pile: {count}/600 samples loaded...")
247
+ sys.stdout.flush()
248
+ print(f" Pile general: {count} samples")
249
+ except Exception as e:
250
+ print(f" WARNING: Pile failed: {e}")
251
+ print(f" Falling back to neuralmagic only")
252
+
253
+ # --- neuralmagic: Q&A calibration (up to remaining) ---
254
+ remaining = cfg.calibration_samples - len(samples)
255
+ if remaining > 0:
256
+ try:
257
+ nm = load_dataset(
258
+ cfg.calibration_dataset_nm,
259
+ split="train",
260
+ trust_remote_code=True,
261
+ )
262
+ count = 0
263
+ for example in nm:
264
+ if count >= remaining:
265
+ break
266
+ text = example.get("text", example.get("content", ""))
267
+ if len(str(text)) > 50:
268
+ tokens = tokenizer(
269
+ str(text),
270
+ truncation=True,
271
+ max_length=cfg.calibration_seq_len,
272
+ return_tensors="pt",
273
+ )
274
+ samples.append(tokens)
275
+ count += 1
276
+ if count % 100 == 0:
277
+ print(f" neuralmagic: {count}/{remaining} samples loaded...")
278
+ sys.stdout.flush()
279
+ print(f" neuralmagic: {count} samples")
280
+ except Exception as e:
281
+ print(f" WARNING: neuralmagic failed: {e}")
282
+
283
+ tracker.done()
284
+ print(f"[transport] Total calibration samples: {len(samples)}")
285
+ sys.stdout.flush()
286
+ return samples
287
+
288
+
289
+ def extract_activations(
290
+ model: AutoModelForCausalLM,
291
+ calibration_data: list,
292
+ device: str = "cuda",
293
+ ) -> dict:
294
+ """
295
+ Extract intermediate activations from each layer of a model.
296
+
297
+ Runs calibration data through the model with hooks on each layer
298
+ to capture activation patterns. These activations are what the
299
+ optimal transport algorithm aligns between source and target.
300
+
301
+ Returns:
302
+ Dict mapping layer_name -> activation tensor [num_samples, hidden_dim]
303
+ """
304
+ tracker = ProgressTracker("extract-activations", interval_seconds=300)
305
+ tracker.set_total(len(calibration_data))
306
+ print(f"[transport] Extracting activations from {len(calibration_data)} samples...")
307
+ sys.stdout.flush()
308
+
309
+ activations = {}
310
+ hooks = []
311
+
312
+ # Register hooks on each transformer layer
313
+ for name, module in model.named_modules():
314
+ if hasattr(module, "self_attn") or name.endswith(".mlp"):
315
+ # Hook to capture output activations
316
+ def make_hook(layer_name):
317
+ def hook_fn(module, input, output):
318
+ # Handle tuple outputs (some layers return tuples)
319
+ if isinstance(output, tuple):
320
+ act = output[0]
321
+ else:
322
+ act = output
323
+ if layer_name not in activations:
324
+ activations[layer_name] = []
325
+ # Mean pool over sequence length -> [hidden_dim]
326
+ activations[layer_name].append(
327
+ act.detach().float().mean(dim=1).cpu()
328
+ )
329
+ return hook_fn
330
+
331
+ h = module.register_forward_hook(make_hook(name))
332
+ hooks.append(h)
333
+
334
+ # Forward pass on calibration data
335
+ model.eval()
336
+ with torch.no_grad():
337
+ for i, tokens in enumerate(calibration_data):
338
+ inputs = {k: v.to(device) for k, v in tokens.items()}
339
+ try:
340
+ model(**inputs)
341
+ except Exception as e:
342
+ print(f" WARNING: Sample {i} failed: {e}")
343
+ continue
344
+
345
+ tracker.tick(f"sample {i+1}")
346
+
347
+ if (i + 1) % 100 == 0:
348
+ print(f" Processed {i + 1}/{len(calibration_data)} samples")
349
+ sys.stdout.flush()
350
+
351
+ # Timeout: 30 min for activation extraction
352
+ tracker.check_timeout(timeout_seconds=1800)
353
+
354
+ # Remove hooks
355
+ for h in hooks:
356
+ h.remove()
357
+
358
+ # Stack activations: [num_samples, hidden_dim]
359
+ layer_count = 0
360
+ for key in activations:
361
+ activations[key] = torch.cat(activations[key], dim=0)
362
+ layer_count += 1
363
+
364
+ print(f" Extracted {layer_count} layers, shapes: {activations[list(activations.keys())[0]].shape if activations else \'empty\'}")
365
+ tracker.done()
366
+ sys.stdout.flush()
367
+
368
+ return activations
369
+
370
+
371
+ def compute_transport_plans(
372
+ source_activations: dict,
373
+ target_activations: dict,
374
+ cfg: MergeConfig,
375
+ ) -> dict:
376
+ """
377
+ Compute optimal transport plans between source and target activations.
378
+
379
+ This is where the magic happens. We use the official T&M code's:
380
+ - corr_distance_matrix: correlation distance between activation vectors
381
+ - sinkhorn_uniform_streaming: memory-efficient Sinkhorn solver
382
+ - compute_P: layer-level coupling (which source layers -> which target layers)
383
+ - compute_Q_and_layer_costs: neuron-level coupling within each layer pair
384
+
385
+ Returns:
386
+ Dict with 'P' (layer coupling) and 'Q' (per-layer neuron coupling) matrices
387
+ """
388
+ print("[transport] Computing transport plans...")
389
+ sys.stdout.flush()
390
+
391
+ try:
392
+ # Try importing official T&M code
393
+ from hot_transport import (
394
+ corr_distance_matrix,
395
+ sinkhorn_uniform_streaming,
396
+ compute_P,
397
+ compute_Q_and_layer_costs,
398
+ )
399
+ print("[transport] Using official T&M implementation")
400
+ return _compute_plans_official(
401
+ source_activations, target_activations, cfg,
402
+ corr_distance_matrix, sinkhorn_uniform_streaming,
403
+ compute_P, compute_Q_and_layer_costs,
404
+ )
405
+ except ImportError:
406
+ print("[transport] Official T&M code not available, using fallback")
407
+ return _compute_plans_fallback(
408
+ source_activations, target_activations, cfg
409
+ )
410
+
411
+
412
+ def _compute_plans_official(
413
+ source_act, target_act, cfg,
414
+ corr_distance_matrix, sinkhorn_uniform_streaming,
415
+ compute_P, compute_Q_and_layer_costs,
416
+ ) -> dict:
417
+ """Use the official T&M code to compute transport plans."""
418
+
419
+ # Get matching layer pairs
420
+ source_layers = sorted(source_act.keys())
421
+ target_layers = sorted(target_act.keys())
422
+
423
+ # Compute Q matrices (neuron-level) and layer costs
424
+ Q_matrices, layer_costs = compute_Q_and_layer_costs(
425
+ source_act, target_act,
426
+ source_layers, target_layers,
427
+ )
428
+
429
+ # Compute P matrix (layer-level coupling)
430
+ P = compute_P(layer_costs)
431
+
432
+ return {
433
+ "P": P,
434
+ "Q": Q_matrices,
435
+ "source_layers": source_layers,
436
+ "target_layers": target_layers,
437
+ }
438
+
439
+
440
+ def _compute_plans_fallback(
441
+ source_act: dict,
442
+ target_act: dict,
443
+ cfg: MergeConfig,
444
+ ) -> dict:
445
+ """
446
+ Fallback transport plan computation when official code isn't available.
447
+
448
+ Smart routing:
449
+ - Same-architecture models (same layer count): direct 1:1 layer matching
450
+ Check if neurons are aligned (DeepSeek) or scrambled (MiMo)
451
+ - Cross-architecture: sparse OT (only top-3 source layers per target)
452
+ """
453
+ tracker = ProgressTracker("transport-plans", interval_seconds=300)
454
+
455
+ source_layers = sorted(source_act.keys())
456
+ target_layers = sorted(target_act.keys())
457
+
458
+ n_source = len(source_layers)
459
+ n_target = len(target_layers)
460
+
461
+ print(f"[transport] Source layers: {n_source}, Target layers: {n_target}")
462
+ sys.stdout.flush()
463
+
464
+ # --- FAST PATH: same architecture (same layer count) ---
465
+ # Both models have the same number of transformer layers
466
+ # Match layers 1:1 but CHECK if neurons correspond
467
+ # DeepSeek: same training base -> neurons aligned -> identity Q (fast)
468
+ # MiMo: different training -> neurons scrambled -> need Sinkhorn permutation
469
+ if n_source == n_target:
470
+ print("[transport] Same layer count -- using direct 1:1 layer matching")
471
+ sys.stdout.flush()
472
+ Q_matrices = {}
473
+ permutations = {} # layer_pair -> permutation array (neuron reordering)
474
+ P = np.eye(n_source) / n_source # Identity coupling
475
+ tracker.set_total(n_source)
476
+
477
+ # Check first layer to decide: are neurons aligned or scrambled?
478
+ first_sl = source_layers[0]
479
+ first_tl = target_layers[0]
480
+ S0 = source_act[first_sl].numpy()
481
+ T0 = target_act[first_tl].numpy()
482
+ if S0.shape[1] == T0.shape[1]:
483
+ S0_norm = (S0 - S0.mean(0)) / (S0.std(0) + 1e-8)
484
+ T0_norm = (T0 - T0.mean(0)) / (T0.std(0) + 1e-8)
485
+ diag_corr = np.mean(np.sum(S0_norm * T0_norm, axis=0) / S0.shape[0])
486
+ neurons_aligned = diag_corr > 0.3
487
+ else:
488
+ neurons_aligned = False
489
+
490
+ if neurons_aligned:
491
+ print(f"[transport] Neurons ARE aligned (diag_corr={diag_corr:.3f}) -- identity Q (fast)")
492
+ print("[transport] This should take under 1 minute...")
493
+ else:
494
+ corr_val = diag_corr if S0.shape[1] == T0.shape[1] else 0.0
495
+ print(f"[transport] Neurons NOT aligned (diag_corr={corr_val:.3f}) -- computing permutations via Sinkhorn")
496
+ print("[transport] This may take 2-5 minutes...")
497
+ sys.stdout.flush()
498
+
499
+ for i, (sl, tl) in enumerate(zip(source_layers, target_layers)):
500
+ S = source_act[sl].numpy()
501
+ T = target_act[tl].numpy()
502
+
503
+ if S.shape[1] == T.shape[1]:
504
+ if neurons_aligned:
505
+ # Neurons already correspond (e.g. DeepSeek) -- identity Q
506
+ Q_matrices[(sl, tl)] = np.eye(S.shape[1]) / S.shape[1]
507
+ else:
508
+ # Neurons are SCRAMBLED (e.g. MiMo) -- find the permutation
509
+ # 1. Compute correlation matrix between source and target neurons
510
+ S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8)
511
+ T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8)
512
+ corr = S_norm.T @ T_norm / S.shape[0] # [hidden_dim, hidden_dim]
513
+
514
+ # 2. Run Sinkhorn on cost matrix to get soft transport plan
515
+ cost = 1.0 - corr
516
+ Q_soft = _sinkhorn(cost, reg=0.05, max_iter=cfg.sinkhorn_max_iter)
517
+
518
+ # 3. Extract hard permutation: for each source neuron, which target neuron?
519
+ perm = np.argmax(Q_soft, axis=1) # source_neuron -> target_neuron
520
+
521
+ # 4. Check for duplicate assignments (Sinkhorn should avoid this, but be safe)
522
+ if len(set(perm)) < len(perm) * 0.9:
523
+ # Too many collisions -- fall back to Hungarian-style greedy
524
+ perm = _greedy_permutation(corr)
525
+
526
+ permutations[(sl, tl)] = perm
527
+ Q_matrices[(sl, tl)] = Q_soft
528
+ else:
529
+ # Different dims -- do lightweight Sinkhorn on this pair only
530
+ print(f" Layer {i}: dim mismatch ({S.shape[1]} vs {T.shape[1]}), using Sinkhorn...")
531
+ S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8)
532
+ T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8)
533
+ corr = S_norm.T @ T_norm / S.shape[0]
534
+ cost = 1.0 - corr
535
+ Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50)
536
+
537
+ tracker.tick(f"{sl} -> {tl}")
538
+
539
+ if (i + 1) % 10 == 0 or i == 0:
540
+ print(f" Matched layer {i + 1}/{n_source}: {sl} -> {tl}")
541
+ sys.stdout.flush()
542
+
543
+ # Timeout: 15 min (permutation takes longer than identity)
544
+ tracker.check_timeout(timeout_seconds=900)
545
+
546
+ if permutations:
547
+ print(f"[transport] Computed {len(permutations)} neuron permutations")
548
+ print(f"[transport] Direct matching complete: {n_source} layer pairs")
549
+ tracker.done()
550
+ sys.stdout.flush()
551
+ return {
552
+ "P": P,
553
+ "Q": Q_matrices,
554
+ "permutations": permutations,
555
+ "source_layers": source_layers,
556
+ "target_layers": target_layers,
557
+ }
558
+
559
+ # --- CROSS-ARCHITECTURE PATH: sparse OT ---
560
+ # Only compute top-3 source layers per target (not all NxN pairs)
561
+ print(f"[transport] Cross-architecture -- using sparse OT (top-3 per target)")
562
+ print(f"[transport] Estimated time: 5-15 minutes")
563
+ sys.stdout.flush()
564
+
565
+ # Step 1: Compute layer-level similarity (cheap: just mean activation correlation)
566
+ print("[transport] Step 1/3: Computing layer-level similarities...")
567
+ sys.stdout.flush()
568
+ layer_costs = np.zeros((n_source, n_target))
569
+ tracker.set_total(n_source * n_target + n_target * 3)
570
+ for i, sl in enumerate(source_layers):
571
+ for j, tl in enumerate(target_layers):
572
+ S_mean = source_act[sl].mean(0).numpy()
573
+ T_mean = target_act[tl].mean(0).numpy()
574
+ # Cosine similarity as cheap proxy
575
+ min_dim = min(len(S_mean), len(T_mean))
576
+ s = S_mean[:min_dim]
577
+ t = T_mean[:min_dim]
578
+ sim = np.dot(s, t) / (np.linalg.norm(s) * np.linalg.norm(t) + 1e-8)
579
+ layer_costs[i, j] = 1.0 - sim
580
+ tracker.tick(f"layer sim {i},{j}")
581
+
582
+ # Timeout: 30 min for cross-arch
583
+ tracker.check_timeout(timeout_seconds=1800)
584
+
585
+ print(f"[transport] Step 1/3 done: {n_source}x{n_target} similarities computed")
586
+ sys.stdout.flush()
587
+
588
+ # Step 2: For each target layer, only compute Q for top-3 most similar source layers
589
+ print("[transport] Step 2/3: Computing neuron-level transport (top-3 per target)...")
590
+ sys.stdout.flush()
591
+ Q_matrices = {}
592
+ for j, tl in enumerate(target_layers):
593
+ top3 = np.argsort(layer_costs[:, j])[:3]
594
+ for i in top3:
595
+ sl = source_layers[i]
596
+ S = source_act[sl].numpy()
597
+ T = target_act[tl].numpy()
598
+
599
+ # Lightweight Sinkhorn (50 iterations, not 100+)
600
+ min_dim = min(S.shape[1], T.shape[1])
601
+ S_sub = S[:, :min_dim]
602
+ T_sub = T[:, :min_dim]
603
+ S_norm = (S_sub - S_sub.mean(0)) / (S_sub.std(0) + 1e-8)
604
+ T_norm = (T_sub - T_sub.mean(0)) / (T_sub.std(0) + 1e-8)
605
+ corr = S_norm.T @ T_norm / S.shape[0]
606
+ cost = 1.0 - corr
607
+ Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50)
608
+ tracker.tick(f"Q({sl},{tl})")
609
+
610
+ if (j + 1) % 5 == 0 or j == 0:
611
+ print(f" Target layer {j + 1}/{n_target}: matched to top-3 sources")
612
+ sys.stdout.flush()
613
+
614
+ # Timeout: 30 min for cross-arch
615
+ tracker.check_timeout(timeout_seconds=1800)
616
+
617
+ print(f"[transport] Step 2/3 done: {len(Q_matrices)} Q matrices computed")
618
+ sys.stdout.flush()
619
+
620
+ # Step 3: Layer coupling via Sinkhorn on layer costs
621
+ print("[transport] Step 3/3: Computing layer coupling P matrix...")
622
+ sys.stdout.flush()
623
+ P = _sinkhorn(layer_costs, reg=0.1, max_iter=50)
624
+
625
+ print(f"[transport] Sparse OT complete: {len(Q_matrices)} layer pairs computed")
626
+ tracker.done()
627
+ sys.stdout.flush()
628
+ return {
629
+ "P": P,
630
+ "Q": Q_matrices,
631
+ "permutations": {},
632
+ "source_layers": source_layers,
633
+ "target_layers": target_layers,
634
+ }
635
+
636
+
637
+ def _sinkhorn(
638
+ cost_matrix: np.ndarray,
639
+ reg: float = 0.05,
640
+ max_iter: int = 100,
641
+ ) -> np.ndarray:
642
+ """
643
+ Basic Sinkhorn-Knopp algorithm for optimal transport.
644
+
645
+ Solves: min <T, C> - reg * H(T)
646
+ where H(T) is the entropy of the transport plan.
647
+
648
+ This is the FALLBACK. The official code uses streaming Sinkhorn
649
+ which is more memory-efficient.
650
+ """
651
+ n, m = cost_matrix.shape
652
+ K = np.exp(-cost_matrix / reg)
653
+
654
+ u = np.ones(n) / n
655
+ v = np.ones(m) / m
656
+
657
+ for iteration in range(max_iter):
658
+ u = 1.0 / (K @ v + 1e-10)
659
+ v = 1.0 / (K.T @ u + 1e-10)
660
+
661
+ # Transport plan
662
+ T = np.diag(u) @ K @ np.diag(v)
663
+ return T
664
+
665
+
666
+ def _greedy_permutation(corr_matrix: np.ndarray) -> np.ndarray:
667
+ """
668
+ Greedy permutation assignment when Sinkhorn gives duplicate mappings.
669
+
670
+ For each source neuron (in order of strongest match), assign it to the
671
+ best available target neuron that hasn't been taken yet.
672
+ """
673
+ n = corr_matrix.shape[0]
674
+ perm = np.full(n, -1, dtype=np.int64)
675
+ taken = set()
676
+
677
+ # Process source neurons by strength of their best match (strongest first)
678
+ best_scores = np.max(corr_matrix, axis=1)
679
+ order = np.argsort(-best_scores)
680
+
681
+ for src in order:
682
+ # Find best available target
683
+ sorted_targets = np.argsort(-corr_matrix[src])
684
+ for tgt in sorted_targets:
685
+ if tgt not in taken:
686
+ perm[src] = tgt
687
+ taken.add(tgt)
688
+ break
689
+
690
+ # Safety: any unassigned source neurons get remaining targets
691
+ remaining = set(range(n)) - taken
692
+ for src in range(n):
693
+ if perm[src] == -1:
694
+ perm[src] = remaining.pop()
695
+
696
+ return perm
697
+
698
+
699
+ def _apply_permutation(source_w: torch.Tensor, perm: np.ndarray, key: str) -> torch.Tensor:
700
+ """
701
+ Apply neuron permutation to a source weight tensor before blending.
702
+
703
+ The permutation rearranges MiMo's neurons to match Qwen3's ordering.
704
+ Think of it like reorganising filing cabinets: same files, different order.
705
+
706
+ Which dimension to permute depends on the weight type:
707
+ - Input projections (q_proj, k_proj, v_proj, gate_proj, up_proj):
708
+ shape [out_features, in_features] -> permute columns (dim 1)
709
+ because input neurons need reordering
710
+ - Output projections (o_proj, down_proj):
711
+ shape [out_features, in_features] -> permute rows (dim 0)
712
+ because output neurons need reordering
713
+ - 1D weights (layer_norm, bias):
714
+ permute directly
715
+ """
716
+ perm_tensor = torch.from_numpy(perm).long()
717
+
718
+ if source_w.dim() == 1:
719
+ # 1D: layer norms, biases
720
+ if len(perm_tensor) == source_w.shape[0]:
721
+ return source_w[perm_tensor]
722
+ return source_w
723
+
724
+ if source_w.dim() == 2:
725
+ # 2D: linear layers
726
+ out_features, in_features = source_w.shape
727
+
728
+ # Output projections: neurons on dim 0 (rows)
729
+ if any(proj in key for proj in ["o_proj", "down_proj"]):
730
+ if len(perm_tensor) == out_features:
731
+ return source_w[perm_tensor, :]
732
+ # Input projections: neurons on dim 1 (columns)
733
+ elif any(proj in key for proj in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]):
734
+ if len(perm_tensor) == in_features:
735
+ return source_w[:, perm_tensor]
736
+ # Other 2D weights: try columns first (more common)
737
+ else:
738
+ if len(perm_tensor) == in_features:
739
+ return source_w[:, perm_tensor]
740
+ elif len(perm_tensor) == out_features:
741
+ return source_w[perm_tensor, :]
742
+
743
+ # Can't permute -- return unchanged
744
+ return source_w
745
+
746
+
747
+ def fuse_weights(
748
+ source_state: dict,
749
+ target_model: AutoModelForCausalLM,
750
+ transport_plans: dict,
751
+ source_config: ModelConfig,
752
+ cfg: MergeConfig,
753
+ target_activations: dict = None,
754
+ ) -> AutoModelForCausalLM:
755
+ """
756
+ Fuse source model weights into target model using transport plans.
757
+
758
+ For each layer pair with significant coupling (P > threshold):
759
+ 1. Get the Q matrix (neuron-level correspondence)
760
+ 2. Transport source weights into target neuron basis: W_fused = Q @ W_source
761
+ 3. Blend with target: W_final = alpha * W_fused + (1-alpha) * W_target
762
+
763
+ Args:
764
+ source_state: Source model state dict (can be on CPU -- will be moved per-param)
765
+ target_model: Target model (on GPU)
766
+ transport_plans: Transport plan matrices from compute_transport_plans
767
+ source_config: Source model config
768
+ cfg: Merge configuration
769
+
770
+ Special handling per model:
771
+ - DeepSeek: Direct merge (same architecture)
772
+ - MiMo: Skip MTP heads, skip embeddings, apply neuron permutation
773
+ - Llama: Layer mapping (32->36), skip embeddings, drop QKV bias
774
+ - Falcon: Skip Mamba components, skip embeddings
775
+
776
+ Returns:
777
+ Target model with fused weights
778
+ """
779
+ tracker = ProgressTracker("fuse-weights", interval_seconds=300)
780
+ print(f"\\n[transport] Fusing {source_config.name} -> target")
781
+ alpha = source_config.merge_alpha
782
+
783
+ try:
784
+ # Try official fusion code first
785
+ from generate_hot_residual import fuse_attention_only_from_hot_dir
786
+ print("[transport] Using official fusion implementation")
787
+ # TODO: Adapt official fusion to our pipeline
788
+ # For now, fall through to manual fusion
789
+ except ImportError:
790
+ pass
791
+
792
+ # --- Manual fusion using transport plans ---
793
+ # source_state is passed in (may be on CPU to save GPU memory)
794
+ target_state = target_model.state_dict()
795
+ P = transport_plans["P"]
796
+ Q = transport_plans["Q"]
797
+ permutations = transport_plans.get("permutations", {})
798
+
799
+ # Build layer-index -> permutation lookup
800
+ # permutations keys are (source_layer_name, target_layer_name) tuples
801
+ # We need to map weight keys like "model.layers.5.self_attn.q_proj.weight"
802
+ # to the permutation for layer 5
803
+ layer_perms = {}
804
+ for (sl, tl), perm in permutations.items():
805
+ # Extract layer index from target layer name (e.g. "model.layers.5.mlp" -> 5)
806
+ parts = tl.split(".")
807
+ for j, part in enumerate(parts):
808
+ if part == "layers" and j + 1 < len(parts):
809
+ try:
810
+ layer_idx = int(parts[j + 1])
811
+ layer_perms[layer_idx] = perm
812
+ except ValueError:
813
+ pass
814
+ break
815
+
816
+ if permutations:
817
+ print(f"[transport] Will apply neuron permutations to {len(layer_perms)} layers before blending")
818
+ else:
819
+ print("[transport] No neuron permutations needed (neurons already aligned)")
820
+
821
+ fused_count = 0
822
+ skipped_count = 0
823
+ permuted_count = 0
824
+ total_params = len(target_state)
825
+ tracker.set_total(total_params)
826
+
827
+ for target_key in target_state:
828
+ tracker.tick(target_key)
829
+
830
+ # Skip parameters we shouldn't merge
831
+ if _should_skip(target_key, source_config):
832
+ skipped_count += 1
833
+ continue
834
+
835
+ # Find corresponding source key
836
+ source_key = _map_key(target_key, source_config)
837
+ if source_key is None or source_key not in source_state:
838
+ skipped_count += 1
839
+ # Log first few misses to help debug key mapping issues
840
+ if skipped_count <= 5:
841
+ print(f" [skip] No source match for: {target_key} (mapped to: {source_key})")
842
+ sys.stdout.flush()
843
+ continue
844
+
845
+ target_w = target_state[target_key]
846
+ source_w = source_state[source_key]
847
+
848
+ # Handle dimension mismatches
849
+ if target_w.shape != source_w.shape:
850
+ # Use transport plan to align dimensions
851
+ source_w = _align_dimensions(source_w, target_w.shape, Q, target_key)
852
+ if source_w is None:
853
+ skipped_count += 1
854
+ continue
855
+
856
+ # --- NEURON PERMUTATION: rearrange source neurons to match target ---
857
+ # This is what makes MiMo merge work -- without this, it's like
858
+ # dumping one filing cabinet into another without matching folders
859
+ if layer_perms:
860
+ # Extract layer index from this weight's key
861
+ key_parts = target_key.split(".")
862
+ for j, part in enumerate(key_parts):
863
+ if part == "layers" and j + 1 < len(key_parts):
864
+ try:
865
+ lidx = int(key_parts[j + 1])
866
+ if lidx in layer_perms:
867
+ source_w = _apply_permutation(source_w, layer_perms[lidx], target_key)
868
+ permuted_count += 1
869
+ except ValueError:
870
+ pass
871
+ break
872
+
873
+ # Blend: W_final = alpha * source + (1-alpha) * target
874
+ fused_w = alpha * source_w.to(target_w.device) + (1 - alpha) * target_w
875
+ target_state[target_key] = fused_w
876
+ fused_count += 1
877
+
878
+ # Apply thinking mode protection (inside loop -- check each key)
879
+ if cfg.freeze_think_tokens and "embed_tokens" in target_key:
880
+ for token_id in cfg.think_token_ids:
881
+ if token_id < target_state[target_key].shape[0]:
882
+ # Restore original embedding for think tokens
883
+ orig_embed = target_model.state_dict()[target_key]
884
+ target_state[target_key][token_id] = orig_embed[token_id]
885
+ print(f"[transport] Protected think token {token_id}")
886
+
887
+ if fused_count % 50 == 0:
888
+ print(f" Fused {fused_count} params so far (skipped {skipped_count})...")
889
+ sys.stdout.flush()
890
+
891
+ # Timeout: 20 min for weight fusion
892
+ tracker.check_timeout(timeout_seconds=1200)
893
+
894
+ # Load fused weights (strict=False: vision encoder may have bitsandbytes quant keys
895
+ # that don't match the original key names -- we never modify vision weights anyway)
896
+ missing, unexpected = target_model.load_state_dict(target_state, strict=False)
897
+ if missing:
898
+ print(f"[transport] NOTE: {len(missing)} missing keys (likely quantized vision params -- safe to ignore)")
899
+ if unexpected:
900
+ print(f"[transport] NOTE: {len(unexpected)} unexpected keys (safe to ignore)")
901
+ perm_msg = f", permuted {permuted_count}" if permuted_count else ""
902
+ print(f"[transport] Fused {fused_count} params, skipped {skipped_count}{perm_msg}")
903
+ tracker.done()
904
+ sys.stdout.flush()
905
+
906
+ return target_model
907
+
908
+
909
+ def _should_skip(key: str, source_config: ModelConfig) -> bool:
910
+ """Determine if a parameter should be skipped during merge."""
911
+
912
+ # Skip vision encoder params (Qwen3-VL) -- these should never be merged
913
+ if key.startswith("visual") or key.startswith("merger") or key.startswith("model.visual") or key.startswith("model.merger"):
914
+ return True
915
+
916
+ # Always skip if source model says to skip embeddings
917
+ if source_config.skip_embeddings and ("embed_tokens" in key or "lm_head" in key):
918
+ return True
919
+
920
+ # Skip MiMo MTP heads
921
+ if "drop_mtp_heads" in source_config.special_handling and "mtp_head" in key:
922
+ return True
923
+
924
+ # Skip Falcon Mamba-specific parameters
925
+ if "drop_mamba_state_params" in source_config.special_handling:
926
+ mamba_keys = ["mamba", "A_log", "dt_proj", ".D"]
927
+ if any(mk in key for mk in mamba_keys):
928
+ return True
929
+
930
+ # Skip QKV bias for Llama (Qwen3 doesn't have it)
931
+ if "drop_qkv_bias" in source_config.special_handling and ".bias" in key:
932
+ if any(proj in key for proj in ["q_proj", "k_proj", "v_proj"]):
933
+ return True
934
+
935
+ return False
936
+
937
+
938
+ def _strip_vl_prefix(key: str) -> str:
939
+ """
940
+ Strip the 'language_model.' prefix that Qwen3-VL adds.
941
+
942
+ Qwen3-VL wraps all language params under 'model.language_model.*'
943
+ but source models (DeepSeek, MiMo, Llama, Falcon) use 'model.*' directly.
944
+
945
+ Example:
946
+ target: model.language_model.layers.0.self_attn.q_proj.weight
947
+ source: model.layers.0.self_attn.q_proj.weight
948
+ """
949
+ # model.language_model.X -> model.X
950
+ if "language_model." in key:
951
+ return key.replace("language_model.", "")
952
+ return key
953
+
954
+
955
+ def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]:
956
+ """Map a target model parameter name to the corresponding source name."""
957
+
958
+ # Step 1: Strip Qwen3-VL's language_model. prefix so we can match source keys
959
+ source_key = _strip_vl_prefix(target_key)
960
+
961
+ # For same-architecture models (DeepSeek), keys match directly after prefix strip
962
+ if source_config.architecture == "transformer" and source_config.layers == 36:
963
+ return source_key
964
+
965
+ # For Llama (32 layers -> 36 layers), map layer indices
966
+ if "layer_mapping_32_to_36" in source_config.special_handling:
967
+ if "model.layers." in source_key:
968
+ # Extract layer number
969
+ parts = source_key.split(".")
970
+ try:
971
+ layer_idx = int(parts[2])
972
+ except (IndexError, ValueError):
973
+ return source_key
974
+
975
+ # Map 36 target layers to 32 source layers (stride)
976
+ source_layer = int(layer_idx * 32 / 36)
977
+ parts[2] = str(source_layer)
978
+ return ".".join(parts)
979
+
980
+ # For MiMo (same layer count, different extras), keys mostly match
981
+ if source_config.architecture == "transformer+mtp":
982
+ if "mtp_head" in source_key:
983
+ return None # MTP heads don't exist in target
984
+ return source_key
985
+
986
+ # For Falcon hybrid, only attention and MLP keys map
987
+ if source_config.architecture == "hybrid_ssm":
988
+ if any(k in source_key for k in ["self_attn", "mlp", "layer_norm"]):
989
+ return source_key # These exist in both
990
+ return None # Mamba components don't map
991
+
992
+ return source_key
993
+
994
+
995
+ def _align_dimensions(
996
+ source_w: torch.Tensor,
997
+ target_shape: tuple,
998
+ Q_matrices: dict,
999
+ key: str,
1000
+ ) -> Optional[torch.Tensor]:
1001
+ """
1002
+ Align source weight dimensions to target shape using transport plans.
1003
+
1004
+ For small mismatches: pad or truncate.
1005
+ For large mismatches: use Q matrix to project.
1006
+ """
1007
+ if source_w.shape == target_shape:
1008
+ return source_w
1009
+
1010
+ # Simple case: different width (FFN size difference)
1011
+ if len(source_w.shape) == 2 and len(target_shape) == 2:
1012
+ s_rows, s_cols = source_w.shape
1013
+ t_rows, t_cols = target_shape
1014
+
1015
+ result = torch.zeros(target_shape, dtype=source_w.dtype)
1016
+
1017
+ # Copy what fits
1018
+ min_rows = min(s_rows, t_rows)
1019
+ min_cols = min(s_cols, t_cols)
1020
+ result[:min_rows, :min_cols] = source_w[:min_rows, :min_cols]
1021
+
1022
+ return result
1023
+
1024
+ # 1D case (biases, layer norms)
1025
+ if len(source_w.shape) == 1 and len(target_shape) == 1:
1026
+ result = torch.zeros(target_shape, dtype=source_w.dtype)
1027
+ min_len = min(source_w.shape[0], target_shape[0])
1028
+ result[:min_len] = source_w[:min_len]
1029
+ return result
1030
+
1031
+ # Can't align -- skip this parameter
1032
+ return None
1033
+ '''
1034
+ with open("td_fuse/transport.py", 'w') as f:
1035
+ f.write(code)
1036
+
1037
+
1038
+ if __name__ == "__main__":
1039
+ main()
hugging/td_fuse/config.py CHANGED
@@ -107,7 +107,7 @@ SOURCES = [
107
  skip_embeddings=True, # Must skip — vocab too different
108
  trust_remote_code=True, # Custom MTP architecture
109
  merge_risk="medium",
110
- merge_alpha=0.4, # Slightly lower preserve target
111
  special_handling=["drop_mtp_heads", "skip_embeddings"],
112
  notes=(
113
  "Xiaomi's reasoning model. Same layer count and hidden dim as Qwen3. "
 
107
  skip_embeddings=True, # Must skip — vocab too different
108
  trust_remote_code=True, # Custom MTP architecture
109
  merge_risk="medium",
110
+ merge_alpha=0.15, # LowMiMo neurons need permutation, keep target dominant
111
  special_handling=["drop_mtp_heads", "skip_embeddings"],
112
  notes=(
113
  "Xiaomi's reasoning model. Same layer count and hidden dim as Qwen3. "
hugging/td_fuse/heal.py CHANGED
@@ -247,6 +247,7 @@ def apply_qlora_standard(
247
  if os.path.exists(healed_check):
248
  print('[heal] Found existing healed model — SKIPPING healing!')
249
  return 'td_fuse_outputs/healed'
 
250
  from peft import LoraConfig, get_peft_model, TaskType
251
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
252
 
@@ -353,39 +354,55 @@ def apply_qlora_standard(
353
  print(f"\n[heal] Merging LoRA adapters...")
354
  merged_model = model.merge_and_unload()
355
 
356
- # Free disk space before saving — remove duplicate model copies
357
  import shutil, gc
358
- print("[heal] Freeing disk space before save...")
359
 
360
- # Search for large duplicate directories we can safely remove
361
- # The healed model in memory IS the final product — we don't need old copies
362
- cleanup_targets = [
363
- "td_fuse_outputs/final", # duplicate of after_deepseek
364
- "td_fuse_outputs/healed", # old healed dir if exists
365
- ]
366
- for target in cleanup_targets:
367
- target_path = Path(target)
368
- if target_path.exists() and target_path.is_dir():
369
- shutil.rmtree(str(target_path))
370
- print(f"[heal] Freed space: removed {target_path}")
371
-
372
- # Remove any trainer checkpoint-* dirs (we have the merged model in memory)
373
- for parent in [Path("."), Path("td_lang_outputs"), Path(cfg.output_dir)]:
374
- if parent.exists():
375
- for ckpt in parent.rglob("checkpoint-*"):
376
- if ckpt.is_dir():
377
- shutil.rmtree(str(ckpt))
378
- print(f"[heal] Freed space: removed {ckpt}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
  gc.collect()
381
 
382
- # Report free space
383
- stat = shutil.disk_usage("/")
384
- print(f"[heal] Disk space: {stat.free / 1e9:.1f} GB free / {stat.total / 1e9:.1f} GB total")
385
-
386
- merged_model.save_pretrained(str(healed_dir))
387
- tokenizer.save_pretrained(str(healed_dir))
388
-
389
  print(f"[heal] Healed model saved to {healed_dir}")
390
  return str(healed_dir)
391
 
 
247
  if os.path.exists(healed_check):
248
  print('[heal] Found existing healed model — SKIPPING healing!')
249
  return 'td_fuse_outputs/healed'
250
+ import torch
251
  from peft import LoraConfig, get_peft_model, TaskType
252
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
253
 
 
354
  print(f"\n[heal] Merging LoRA adapters...")
355
  merged_model = model.merge_and_unload()
356
 
 
357
  import shutil, gc
 
358
 
359
+ # SAVE FIRST never delete anything until save is confirmed
360
+ # save_pretrained can fail on 4-bit merged models (NotImplementedError)
361
+ # So we go straight to the safe manual method
362
+ print(f"[heal] Saving healed model to {healed_dir}...")
363
+ try:
364
+ from safetensors.torch import save_file
365
+ import torch as _torch
366
+ state_dict = merged_model.state_dict()
367
+ clean_state = {}
368
+ for k, v in state_dict.items():
369
+ if hasattr(v, 'dequantize'):
370
+ clean_state[k] = v.dequantize().to(_torch.bfloat16)
371
+ elif v.dtype in (_torch.float32, _torch.float16, _torch.bfloat16):
372
+ clean_state[k] = v.to(_torch.bfloat16)
373
+ else:
374
+ clean_state[k] = v.float().to(_torch.bfloat16)
375
+ save_file(clean_state, str(healed_dir / "model.safetensors"))
376
+ if hasattr(merged_model, 'config'):
377
+ merged_model.config.save_pretrained(str(healed_dir))
378
+ tokenizer.save_pretrained(str(healed_dir))
379
+ print(f"[heal] SAVED OK: {healed_dir / 'model.safetensors'}")
380
+ except Exception as e:
381
+ # Emergency fallback: try save_pretrained as last resort
382
+ print(f"[heal] Manual save failed ({e}), trying save_pretrained...")
383
+ merged_model.save_pretrained(str(healed_dir))
384
+ tokenizer.save_pretrained(str(healed_dir))
385
+ print(f"[heal] SAVED OK via save_pretrained: {healed_dir}")
386
+
387
+ # Verify the save actually worked before cleaning up ANYTHING
388
+ saved_model = healed_dir / "model.safetensors"
389
+ if not saved_model.exists() or saved_model.stat().st_size < 1_000_000:
390
+ print(f"[heal] WARNING: Save may have failed — NOT deleting any backups!")
391
+ else:
392
+ save_size = saved_model.stat().st_size / 1e9
393
+ print(f"[heal] Verified: {saved_model} ({save_size:.1f} GB)")
394
+ # NOW safe to clean up old stuff
395
+ cleanup_targets = [
396
+ "td_fuse_outputs/final",
397
+ ]
398
+ for target in cleanup_targets:
399
+ target_path = Path(target)
400
+ if target_path.exists() and target_path.is_dir():
401
+ shutil.rmtree(str(target_path))
402
+ print(f"[heal] Freed space: removed {target_path}")
403
 
404
  gc.collect()
405
 
 
 
 
 
 
 
 
406
  print(f"[heal] Healed model saved to {healed_dir}")
407
  return str(healed_dir)
408
 
hugging/td_fuse/merge.py CHANGED
@@ -484,6 +484,9 @@ class ResidualBank:
484
  # What the source lost (what didn't make it into the merge)
485
  if key in source_state:
486
  original_source = source_state[key].float()
 
 
 
487
  s_residual = original_source - merged_w
488
  s_loss = s_residual.abs().mean().item()
489
 
 
484
  # What the source lost (what didn't make it into the merge)
485
  if key in source_state:
486
  original_source = source_state[key].float()
487
+ # Skip if shapes don't match (e.g. vocab size mismatch on embeddings/lm_head)
488
+ if original_source.shape != merged_w.shape:
489
+ continue
490
  s_residual = original_source - merged_w
491
  s_loss = s_residual.abs().mean().item()
492
 
hugging/td_fuse/transport.py CHANGED
@@ -360,24 +360,69 @@ def _compute_plans_fallback(
360
  sys.stdout.flush()
361
 
362
  # --- FAST PATH: same architecture (same layer count) ---
363
- # DeepSeek-R1-0528-Qwen3-8B has the same architecture as Qwen3-VL-8B
364
- # Both have 36 transformer layers with identical hidden dims
365
- # No need for expensive OT -- just match layers 1:1
 
366
  if n_source == n_target:
367
- print("[transport] Same layer count -- using direct 1:1 layer matching (fast path)")
368
- print("[transport] This should take under 1 minute...")
369
  sys.stdout.flush()
370
  Q_matrices = {}
 
371
  P = np.eye(n_source) / n_source # Identity coupling
372
  tracker.set_total(n_source)
373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  for i, (sl, tl) in enumerate(zip(source_layers, target_layers)):
375
  S = source_act[sl].numpy()
376
  T = target_act[tl].numpy()
377
 
378
- # For same-dim layers, Q is identity (neurons already correspond)
379
  if S.shape[1] == T.shape[1]:
380
- Q_matrices[(sl, tl)] = np.eye(S.shape[1]) / S.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  else:
382
  # Different dims -- do lightweight Sinkhorn on this pair only
383
  print(f" Layer {i}: dim mismatch ({S.shape[1]} vs {T.shape[1]}), using Sinkhorn...")
@@ -393,15 +438,18 @@ def _compute_plans_fallback(
393
  print(f" Matched layer {i + 1}/{n_source}: {sl} -> {tl}")
394
  sys.stdout.flush()
395
 
396
- # Timeout: 10 min for fast path (should take seconds)
397
- tracker.check_timeout(timeout_seconds=600)
398
 
 
 
399
  print(f"[transport] Direct matching complete: {n_source} layer pairs")
400
  tracker.done()
401
  sys.stdout.flush()
402
  return {
403
  "P": P,
404
  "Q": Q_matrices,
 
405
  "source_layers": source_layers,
406
  "target_layers": target_layers,
407
  }
@@ -478,6 +526,7 @@ def _compute_plans_fallback(
478
  return {
479
  "P": P,
480
  "Q": Q_matrices,
 
481
  "source_layers": source_layers,
482
  "target_layers": target_layers,
483
  }
@@ -512,6 +561,87 @@ def _sinkhorn(
512
  return T
513
 
514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  def fuse_weights(
516
  source_state: dict,
517
  target_model: AutoModelForCausalLM,
@@ -562,9 +692,33 @@ def fuse_weights(
562
  target_state = target_model.state_dict()
563
  P = transport_plans["P"]
564
  Q = transport_plans["Q"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
 
566
  fused_count = 0
567
  skipped_count = 0
 
568
  total_params = len(target_state)
569
  tracker.set_total(total_params)
570
 
@@ -597,6 +751,23 @@ def fuse_weights(
597
  skipped_count += 1
598
  continue
599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
  # Blend: W_final = alpha * source + (1-alpha) * target
601
  fused_w = alpha * source_w.to(target_w.device) + (1 - alpha) * target_w
602
  target_state[target_key] = fused_w
@@ -618,9 +789,15 @@ def fuse_weights(
618
  # Timeout: 20 min for weight fusion
619
  tracker.check_timeout(timeout_seconds=1200)
620
 
621
- # Load fused weights
622
- target_model.load_state_dict(target_state)
623
- print(f"[transport] Fused {fused_count} params, skipped {skipped_count}")
 
 
 
 
 
 
624
  tracker.done()
625
  sys.stdout.flush()
626
 
 
360
  sys.stdout.flush()
361
 
362
  # --- FAST PATH: same architecture (same layer count) ---
363
+ # Both models have the same number of transformer layers
364
+ # Match layers 1:1 but CHECK if neurons correspond
365
+ # DeepSeek: same training base neurons aligned identity Q (fast)
366
+ # MiMo: different training → neurons scrambled → need Sinkhorn permutation
367
  if n_source == n_target:
368
+ print("[transport] Same layer count -- using direct 1:1 layer matching")
 
369
  sys.stdout.flush()
370
  Q_matrices = {}
371
+ permutations = {} # layer_pair -> permutation array (neuron reordering)
372
  P = np.eye(n_source) / n_source # Identity coupling
373
  tracker.set_total(n_source)
374
 
375
+ # Check first layer to decide: are neurons aligned or scrambled?
376
+ first_sl = source_layers[0]
377
+ first_tl = target_layers[0]
378
+ S0 = source_act[first_sl].numpy()
379
+ T0 = target_act[first_tl].numpy()
380
+ if S0.shape[1] == T0.shape[1]:
381
+ S0_norm = (S0 - S0.mean(0)) / (S0.std(0) + 1e-8)
382
+ T0_norm = (T0 - T0.mean(0)) / (T0.std(0) + 1e-8)
383
+ diag_corr = np.mean(np.sum(S0_norm * T0_norm, axis=0) / S0.shape[0])
384
+ neurons_aligned = diag_corr > 0.3
385
+ else:
386
+ neurons_aligned = False
387
+
388
+ if neurons_aligned:
389
+ print(f"[transport] Neurons ARE aligned (diag_corr={diag_corr:.3f}) — identity Q (fast)")
390
+ print("[transport] This should take under 1 minute...")
391
+ else:
392
+ corr_val = diag_corr if S0.shape[1] == T0.shape[1] else 0.0
393
+ print(f"[transport] Neurons NOT aligned (diag_corr={corr_val:.3f}) — computing permutations via Sinkhorn")
394
+ print("[transport] This may take 2-5 minutes...")
395
+ sys.stdout.flush()
396
+
397
  for i, (sl, tl) in enumerate(zip(source_layers, target_layers)):
398
  S = source_act[sl].numpy()
399
  T = target_act[tl].numpy()
400
 
 
401
  if S.shape[1] == T.shape[1]:
402
+ if neurons_aligned:
403
+ # Neurons already correspond (e.g. DeepSeek) — identity Q
404
+ Q_matrices[(sl, tl)] = np.eye(S.shape[1]) / S.shape[1]
405
+ else:
406
+ # Neurons are SCRAMBLED (e.g. MiMo) — find the permutation
407
+ # 1. Compute correlation matrix between source and target neurons
408
+ S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8)
409
+ T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8)
410
+ corr = S_norm.T @ T_norm / S.shape[0] # [hidden_dim, hidden_dim]
411
+
412
+ # 2. Run Sinkhorn on cost matrix to get soft transport plan
413
+ cost = 1.0 - corr
414
+ Q_soft = _sinkhorn(cost, reg=0.05, max_iter=cfg.sinkhorn_max_iter)
415
+
416
+ # 3. Extract hard permutation: for each source neuron, which target neuron?
417
+ perm = np.argmax(Q_soft, axis=1) # source_neuron -> target_neuron
418
+
419
+ # 4. Check for duplicate assignments (Sinkhorn should avoid this, but be safe)
420
+ if len(set(perm)) < len(perm) * 0.9:
421
+ # Too many collisions — fall back to Hungarian-style greedy
422
+ perm = _greedy_permutation(corr)
423
+
424
+ permutations[(sl, tl)] = perm
425
+ Q_matrices[(sl, tl)] = Q_soft
426
  else:
427
  # Different dims -- do lightweight Sinkhorn on this pair only
428
  print(f" Layer {i}: dim mismatch ({S.shape[1]} vs {T.shape[1]}), using Sinkhorn...")
 
438
  print(f" Matched layer {i + 1}/{n_source}: {sl} -> {tl}")
439
  sys.stdout.flush()
440
 
441
+ # Timeout: 15 min (permutation takes longer than identity)
442
+ tracker.check_timeout(timeout_seconds=900)
443
 
444
+ if permutations:
445
+ print(f"[transport] Computed {len(permutations)} neuron permutations")
446
  print(f"[transport] Direct matching complete: {n_source} layer pairs")
447
  tracker.done()
448
  sys.stdout.flush()
449
  return {
450
  "P": P,
451
  "Q": Q_matrices,
452
+ "permutations": permutations,
453
  "source_layers": source_layers,
454
  "target_layers": target_layers,
455
  }
 
526
  return {
527
  "P": P,
528
  "Q": Q_matrices,
529
+ "permutations": {},
530
  "source_layers": source_layers,
531
  "target_layers": target_layers,
532
  }
 
561
  return T
562
 
563
 
564
+ def _greedy_permutation(corr_matrix: np.ndarray) -> np.ndarray:
565
+ """
566
+ Greedy permutation assignment when Sinkhorn gives duplicate mappings.
567
+
568
+ For each source neuron (in order of strongest match), assign it to the
569
+ best available target neuron that hasn't been taken yet.
570
+ """
571
+ n = corr_matrix.shape[0]
572
+ perm = np.full(n, -1, dtype=np.int64)
573
+ taken = set()
574
+
575
+ # Process source neurons by strength of their best match (strongest first)
576
+ best_scores = np.max(corr_matrix, axis=1)
577
+ order = np.argsort(-best_scores)
578
+
579
+ for src in order:
580
+ # Find best available target
581
+ sorted_targets = np.argsort(-corr_matrix[src])
582
+ for tgt in sorted_targets:
583
+ if tgt not in taken:
584
+ perm[src] = tgt
585
+ taken.add(tgt)
586
+ break
587
+
588
+ # Safety: any unassigned source neurons get remaining targets
589
+ remaining = set(range(n)) - taken
590
+ for src in range(n):
591
+ if perm[src] == -1:
592
+ perm[src] = remaining.pop()
593
+
594
+ return perm
595
+
596
+
597
+ def _apply_permutation(source_w: torch.Tensor, perm: np.ndarray, key: str) -> torch.Tensor:
598
+ """
599
+ Apply neuron permutation to a source weight tensor before blending.
600
+
601
+ The permutation rearranges MiMo's neurons to match Qwen3's ordering.
602
+ Think of it like reorganising filing cabinets: same files, different order.
603
+
604
+ Which dimension to permute depends on the weight type:
605
+ - Input projections (q_proj, k_proj, v_proj, gate_proj, up_proj):
606
+ shape [out_features, in_features] → permute columns (dim 1)
607
+ because input neurons need reordering
608
+ - Output projections (o_proj, down_proj):
609
+ shape [out_features, in_features] → permute rows (dim 0)
610
+ because output neurons need reordering
611
+ - 1D weights (layer_norm, bias):
612
+ permute directly
613
+ """
614
+ perm_tensor = torch.from_numpy(perm).long()
615
+
616
+ if source_w.dim() == 1:
617
+ # 1D: layer norms, biases
618
+ if len(perm_tensor) == source_w.shape[0]:
619
+ return source_w[perm_tensor]
620
+ return source_w
621
+
622
+ if source_w.dim() == 2:
623
+ # 2D: linear layers
624
+ out_features, in_features = source_w.shape
625
+
626
+ # Output projections: neurons on dim 0 (rows)
627
+ if any(proj in key for proj in ["o_proj", "down_proj"]):
628
+ if len(perm_tensor) == out_features:
629
+ return source_w[perm_tensor, :]
630
+ # Input projections: neurons on dim 1 (columns)
631
+ elif any(proj in key for proj in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]):
632
+ if len(perm_tensor) == in_features:
633
+ return source_w[:, perm_tensor]
634
+ # Other 2D weights: try columns first (more common)
635
+ else:
636
+ if len(perm_tensor) == in_features:
637
+ return source_w[:, perm_tensor]
638
+ elif len(perm_tensor) == out_features:
639
+ return source_w[perm_tensor, :]
640
+
641
+ # Can't permute — return unchanged
642
+ return source_w
643
+
644
+
645
  def fuse_weights(
646
  source_state: dict,
647
  target_model: AutoModelForCausalLM,
 
692
  target_state = target_model.state_dict()
693
  P = transport_plans["P"]
694
  Q = transport_plans["Q"]
695
+ permutations = transport_plans.get("permutations", {})
696
+
697
+ # Build layer-index -> permutation lookup
698
+ # permutations keys are (source_layer_name, target_layer_name) tuples
699
+ # We need to map weight keys like "model.layers.5.self_attn.q_proj.weight"
700
+ # to the permutation for layer 5
701
+ layer_perms = {}
702
+ for (sl, tl), perm in permutations.items():
703
+ # Extract layer index from target layer name (e.g. "model.layers.5.mlp" -> 5)
704
+ parts = tl.split(".")
705
+ for j, part in enumerate(parts):
706
+ if part == "layers" and j + 1 < len(parts):
707
+ try:
708
+ layer_idx = int(parts[j + 1])
709
+ layer_perms[layer_idx] = perm
710
+ except ValueError:
711
+ pass
712
+ break
713
+
714
+ if permutations:
715
+ print(f"[transport] Will apply neuron permutations to {len(layer_perms)} layers before blending")
716
+ else:
717
+ print("[transport] No neuron permutations needed (neurons already aligned)")
718
 
719
  fused_count = 0
720
  skipped_count = 0
721
+ permuted_count = 0
722
  total_params = len(target_state)
723
  tracker.set_total(total_params)
724
 
 
751
  skipped_count += 1
752
  continue
753
 
754
+ # --- NEURON PERMUTATION: rearrange source neurons to match target ---
755
+ # This is what makes MiMo merge work — without this, it's like
756
+ # dumping one filing cabinet into another without matching folders
757
+ if layer_perms:
758
+ # Extract layer index from this weight's key
759
+ key_parts = target_key.split(".")
760
+ for j, part in enumerate(key_parts):
761
+ if part == "layers" and j + 1 < len(key_parts):
762
+ try:
763
+ lidx = int(key_parts[j + 1])
764
+ if lidx in layer_perms:
765
+ source_w = _apply_permutation(source_w, layer_perms[lidx], target_key)
766
+ permuted_count += 1
767
+ except ValueError:
768
+ pass
769
+ break
770
+
771
  # Blend: W_final = alpha * source + (1-alpha) * target
772
  fused_w = alpha * source_w.to(target_w.device) + (1 - alpha) * target_w
773
  target_state[target_key] = fused_w
 
789
  # Timeout: 20 min for weight fusion
790
  tracker.check_timeout(timeout_seconds=1200)
791
 
792
+ # Load fused weights (strict=False: vision encoder may have bitsandbytes quant keys
793
+ # that don't match the original key names — we never modify vision weights anyway)
794
+ missing, unexpected = target_model.load_state_dict(target_state, strict=False)
795
+ if missing:
796
+ print(f"[transport] NOTE: {len(missing)} missing keys (likely quantized vision params — safe to ignore)")
797
+ if unexpected:
798
+ print(f"[transport] NOTE: {len(unexpected)} unexpected keys (safe to ignore)")
799
+ perm_msg = f", permuted {permuted_count}" if permuted_count else ""
800
+ print(f"[transport] Fused {fused_count} params, skipped {skipped_count}{perm_msg}")
801
  tracker.done()
802
  sys.stdout.flush()
803
 
hugging/td_start.td CHANGED
@@ -51,7 +51,7 @@ merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength
51
  # Medium risk: same layer count (36) and hidden_dim (4096)
52
  # MTP heads get dropped automatically (no Qwen3 equivalent)
53
  # Embeddings skipped (28% vocab overlap too low)
54
- merge "XiaomiMiMo/MiMo-7B-RL" into base using transport strength 0.4
55
 
56
  # --- Step 3: Heal any merge damage ---
57
  # QLoRA fine-tune to smooth out rough edges from the merge
 
51
  # Medium risk: same layer count (36) and hidden_dim (4096)
52
  # MTP heads get dropped automatically (no Qwen3 equivalent)
53
  # Embeddings skipped (28% vocab overlap too low)
54
+ merge "XiaomiMiMo/MiMo-7B-RL" into base using transport strength 0.15
55
 
56
  # --- Step 3: Heal any merge damage ---
57
  # QLoRA fine-tune to smooth out rough edges from the merge