td-builder commited on
Commit
1e0fbdd
Β·
verified Β·
1 Parent(s): 10244f6

Upload 137 files

Browse files
hugging/td_fuse/canary.py CHANGED
@@ -186,7 +186,7 @@ def test_all_canaries(
186
  results = {}
187
 
188
  # Test the target model's canary
189
- results["Qwen3-8B"] = test_canary(model, tokenizer, "Qwen3-8B")
190
 
191
  # Test each merged source model's canary
192
  for source_name in merged_sources:
 
186
  results = {}
187
 
188
  # Test the target model's canary
189
+ results["Qwen3-VL-8B"] = test_canary(model, tokenizer, "Qwen3-VL-8B")
190
 
191
  # Test each merged source model's canary
192
  for source_name in merged_sources:
hugging/td_fuse/heal.py CHANGED
@@ -18,6 +18,8 @@ Findings: #12, #16, #20
18
  """
19
 
20
  import os
 
 
21
  import torch
22
  from pathlib import Path
23
  from typing import Optional
@@ -350,14 +352,20 @@ def heal_model(
350
  if cfg is None:
351
  cfg = MergeConfig()
352
 
 
353
  print("\n" + "=" * 60)
354
  print("HEALING FINE-TUNE")
355
  print(f"Model: {model_path}")
356
  print(f"LoRA r={cfg.heal_lora_r}, alpha={cfg.heal_lora_alpha}")
357
  print(f"Epochs: {cfg.heal_epochs}, LR: {cfg.heal_learning_rate}")
 
358
  print("=" * 60)
 
359
 
360
  if check_unsloth_available():
361
- return apply_qlora_unsloth(model_path, cfg, healing_data)
362
  else:
363
- return apply_qlora_standard(model_path, cfg, healing_data)
 
 
 
 
18
  """
19
 
20
  import os
21
+ import sys
22
+ import time
23
  import torch
24
  from pathlib import Path
25
  from typing import Optional
 
352
  if cfg is None:
353
  cfg = MergeConfig()
354
 
355
+ heal_start = time.time()
356
  print("\n" + "=" * 60)
357
  print("HEALING FINE-TUNE")
358
  print(f"Model: {model_path}")
359
  print(f"LoRA r={cfg.heal_lora_r}, alpha={cfg.heal_lora_alpha}")
360
  print(f"Epochs: {cfg.heal_epochs}, LR: {cfg.heal_learning_rate}")
361
+ print(f"Started at: {time.strftime('%H:%M:%S')}")
362
  print("=" * 60)
363
+ sys.stdout.flush()
364
 
365
  if check_unsloth_available():
366
+ result = apply_qlora_unsloth(model_path, cfg, healing_data)
367
  else:
368
+ result = apply_qlora_standard(model_path, cfg, healing_data)
369
+ print(f"[heal] Total healing time: {(time.time()-heal_start)/60:.1f} min")
370
+ sys.stdout.flush()
371
+ return result
hugging/td_fuse/merge.py CHANGED
@@ -22,7 +22,9 @@ Findings: #13, #22, #25
22
 
23
  import os
24
  import gc
 
25
  import copy
 
26
  import torch
27
  import numpy as np
28
  from pathlib import Path
@@ -593,10 +595,13 @@ def run_single_merge(
593
  merged_sources = []
594
 
595
  stage_name = source_config.name
 
596
  print(f"\n{'=' * 70}")
597
- print(f"MERGE STAGE: {stage_name} β†’ target")
598
  print(f"Risk level: {source_config.merge_risk.upper()}")
 
599
  print(f"{'=' * 70}")
 
600
 
601
  result = {
602
  "stage": stage_name,
@@ -606,23 +611,34 @@ def run_single_merge(
606
  }
607
 
608
  # --- Step 1: Load source model ---
 
 
609
  source_model, source_tokenizer = load_model(source_config, cfg)
 
610
 
611
  # --- Step 2: Inject canary into source ---
 
 
612
  if stage_name in CANARY_FACTS:
613
- print(f"\n[merge] Injecting canary fact into {stage_name}...")
614
  source_model = inject_canary(source_model, source_tokenizer, stage_name)
 
615
 
616
  # --- Step 3: Load calibration data (if not provided) ---
 
 
617
  if calibration_data is None:
618
  calibration_data = load_calibration_data(cfg, target_tokenizer)
 
619
 
620
  # --- Step 4: Extract activations ---
621
- print(f"\n[merge] Extracting source activations...")
 
 
622
  source_activations = extract_activations(source_model, calibration_data)
623
 
624
- print(f"\n[merge] Extracting target activations...")
625
  pre_merge_target_activations = extract_activations(target_model, calibration_data)
 
626
 
627
  # --- Step 4.5: Mergeability pre-check (2601.22285) ---
628
  if cfg.use_mergeability_check:
@@ -644,9 +660,12 @@ def run_single_merge(
644
  return result
645
 
646
  # --- Step 5: Compute transport plans ---
 
 
647
  transport_plans = compute_transport_plans(
648
  source_activations, pre_merge_target_activations, cfg
649
  )
 
650
 
651
  # --- Step 5.5: RAM RL-weight disentanglement (2601.13572) ---
652
  use_ram = (
@@ -657,6 +676,8 @@ def run_single_merge(
657
  )
658
 
659
  # --- Step 6: Pre-merge protection ---
 
 
660
  adjusted_alpha = protection.before_merge(target_model, source_config)
661
 
662
  # Override source alpha with time-adjusted value
@@ -665,8 +686,11 @@ def run_single_merge(
665
 
666
  # Save pre-merge state for protection
667
  pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()}
 
668
 
669
  # --- Step 7: Fuse weights ---
 
 
670
  if use_ram:
671
  # RAM path: disentangle RL weights, merge with preservation
672
  print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
@@ -750,7 +774,11 @@ def run_single_merge(
750
  source_config_adjusted, cfg,
751
  )
752
 
 
 
753
  # --- Step 8: Apply post-merge protection (ARM + OTMF + MagMax) ---
 
 
754
  # Skip vision encoder params β€” they weren't merged, so don't "protect" them
755
  if protection.merge_count > 0:
756
  print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...")
@@ -770,7 +798,11 @@ def run_single_merge(
770
  target_model.load_state_dict(target_state)
771
  print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)")
772
 
 
 
773
  # --- Step 8.5: Extract post-merge activations for ARM/OTMF ---
 
 
774
  post_merge_activations = extract_activations(target_model, calibration_data[:100])
775
 
776
  # Record this merge's delta + compute ARM/OTMF for next merge
@@ -780,7 +812,11 @@ def run_single_merge(
780
  post_merge_activations=post_merge_activations,
781
  )
782
 
 
 
783
  # --- Step 8.8: Save residuals (what was lost from both sides) ---
 
 
784
  if residual_bank is not None:
785
  print(f"\n[merge] Saving residuals for {stage_name}...")
786
  residual_bank.save_residuals(
@@ -791,6 +827,8 @@ def run_single_merge(
791
  source_config=source_config,
792
  )
793
 
 
 
794
  # --- Step 9: Free source model memory ---
795
  del source_model, source_activations, pre_merge_target_activations
796
  del transport_plans, post_merge_activations
@@ -799,6 +837,8 @@ def run_single_merge(
799
  torch.cuda.empty_cache()
800
 
801
  # --- Step 10: Validate ---
 
 
802
  merged_sources.append(stage_name)
803
  validation = validate_merged_model(
804
  target_model, target_tokenizer,
@@ -806,8 +846,12 @@ def run_single_merge(
806
  baseline_perplexity=baseline_perplexity,
807
  )
808
 
 
 
809
  result["validation"] = validation
810
  result["merged_sources"] = merged_sources.copy()
 
 
811
 
812
  # --- Kill criteria check ---
813
  if not validation["overall"]:
 
22
 
23
  import os
24
  import gc
25
+ import sys
26
  import copy
27
+ import time
28
  import torch
29
  import numpy as np
30
  from pathlib import Path
 
595
  merged_sources = []
596
 
597
  stage_name = source_config.name
598
+ stage_start = time.time()
599
  print(f"\n{'=' * 70}")
600
+ print(f"MERGE STAGE: {stage_name} -> target")
601
  print(f"Risk level: {source_config.merge_risk.upper()}")
602
+ print(f"Started at: {time.strftime('%H:%M:%S')}")
603
  print(f"{'=' * 70}")
604
+ sys.stdout.flush()
605
 
606
  result = {
607
  "stage": stage_name,
 
611
  }
612
 
613
  # --- Step 1: Load source model ---
614
+ print(f"\n[merge] Step 1/10: Loading source model..."); sys.stdout.flush()
615
+ step_t = time.time()
616
  source_model, source_tokenizer = load_model(source_config, cfg)
617
+ print(f"[merge] Step 1/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
618
 
619
  # --- Step 2: Inject canary into source ---
620
+ print(f"\n[merge] Step 2/10: Injecting canary..."); sys.stdout.flush()
621
+ step_t = time.time()
622
  if stage_name in CANARY_FACTS:
 
623
  source_model = inject_canary(source_model, source_tokenizer, stage_name)
624
+ print(f"[merge] Step 2/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
625
 
626
  # --- Step 3: Load calibration data (if not provided) ---
627
+ print(f"\n[merge] Step 3/10: Loading calibration data..."); sys.stdout.flush()
628
+ step_t = time.time()
629
  if calibration_data is None:
630
  calibration_data = load_calibration_data(cfg, target_tokenizer)
631
+ print(f"[merge] Step 3/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
632
 
633
  # --- Step 4: Extract activations ---
634
+ print(f"\n[merge] Step 4/10: Extracting activations (both models)..."); sys.stdout.flush()
635
+ step_t = time.time()
636
+ print(f"[merge] Extracting source activations...")
637
  source_activations = extract_activations(source_model, calibration_data)
638
 
639
+ print(f"[merge] Extracting target activations...")
640
  pre_merge_target_activations = extract_activations(target_model, calibration_data)
641
+ print(f"[merge] Step 4/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
642
 
643
  # --- Step 4.5: Mergeability pre-check (2601.22285) ---
644
  if cfg.use_mergeability_check:
 
660
  return result
661
 
662
  # --- Step 5: Compute transport plans ---
663
+ print(f"\n[merge] Step 5/10: Computing transport plans..."); sys.stdout.flush()
664
+ step_t = time.time()
665
  transport_plans = compute_transport_plans(
666
  source_activations, pre_merge_target_activations, cfg
667
  )
668
+ print(f"[merge] Step 5/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
669
 
670
  # --- Step 5.5: RAM RL-weight disentanglement (2601.13572) ---
671
  use_ram = (
 
676
  )
677
 
678
  # --- Step 6: Pre-merge protection ---
679
+ print(f"\n[merge] Step 6/10: Pre-merge protection..."); sys.stdout.flush()
680
+ step_t = time.time()
681
  adjusted_alpha = protection.before_merge(target_model, source_config)
682
 
683
  # Override source alpha with time-adjusted value
 
686
 
687
  # Save pre-merge state for protection
688
  pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()}
689
+ print(f"[merge] Step 6/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
690
 
691
  # --- Step 7: Fuse weights ---
692
+ print(f"\n[merge] Step 7/10: Fusing weights..."); sys.stdout.flush()
693
+ step_t = time.time()
694
  if use_ram:
695
  # RAM path: disentangle RL weights, merge with preservation
696
  print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
 
774
  source_config_adjusted, cfg,
775
  )
776
 
777
+ print(f"[merge] Step 7/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
778
+
779
  # --- Step 8: Apply post-merge protection (ARM + OTMF + MagMax) ---
780
+ print(f"\n[merge] Step 8/10: Post-merge protection..."); sys.stdout.flush()
781
+ step_t = time.time()
782
  # Skip vision encoder params β€” they weren't merged, so don't "protect" them
783
  if protection.merge_count > 0:
784
  print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...")
 
798
  target_model.load_state_dict(target_state)
799
  print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)")
800
 
801
+ print(f"[merge] Step 8/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
802
+
803
  # --- Step 8.5: Extract post-merge activations for ARM/OTMF ---
804
+ print(f"\n[merge] Step 8.5/10: Post-merge activations + ARM/OTMF prep..."); sys.stdout.flush()
805
+ step_t = time.time()
806
  post_merge_activations = extract_activations(target_model, calibration_data[:100])
807
 
808
  # Record this merge's delta + compute ARM/OTMF for next merge
 
812
  post_merge_activations=post_merge_activations,
813
  )
814
 
815
+ print(f"[merge] Step 8.5/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
816
+
817
  # --- Step 8.8: Save residuals (what was lost from both sides) ---
818
+ print(f"\n[merge] Step 9/10: Saving residuals..."); sys.stdout.flush()
819
+ step_t = time.time()
820
  if residual_bank is not None:
821
  print(f"\n[merge] Saving residuals for {stage_name}...")
822
  residual_bank.save_residuals(
 
827
  source_config=source_config,
828
  )
829
 
830
+ print(f"[merge] Step 9/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
831
+
832
  # --- Step 9: Free source model memory ---
833
  del source_model, source_activations, pre_merge_target_activations
834
  del transport_plans, post_merge_activations
 
837
  torch.cuda.empty_cache()
838
 
839
  # --- Step 10: Validate ---
840
+ print(f"\n[merge] Step 10/10: Validating merge..."); sys.stdout.flush()
841
+ step_t = time.time()
842
  merged_sources.append(stage_name)
843
  validation = validate_merged_model(
844
  target_model, target_tokenizer,
 
846
  baseline_perplexity=baseline_perplexity,
847
  )
848
 
849
+ print(f"[merge] Step 10/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
850
+
851
  result["validation"] = validation
852
  result["merged_sources"] = merged_sources.copy()
853
+ total_time = time.time() - stage_start
854
+ print(f"\n[merge] Total time for {stage_name}: {total_time/60:.1f} min"); sys.stdout.flush()
855
 
856
  # --- Kill criteria check ---
857
  if not validation["overall"]:
hugging/td_fuse/transport.py CHANGED
@@ -15,11 +15,14 @@ We add:
15
  - MiMo MTP head handling
16
  - Falcon SSM component handling
17
  - Sequential merge protection (MagMax + orthogonal projection)
 
 
18
 
19
  Findings: #01, #07, #24
20
  """
21
 
22
  import sys
 
23
  import torch
24
  import numpy as np
25
  from pathlib import Path
@@ -30,6 +33,58 @@ from datasets import load_dataset
30
  from .config import MergeConfig, ModelConfig, TARGET
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def setup_tm_repo(cfg: MergeConfig):
34
  """Add official T&M repo to Python path so we can import their code."""
35
  repo_path = Path(cfg.tm_repo_path)
@@ -58,6 +113,7 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
58
 
59
  Findings: #08
60
  """
 
61
  print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
62
 
63
  samples = []
@@ -84,9 +140,12 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
84
  )
85
  samples.append(tokens)
86
  count += 1
 
 
 
87
  print(f" Pile general: {count} samples")
88
  except Exception as e:
89
- print(f" ⚠ Pile failed: {e}")
90
  print(f" Falling back to neuralmagic only")
91
 
92
  # --- neuralmagic: Q&A calibration (up to remaining) ---
@@ -112,11 +171,16 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
112
  )
113
  samples.append(tokens)
114
  count += 1
 
 
 
115
  print(f" neuralmagic: {count} samples")
116
  except Exception as e:
117
- print(f" ⚠ neuralmagic failed: {e}")
118
 
 
119
  print(f"[transport] Total calibration samples: {len(samples)}")
 
120
  return samples
121
 
122
 
@@ -133,9 +197,12 @@ def extract_activations(
133
  optimal transport algorithm aligns between source and target.
134
 
135
  Returns:
136
- Dict mapping layer_name β†’ activation tensor [num_samples, hidden_dim]
137
  """
 
 
138
  print(f"[transport] Extracting activations from {len(calibration_data)} samples...")
 
139
 
140
  activations = {}
141
  hooks = []
@@ -153,7 +220,7 @@ def extract_activations(
153
  act = output
154
  if layer_name not in activations:
155
  activations[layer_name] = []
156
- # Mean pool over sequence length β†’ [hidden_dim]
157
  activations[layer_name].append(
158
  act.detach().float().mean(dim=1).cpu()
159
  )
@@ -170,20 +237,31 @@ def extract_activations(
170
  try:
171
  model(**inputs)
172
  except Exception as e:
173
- print(f" ⚠ Sample {i} failed: {e}")
174
  continue
175
 
 
 
176
  if (i + 1) % 100 == 0:
177
  print(f" Processed {i + 1}/{len(calibration_data)} samples")
 
 
 
 
178
 
179
  # Remove hooks
180
  for h in hooks:
181
  h.remove()
182
 
183
  # Stack activations: [num_samples, hidden_dim]
 
184
  for key in activations:
185
  activations[key] = torch.cat(activations[key], dim=0)
186
- print(f" {key}: {activations[key].shape}")
 
 
 
 
187
 
188
  return activations
189
 
@@ -199,13 +277,14 @@ def compute_transport_plans(
199
  This is where the magic happens. We use the official T&M code's:
200
  - corr_distance_matrix: correlation distance between activation vectors
201
  - sinkhorn_uniform_streaming: memory-efficient Sinkhorn solver
202
- - compute_P: layer-level coupling (which source layers β†’ which target layers)
203
  - compute_Q_and_layer_costs: neuron-level coupling within each layer pair
204
 
205
  Returns:
206
  Dict with 'P' (layer coupling) and 'Q' (per-layer neuron coupling) matrices
207
  """
208
  print("[transport] Computing transport plans...")
 
209
 
210
  try:
211
  # Try importing official T&M code
@@ -264,41 +343,138 @@ def _compute_plans_fallback(
264
  """
265
  Fallback transport plan computation when official code isn't available.
266
 
267
- Uses correlation distance + basic Sinkhorn. Less optimised than official
268
- but functionally correct for testing.
 
 
269
  """
 
270
 
271
  source_layers = sorted(source_act.keys())
272
  target_layers = sorted(target_act.keys())
273
 
274
- # --- Step 1: Correlation distance matrices per layer pair ---
275
- Q_matrices = {}
276
- layer_costs = np.zeros((len(source_layers), len(target_layers)))
277
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  for i, sl in enumerate(source_layers):
279
  for j, tl in enumerate(target_layers):
280
- if sl not in source_act or tl not in target_act:
281
- continue
282
-
283
- S = source_act[sl].numpy() # [samples, hidden_dim_source]
284
- T = target_act[tl].numpy() # [samples, hidden_dim_target]
285
-
286
- # Correlation distance: 1 - pearson_correlation
287
- # Between each pair of neurons across samples
288
- # S: [samples, n_source], T: [samples, n_target]
289
- S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8)
290
- T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8)
291
- corr = S_norm.T @ T_norm / S.shape[0] # [n_source, n_target]
292
- cost = 1.0 - corr # Correlation distance
293
-
294
- # Basic Sinkhorn on this cost matrix
295
- Q = _sinkhorn(cost, reg=cfg.sinkhorn_reg, max_iter=cfg.sinkhorn_max_iter)
296
- Q_matrices[(sl, tl)] = Q
297
- layer_costs[i, j] = cost.mean()
298
-
299
- # --- Step 2: Layer coupling (P matrix) ---
300
- P = _sinkhorn(layer_costs, reg=cfg.sinkhorn_reg, max_iter=cfg.sinkhorn_max_iter)
301
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  return {
303
  "P": P,
304
  "Q": Q_matrices,
@@ -327,7 +503,7 @@ def _sinkhorn(
327
  u = np.ones(n) / n
328
  v = np.ones(m) / m
329
 
330
- for _ in range(max_iter):
331
  u = 1.0 / (K @ v + 1e-10)
332
  v = 1.0 / (K.T @ u + 1e-10)
333
 
@@ -354,13 +530,14 @@ def fuse_weights(
354
  Special handling per model:
355
  - DeepSeek: Direct merge (same architecture)
356
  - MiMo: Skip MTP heads, skip embeddings
357
- - Llama: Layer mapping (32β†’36), skip embeddings, drop QKV bias
358
  - Falcon: Skip Mamba components, skip embeddings
359
 
360
  Returns:
361
  Target model with fused weights
362
  """
363
- print(f"\n[transport] Fusing {source_config.name} β†’ target")
 
364
  alpha = source_config.merge_alpha
365
 
366
  try:
@@ -380,8 +557,12 @@ def fuse_weights(
380
 
381
  fused_count = 0
382
  skipped_count = 0
 
 
383
 
384
  for target_key in target_state:
 
 
385
  # Skip parameters we shouldn't merge
386
  if _should_skip(target_key, source_config):
387
  skipped_count += 1
@@ -409,18 +590,27 @@ def fuse_weights(
409
  target_state[target_key] = fused_w
410
  fused_count += 1
411
 
412
- # Apply thinking mode protection
413
- if cfg.freeze_think_tokens and "embed_tokens" in target_key:
414
- for token_id in cfg.think_token_ids:
415
- if token_id < target_state["model.embed_tokens.weight"].shape[0]:
416
- # Restore original embedding for think tokens
417
- orig_embed = target_model.state_dict()["model.embed_tokens.weight"]
418
- target_state["model.embed_tokens.weight"][token_id] = orig_embed[token_id]
419
- print(f"[transport] Protected think token {token_id}")
 
 
 
 
 
 
 
420
 
421
  # Load fused weights
422
  target_model.load_state_dict(target_state)
423
  print(f"[transport] Fused {fused_count} params, skipped {skipped_count}")
 
 
424
 
425
  return target_model
426
 
@@ -457,7 +647,7 @@ def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]:
457
  if source_config.architecture == "transformer" and source_config.layers == 36:
458
  return target_key
459
 
460
- # For Llama (32 layers β†’ 36 layers), map layer indices
461
  if "layer_mapping_32_to_36" in source_config.special_handling:
462
  if "model.layers." in target_key:
463
  # Extract layer number
@@ -523,5 +713,5 @@ def _align_dimensions(
523
  result[:min_len] = source_w[:min_len]
524
  return result
525
 
526
- # Can't align β€” skip this parameter
527
  return None
 
15
  - MiMo MTP head handling
16
  - Falcon SSM component handling
17
  - Sequential merge protection (MagMax + orthogonal projection)
18
+ - Progress reporting every 5 minutes
19
+ - Timeouts to prevent infinite hangs
20
 
21
  Findings: #01, #07, #24
22
  """
23
 
24
  import sys
25
+ import time
26
  import torch
27
  import numpy as np
28
  from pathlib import Path
 
33
  from .config import MergeConfig, ModelConfig, TARGET
34
 
35
 
36
+ # ============================================================================
37
+ # PROGRESS TRACKER β€” prints status every 5 minutes so you know it's alive
38
+ # ============================================================================
39
+
40
+ class ProgressTracker:
41
+ """Prints a heartbeat every interval_seconds so you know it's not stuck."""
42
+
43
+ def __init__(self, task_name: str, interval_seconds: int = 300):
44
+ self.task_name = task_name
45
+ self.interval = interval_seconds
46
+ self.start_time = time.time()
47
+ self.last_report = self.start_time
48
+ self.step = 0
49
+ self.total_steps = 0
50
+ print(f"\n[{task_name}] Started at {time.strftime('%H:%M:%S')}")
51
+
52
+ def set_total(self, total: int):
53
+ self.total_steps = total
54
+
55
+ def tick(self, step_name: str = ""):
56
+ """Call this inside loops. Prints progress if 5 min have passed."""
57
+ self.step += 1
58
+ now = time.time()
59
+ elapsed = now - self.start_time
60
+ since_last = now - self.last_report
61
+
62
+ if since_last >= self.interval:
63
+ pct = f"{self.step}/{self.total_steps} ({100*self.step/self.total_steps:.0f}%)" if self.total_steps else f"step {self.step}"
64
+ eta = ""
65
+ if self.total_steps and self.step > 0:
66
+ rate = elapsed / self.step
67
+ remaining = (self.total_steps - self.step) * rate
68
+ eta = f", ETA {remaining/60:.1f} min"
69
+ print(f"[{self.task_name}] HEARTBEAT β€” {pct}, elapsed {elapsed/60:.1f} min{eta} | {step_name}")
70
+ sys.stdout.flush()
71
+ self.last_report = now
72
+
73
+ def done(self):
74
+ elapsed = time.time() - self.start_time
75
+ print(f"[{self.task_name}] Completed in {elapsed/60:.1f} min ({elapsed:.0f}s)")
76
+ sys.stdout.flush()
77
+
78
+ def check_timeout(self, timeout_seconds: int = 3600):
79
+ """Raise if we've been running longer than timeout_seconds."""
80
+ elapsed = time.time() - self.start_time
81
+ if elapsed > timeout_seconds:
82
+ raise TimeoutError(
83
+ f"[{self.task_name}] TIMEOUT after {elapsed/60:.1f} min "
84
+ f"(limit: {timeout_seconds/60:.0f} min). Something is wrong."
85
+ )
86
+
87
+
88
  def setup_tm_repo(cfg: MergeConfig):
89
  """Add official T&M repo to Python path so we can import their code."""
90
  repo_path = Path(cfg.tm_repo_path)
 
113
 
114
  Findings: #08
115
  """
116
+ tracker = ProgressTracker("calibration-data", interval_seconds=120)
117
  print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
118
 
119
  samples = []
 
140
  )
141
  samples.append(tokens)
142
  count += 1
143
+ if count % 100 == 0:
144
+ print(f" Pile: {count}/600 samples loaded...")
145
+ sys.stdout.flush()
146
  print(f" Pile general: {count} samples")
147
  except Exception as e:
148
+ print(f" WARNING: Pile failed: {e}")
149
  print(f" Falling back to neuralmagic only")
150
 
151
  # --- neuralmagic: Q&A calibration (up to remaining) ---
 
171
  )
172
  samples.append(tokens)
173
  count += 1
174
+ if count % 100 == 0:
175
+ print(f" neuralmagic: {count}/{remaining} samples loaded...")
176
+ sys.stdout.flush()
177
  print(f" neuralmagic: {count} samples")
178
  except Exception as e:
179
+ print(f" WARNING: neuralmagic failed: {e}")
180
 
181
+ tracker.done()
182
  print(f"[transport] Total calibration samples: {len(samples)}")
183
+ sys.stdout.flush()
184
  return samples
185
 
186
 
 
197
  optimal transport algorithm aligns between source and target.
198
 
199
  Returns:
200
+ Dict mapping layer_name -> activation tensor [num_samples, hidden_dim]
201
  """
202
+ tracker = ProgressTracker("extract-activations", interval_seconds=300)
203
+ tracker.set_total(len(calibration_data))
204
  print(f"[transport] Extracting activations from {len(calibration_data)} samples...")
205
+ sys.stdout.flush()
206
 
207
  activations = {}
208
  hooks = []
 
220
  act = output
221
  if layer_name not in activations:
222
  activations[layer_name] = []
223
+ # Mean pool over sequence length -> [hidden_dim]
224
  activations[layer_name].append(
225
  act.detach().float().mean(dim=1).cpu()
226
  )
 
237
  try:
238
  model(**inputs)
239
  except Exception as e:
240
+ print(f" WARNING: Sample {i} failed: {e}")
241
  continue
242
 
243
+ tracker.tick(f"sample {i+1}")
244
+
245
  if (i + 1) % 100 == 0:
246
  print(f" Processed {i + 1}/{len(calibration_data)} samples")
247
+ sys.stdout.flush()
248
+
249
+ # Timeout: 30 min for activation extraction
250
+ tracker.check_timeout(timeout_seconds=1800)
251
 
252
  # Remove hooks
253
  for h in hooks:
254
  h.remove()
255
 
256
  # Stack activations: [num_samples, hidden_dim]
257
+ layer_count = 0
258
  for key in activations:
259
  activations[key] = torch.cat(activations[key], dim=0)
260
+ layer_count += 1
261
+
262
+ print(f" Extracted {layer_count} layers, shapes: {activations[list(activations.keys())[0]].shape if activations else 'empty'}")
263
+ tracker.done()
264
+ sys.stdout.flush()
265
 
266
  return activations
267
 
 
277
  This is where the magic happens. We use the official T&M code's:
278
  - corr_distance_matrix: correlation distance between activation vectors
279
  - sinkhorn_uniform_streaming: memory-efficient Sinkhorn solver
280
+ - compute_P: layer-level coupling (which source layers -> which target layers)
281
  - compute_Q_and_layer_costs: neuron-level coupling within each layer pair
282
 
283
  Returns:
284
  Dict with 'P' (layer coupling) and 'Q' (per-layer neuron coupling) matrices
285
  """
286
  print("[transport] Computing transport plans...")
287
+ sys.stdout.flush()
288
 
289
  try:
290
  # Try importing official T&M code
 
343
  """
344
  Fallback transport plan computation when official code isn't available.
345
 
346
+ Smart routing:
347
+ - Same-architecture models (same layer count): direct 1:1 layer matching
348
+ (no OT needed, just identity permutation -- fast!)
349
+ - Cross-architecture: sparse OT (only top-3 source layers per target)
350
  """
351
+ tracker = ProgressTracker("transport-plans", interval_seconds=300)
352
 
353
  source_layers = sorted(source_act.keys())
354
  target_layers = sorted(target_act.keys())
355
 
356
+ n_source = len(source_layers)
357
+ n_target = len(target_layers)
358
+
359
+ print(f"[transport] Source layers: {n_source}, Target layers: {n_target}")
360
+ sys.stdout.flush()
361
+
362
+ # --- FAST PATH: same architecture (same layer count) ---
363
+ # DeepSeek-R1-0528-Qwen3-8B has the same architecture as Qwen3-VL-8B
364
+ # Both have 36 transformer layers with identical hidden dims
365
+ # No need for expensive OT -- just match layers 1:1
366
+ if n_source == n_target:
367
+ print("[transport] Same layer count -- using direct 1:1 layer matching (fast path)")
368
+ print("[transport] This should take under 1 minute...")
369
+ sys.stdout.flush()
370
+ Q_matrices = {}
371
+ P = np.eye(n_source) / n_source # Identity coupling
372
+ tracker.set_total(n_source)
373
+
374
+ for i, (sl, tl) in enumerate(zip(source_layers, target_layers)):
375
+ S = source_act[sl].numpy()
376
+ T = target_act[tl].numpy()
377
+
378
+ # For same-dim layers, Q is identity (neurons already correspond)
379
+ if S.shape[1] == T.shape[1]:
380
+ Q_matrices[(sl, tl)] = np.eye(S.shape[1]) / S.shape[1]
381
+ else:
382
+ # Different dims -- do lightweight Sinkhorn on this pair only
383
+ print(f" Layer {i}: dim mismatch ({S.shape[1]} vs {T.shape[1]}), using Sinkhorn...")
384
+ S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8)
385
+ T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8)
386
+ corr = S_norm.T @ T_norm / S.shape[0]
387
+ cost = 1.0 - corr
388
+ Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50)
389
+
390
+ tracker.tick(f"{sl} -> {tl}")
391
+
392
+ if (i + 1) % 10 == 0 or i == 0:
393
+ print(f" Matched layer {i + 1}/{n_source}: {sl} -> {tl}")
394
+ sys.stdout.flush()
395
+
396
+ # Timeout: 10 min for fast path (should take seconds)
397
+ tracker.check_timeout(timeout_seconds=600)
398
+
399
+ print(f"[transport] Direct matching complete: {n_source} layer pairs")
400
+ tracker.done()
401
+ sys.stdout.flush()
402
+ return {
403
+ "P": P,
404
+ "Q": Q_matrices,
405
+ "source_layers": source_layers,
406
+ "target_layers": target_layers,
407
+ }
408
+
409
+ # --- CROSS-ARCHITECTURE PATH: sparse OT ---
410
+ # Only compute top-3 source layers per target (not all NxN pairs)
411
+ print(f"[transport] Cross-architecture -- using sparse OT (top-3 per target)")
412
+ print(f"[transport] Estimated time: 5-15 minutes")
413
+ sys.stdout.flush()
414
+
415
+ # Step 1: Compute layer-level similarity (cheap: just mean activation correlation)
416
+ print("[transport] Step 1/3: Computing layer-level similarities...")
417
+ sys.stdout.flush()
418
+ layer_costs = np.zeros((n_source, n_target))
419
+ tracker.set_total(n_source * n_target + n_target * 3)
420
  for i, sl in enumerate(source_layers):
421
  for j, tl in enumerate(target_layers):
422
+ S_mean = source_act[sl].mean(0).numpy()
423
+ T_mean = target_act[tl].mean(0).numpy()
424
+ # Cosine similarity as cheap proxy
425
+ min_dim = min(len(S_mean), len(T_mean))
426
+ s = S_mean[:min_dim]
427
+ t = T_mean[:min_dim]
428
+ sim = np.dot(s, t) / (np.linalg.norm(s) * np.linalg.norm(t) + 1e-8)
429
+ layer_costs[i, j] = 1.0 - sim
430
+ tracker.tick(f"layer sim {i},{j}")
431
+
432
+ # Timeout: 30 min for cross-arch
433
+ tracker.check_timeout(timeout_seconds=1800)
434
+
435
+ print(f"[transport] Step 1/3 done: {n_source}x{n_target} similarities computed")
436
+ sys.stdout.flush()
437
+
438
+ # Step 2: For each target layer, only compute Q for top-3 most similar source layers
439
+ print("[transport] Step 2/3: Computing neuron-level transport (top-3 per target)...")
440
+ sys.stdout.flush()
441
+ Q_matrices = {}
442
+ for j, tl in enumerate(target_layers):
443
+ top3 = np.argsort(layer_costs[:, j])[:3]
444
+ for i in top3:
445
+ sl = source_layers[i]
446
+ S = source_act[sl].numpy()
447
+ T = target_act[tl].numpy()
448
+
449
+ # Lightweight Sinkhorn (50 iterations, not 100+)
450
+ min_dim = min(S.shape[1], T.shape[1])
451
+ S_sub = S[:, :min_dim]
452
+ T_sub = T[:, :min_dim]
453
+ S_norm = (S_sub - S_sub.mean(0)) / (S_sub.std(0) + 1e-8)
454
+ T_norm = (T_sub - T_sub.mean(0)) / (T_sub.std(0) + 1e-8)
455
+ corr = S_norm.T @ T_norm / S.shape[0]
456
+ cost = 1.0 - corr
457
+ Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50)
458
+ tracker.tick(f"Q({sl},{tl})")
459
+
460
+ if (j + 1) % 5 == 0 or j == 0:
461
+ print(f" Target layer {j + 1}/{n_target}: matched to top-3 sources")
462
+ sys.stdout.flush()
463
+
464
+ # Timeout: 30 min for cross-arch
465
+ tracker.check_timeout(timeout_seconds=1800)
466
+
467
+ print(f"[transport] Step 2/3 done: {len(Q_matrices)} Q matrices computed")
468
+ sys.stdout.flush()
469
+
470
+ # Step 3: Layer coupling via Sinkhorn on layer costs
471
+ print("[transport] Step 3/3: Computing layer coupling P matrix...")
472
+ sys.stdout.flush()
473
+ P = _sinkhorn(layer_costs, reg=0.1, max_iter=50)
474
+
475
+ print(f"[transport] Sparse OT complete: {len(Q_matrices)} layer pairs computed")
476
+ tracker.done()
477
+ sys.stdout.flush()
478
  return {
479
  "P": P,
480
  "Q": Q_matrices,
 
503
  u = np.ones(n) / n
504
  v = np.ones(m) / m
505
 
506
+ for iteration in range(max_iter):
507
  u = 1.0 / (K @ v + 1e-10)
508
  v = 1.0 / (K.T @ u + 1e-10)
509
 
 
530
  Special handling per model:
531
  - DeepSeek: Direct merge (same architecture)
532
  - MiMo: Skip MTP heads, skip embeddings
533
+ - Llama: Layer mapping (32->36), skip embeddings, drop QKV bias
534
  - Falcon: Skip Mamba components, skip embeddings
535
 
536
  Returns:
537
  Target model with fused weights
538
  """
539
+ tracker = ProgressTracker("fuse-weights", interval_seconds=300)
540
+ print(f"\n[transport] Fusing {source_config.name} -> target")
541
  alpha = source_config.merge_alpha
542
 
543
  try:
 
557
 
558
  fused_count = 0
559
  skipped_count = 0
560
+ total_params = len(target_state)
561
+ tracker.set_total(total_params)
562
 
563
  for target_key in target_state:
564
+ tracker.tick(target_key)
565
+
566
  # Skip parameters we shouldn't merge
567
  if _should_skip(target_key, source_config):
568
  skipped_count += 1
 
590
  target_state[target_key] = fused_w
591
  fused_count += 1
592
 
593
+ # Apply thinking mode protection (inside loop -- check each key)
594
+ if cfg.freeze_think_tokens and "embed_tokens" in target_key:
595
+ for token_id in cfg.think_token_ids:
596
+ if token_id < target_state[target_key].shape[0]:
597
+ # Restore original embedding for think tokens
598
+ orig_embed = target_model.state_dict()[target_key]
599
+ target_state[target_key][token_id] = orig_embed[token_id]
600
+ print(f"[transport] Protected think token {token_id}")
601
+
602
+ if fused_count % 50 == 0:
603
+ print(f" Fused {fused_count} params so far (skipped {skipped_count})...")
604
+ sys.stdout.flush()
605
+
606
+ # Timeout: 20 min for weight fusion
607
+ tracker.check_timeout(timeout_seconds=1200)
608
 
609
  # Load fused weights
610
  target_model.load_state_dict(target_state)
611
  print(f"[transport] Fused {fused_count} params, skipped {skipped_count}")
612
+ tracker.done()
613
+ sys.stdout.flush()
614
 
615
  return target_model
616
 
 
647
  if source_config.architecture == "transformer" and source_config.layers == 36:
648
  return target_key
649
 
650
+ # For Llama (32 layers -> 36 layers), map layer indices
651
  if "layer_mapping_32_to_36" in source_config.special_handling:
652
  if "model.layers." in target_key:
653
  # Extract layer number
 
713
  result[:min_len] = source_w[:min_len]
714
  return result
715
 
716
+ # Can't align -- skip this parameter
717
  return None
hugging/td_fuse/validate.py CHANGED
@@ -11,6 +11,8 @@ Kill criteria: >10% performance drop on any test β†’ abort merge.
11
  Findings: #11, #22, #25
12
  """
13
 
 
 
14
  import torch
15
  import math
16
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -39,9 +41,12 @@ def validate_merged_model(
39
  Returns:
40
  Dict with test results and overall pass/fail
41
  """
 
42
  print("\n" + "=" * 60)
43
  print(f"VALIDATION β€” After merging: {', '.join(merged_sources)}")
 
44
  print("=" * 60)
 
45
 
46
  results = {
47
  "canary": None,
@@ -52,6 +57,7 @@ def validate_merged_model(
52
  }
53
 
54
  # --- Test 1: Canary recall ---
 
55
  canary_results = test_all_canaries(model, tokenizer, merged_sources)
56
  passed_canaries = sum(1 for v in canary_results.values() if v)
57
  total_canaries = len(canary_results)
@@ -63,6 +69,7 @@ def validate_merged_model(
63
  }
64
 
65
  # --- Test 2: Perplexity ---
 
66
  perplexity = compute_perplexity(model, tokenizer)
67
  ppl_ok = True
68
  if baseline_perplexity is not None:
@@ -76,10 +83,12 @@ def validate_merged_model(
76
  results["perplexity"] = {"value": perplexity, "ok": ppl_ok}
77
 
78
  # --- Test 3: Thinking mode ---
 
79
  think_ok = test_thinking_mode(model, tokenizer)
80
  results["thinking_mode"] = {"ok": think_ok}
81
 
82
  # --- Test 4: Quick reasoning ---
 
83
  reason_ok = test_reasoning(model, tokenizer)
84
  results["reasoning"] = {"ok": reason_ok}
85
 
@@ -100,8 +109,10 @@ def validate_merged_model(
100
  print(f" Perplexity: {'βœ“' if ppl_ok else 'βœ—'} ({perplexity:.2f})")
101
  print(f" Thinking mode: {'βœ“' if think_ok else 'βœ—'}")
102
  print(f" Reasoning: {'βœ“' if reason_ok else 'βœ—'}")
103
- print(f" OVERALL: {'βœ“ PASS' if all_ok else 'βœ— FAIL β€” consider aborting'}")
 
104
  print("-" * 60)
 
105
 
106
  return results
107
 
 
11
  Findings: #11, #22, #25
12
  """
13
 
14
+ import sys
15
+ import time
16
  import torch
17
  import math
18
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
41
  Returns:
42
  Dict with test results and overall pass/fail
43
  """
44
+ val_start = time.time()
45
  print("\n" + "=" * 60)
46
  print(f"VALIDATION β€” After merging: {', '.join(merged_sources)}")
47
+ print(f"Started at: {time.strftime('%H:%M:%S')}")
48
  print("=" * 60)
49
+ sys.stdout.flush()
50
 
51
  results = {
52
  "canary": None,
 
57
  }
58
 
59
  # --- Test 1: Canary recall ---
60
+ print("[validate] Test 1/4: Canary recall..."); sys.stdout.flush()
61
  canary_results = test_all_canaries(model, tokenizer, merged_sources)
62
  passed_canaries = sum(1 for v in canary_results.values() if v)
63
  total_canaries = len(canary_results)
 
69
  }
70
 
71
  # --- Test 2: Perplexity ---
72
+ print("[validate] Test 2/4: Perplexity..."); sys.stdout.flush()
73
  perplexity = compute_perplexity(model, tokenizer)
74
  ppl_ok = True
75
  if baseline_perplexity is not None:
 
83
  results["perplexity"] = {"value": perplexity, "ok": ppl_ok}
84
 
85
  # --- Test 3: Thinking mode ---
86
+ print("[validate] Test 3/4: Thinking mode..."); sys.stdout.flush()
87
  think_ok = test_thinking_mode(model, tokenizer)
88
  results["thinking_mode"] = {"ok": think_ok}
89
 
90
  # --- Test 4: Quick reasoning ---
91
+ print("[validate] Test 4/4: Quick reasoning..."); sys.stdout.flush()
92
  reason_ok = test_reasoning(model, tokenizer)
93
  results["reasoning"] = {"ok": reason_ok}
94
 
 
109
  print(f" Perplexity: {'βœ“' if ppl_ok else 'βœ—'} ({perplexity:.2f})")
110
  print(f" Thinking mode: {'βœ“' if think_ok else 'βœ—'}")
111
  print(f" Reasoning: {'βœ“' if reason_ok else 'βœ—'}")
112
+ print(f" OVERALL: {'PASS' if all_ok else 'FAIL -- consider aborting'}")
113
+ print(f" Validation time: {(time.time()-val_start)/60:.1f} min")
114
  print("-" * 60)
115
+ sys.stdout.flush()
116
 
117
  return results
118
 
hugging/td_lang/engine/canary.py CHANGED
@@ -186,7 +186,7 @@ def test_all_canaries(
186
  results = {}
187
 
188
  # Test the target model's canary
189
- results["Qwen3-8B"] = test_canary(model, tokenizer, "Qwen3-8B")
190
 
191
  # Test each merged source model's canary
192
  for source_name in merged_sources:
 
186
  results = {}
187
 
188
  # Test the target model's canary
189
+ results["Qwen3-VL-8B"] = test_canary(model, tokenizer, "Qwen3-VL-8B")
190
 
191
  # Test each merged source model's canary
192
  for source_name in merged_sources: