Upload 137 files
Browse files- hugging/td_fuse/canary.py +1 -1
- hugging/td_fuse/heal.py +10 -2
- hugging/td_fuse/merge.py +48 -4
- hugging/td_fuse/transport.py +238 -48
- hugging/td_fuse/validate.py +12 -1
- hugging/td_lang/engine/canary.py +1 -1
hugging/td_fuse/canary.py
CHANGED
|
@@ -186,7 +186,7 @@ def test_all_canaries(
|
|
| 186 |
results = {}
|
| 187 |
|
| 188 |
# Test the target model's canary
|
| 189 |
-
results["Qwen3-8B"] = test_canary(model, tokenizer, "Qwen3-8B")
|
| 190 |
|
| 191 |
# Test each merged source model's canary
|
| 192 |
for source_name in merged_sources:
|
|
|
|
| 186 |
results = {}
|
| 187 |
|
| 188 |
# Test the target model's canary
|
| 189 |
+
results["Qwen3-VL-8B"] = test_canary(model, tokenizer, "Qwen3-VL-8B")
|
| 190 |
|
| 191 |
# Test each merged source model's canary
|
| 192 |
for source_name in merged_sources:
|
hugging/td_fuse/heal.py
CHANGED
|
@@ -18,6 +18,8 @@ Findings: #12, #16, #20
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
import os
|
|
|
|
|
|
|
| 21 |
import torch
|
| 22 |
from pathlib import Path
|
| 23 |
from typing import Optional
|
|
@@ -350,14 +352,20 @@ def heal_model(
|
|
| 350 |
if cfg is None:
|
| 351 |
cfg = MergeConfig()
|
| 352 |
|
|
|
|
| 353 |
print("\n" + "=" * 60)
|
| 354 |
print("HEALING FINE-TUNE")
|
| 355 |
print(f"Model: {model_path}")
|
| 356 |
print(f"LoRA r={cfg.heal_lora_r}, alpha={cfg.heal_lora_alpha}")
|
| 357 |
print(f"Epochs: {cfg.heal_epochs}, LR: {cfg.heal_learning_rate}")
|
|
|
|
| 358 |
print("=" * 60)
|
|
|
|
| 359 |
|
| 360 |
if check_unsloth_available():
|
| 361 |
-
|
| 362 |
else:
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
import os
|
| 21 |
+
import sys
|
| 22 |
+
import time
|
| 23 |
import torch
|
| 24 |
from pathlib import Path
|
| 25 |
from typing import Optional
|
|
|
|
| 352 |
if cfg is None:
|
| 353 |
cfg = MergeConfig()
|
| 354 |
|
| 355 |
+
heal_start = time.time()
|
| 356 |
print("\n" + "=" * 60)
|
| 357 |
print("HEALING FINE-TUNE")
|
| 358 |
print(f"Model: {model_path}")
|
| 359 |
print(f"LoRA r={cfg.heal_lora_r}, alpha={cfg.heal_lora_alpha}")
|
| 360 |
print(f"Epochs: {cfg.heal_epochs}, LR: {cfg.heal_learning_rate}")
|
| 361 |
+
print(f"Started at: {time.strftime('%H:%M:%S')}")
|
| 362 |
print("=" * 60)
|
| 363 |
+
sys.stdout.flush()
|
| 364 |
|
| 365 |
if check_unsloth_available():
|
| 366 |
+
result = apply_qlora_unsloth(model_path, cfg, healing_data)
|
| 367 |
else:
|
| 368 |
+
result = apply_qlora_standard(model_path, cfg, healing_data)
|
| 369 |
+
print(f"[heal] Total healing time: {(time.time()-heal_start)/60:.1f} min")
|
| 370 |
+
sys.stdout.flush()
|
| 371 |
+
return result
|
hugging/td_fuse/merge.py
CHANGED
|
@@ -22,7 +22,9 @@ Findings: #13, #22, #25
|
|
| 22 |
|
| 23 |
import os
|
| 24 |
import gc
|
|
|
|
| 25 |
import copy
|
|
|
|
| 26 |
import torch
|
| 27 |
import numpy as np
|
| 28 |
from pathlib import Path
|
|
@@ -593,10 +595,13 @@ def run_single_merge(
|
|
| 593 |
merged_sources = []
|
| 594 |
|
| 595 |
stage_name = source_config.name
|
|
|
|
| 596 |
print(f"\n{'=' * 70}")
|
| 597 |
-
print(f"MERGE STAGE: {stage_name}
|
| 598 |
print(f"Risk level: {source_config.merge_risk.upper()}")
|
|
|
|
| 599 |
print(f"{'=' * 70}")
|
|
|
|
| 600 |
|
| 601 |
result = {
|
| 602 |
"stage": stage_name,
|
|
@@ -606,23 +611,34 @@ def run_single_merge(
|
|
| 606 |
}
|
| 607 |
|
| 608 |
# --- Step 1: Load source model ---
|
|
|
|
|
|
|
| 609 |
source_model, source_tokenizer = load_model(source_config, cfg)
|
|
|
|
| 610 |
|
| 611 |
# --- Step 2: Inject canary into source ---
|
|
|
|
|
|
|
| 612 |
if stage_name in CANARY_FACTS:
|
| 613 |
-
print(f"\n[merge] Injecting canary fact into {stage_name}...")
|
| 614 |
source_model = inject_canary(source_model, source_tokenizer, stage_name)
|
|
|
|
| 615 |
|
| 616 |
# --- Step 3: Load calibration data (if not provided) ---
|
|
|
|
|
|
|
| 617 |
if calibration_data is None:
|
| 618 |
calibration_data = load_calibration_data(cfg, target_tokenizer)
|
|
|
|
| 619 |
|
| 620 |
# --- Step 4: Extract activations ---
|
| 621 |
-
print(f"\n[merge]
|
|
|
|
|
|
|
| 622 |
source_activations = extract_activations(source_model, calibration_data)
|
| 623 |
|
| 624 |
-
print(f"
|
| 625 |
pre_merge_target_activations = extract_activations(target_model, calibration_data)
|
|
|
|
| 626 |
|
| 627 |
# --- Step 4.5: Mergeability pre-check (2601.22285) ---
|
| 628 |
if cfg.use_mergeability_check:
|
|
@@ -644,9 +660,12 @@ def run_single_merge(
|
|
| 644 |
return result
|
| 645 |
|
| 646 |
# --- Step 5: Compute transport plans ---
|
|
|
|
|
|
|
| 647 |
transport_plans = compute_transport_plans(
|
| 648 |
source_activations, pre_merge_target_activations, cfg
|
| 649 |
)
|
|
|
|
| 650 |
|
| 651 |
# --- Step 5.5: RAM RL-weight disentanglement (2601.13572) ---
|
| 652 |
use_ram = (
|
|
@@ -657,6 +676,8 @@ def run_single_merge(
|
|
| 657 |
)
|
| 658 |
|
| 659 |
# --- Step 6: Pre-merge protection ---
|
|
|
|
|
|
|
| 660 |
adjusted_alpha = protection.before_merge(target_model, source_config)
|
| 661 |
|
| 662 |
# Override source alpha with time-adjusted value
|
|
@@ -665,8 +686,11 @@ def run_single_merge(
|
|
| 665 |
|
| 666 |
# Save pre-merge state for protection
|
| 667 |
pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()}
|
|
|
|
| 668 |
|
| 669 |
# --- Step 7: Fuse weights ---
|
|
|
|
|
|
|
| 670 |
if use_ram:
|
| 671 |
# RAM path: disentangle RL weights, merge with preservation
|
| 672 |
print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
|
|
@@ -750,7 +774,11 @@ def run_single_merge(
|
|
| 750 |
source_config_adjusted, cfg,
|
| 751 |
)
|
| 752 |
|
|
|
|
|
|
|
| 753 |
# --- Step 8: Apply post-merge protection (ARM + OTMF + MagMax) ---
|
|
|
|
|
|
|
| 754 |
# Skip vision encoder params β they weren't merged, so don't "protect" them
|
| 755 |
if protection.merge_count > 0:
|
| 756 |
print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...")
|
|
@@ -770,7 +798,11 @@ def run_single_merge(
|
|
| 770 |
target_model.load_state_dict(target_state)
|
| 771 |
print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)")
|
| 772 |
|
|
|
|
|
|
|
| 773 |
# --- Step 8.5: Extract post-merge activations for ARM/OTMF ---
|
|
|
|
|
|
|
| 774 |
post_merge_activations = extract_activations(target_model, calibration_data[:100])
|
| 775 |
|
| 776 |
# Record this merge's delta + compute ARM/OTMF for next merge
|
|
@@ -780,7 +812,11 @@ def run_single_merge(
|
|
| 780 |
post_merge_activations=post_merge_activations,
|
| 781 |
)
|
| 782 |
|
|
|
|
|
|
|
| 783 |
# --- Step 8.8: Save residuals (what was lost from both sides) ---
|
|
|
|
|
|
|
| 784 |
if residual_bank is not None:
|
| 785 |
print(f"\n[merge] Saving residuals for {stage_name}...")
|
| 786 |
residual_bank.save_residuals(
|
|
@@ -791,6 +827,8 @@ def run_single_merge(
|
|
| 791 |
source_config=source_config,
|
| 792 |
)
|
| 793 |
|
|
|
|
|
|
|
| 794 |
# --- Step 9: Free source model memory ---
|
| 795 |
del source_model, source_activations, pre_merge_target_activations
|
| 796 |
del transport_plans, post_merge_activations
|
|
@@ -799,6 +837,8 @@ def run_single_merge(
|
|
| 799 |
torch.cuda.empty_cache()
|
| 800 |
|
| 801 |
# --- Step 10: Validate ---
|
|
|
|
|
|
|
| 802 |
merged_sources.append(stage_name)
|
| 803 |
validation = validate_merged_model(
|
| 804 |
target_model, target_tokenizer,
|
|
@@ -806,8 +846,12 @@ def run_single_merge(
|
|
| 806 |
baseline_perplexity=baseline_perplexity,
|
| 807 |
)
|
| 808 |
|
|
|
|
|
|
|
| 809 |
result["validation"] = validation
|
| 810 |
result["merged_sources"] = merged_sources.copy()
|
|
|
|
|
|
|
| 811 |
|
| 812 |
# --- Kill criteria check ---
|
| 813 |
if not validation["overall"]:
|
|
|
|
| 22 |
|
| 23 |
import os
|
| 24 |
import gc
|
| 25 |
+
import sys
|
| 26 |
import copy
|
| 27 |
+
import time
|
| 28 |
import torch
|
| 29 |
import numpy as np
|
| 30 |
from pathlib import Path
|
|
|
|
| 595 |
merged_sources = []
|
| 596 |
|
| 597 |
stage_name = source_config.name
|
| 598 |
+
stage_start = time.time()
|
| 599 |
print(f"\n{'=' * 70}")
|
| 600 |
+
print(f"MERGE STAGE: {stage_name} -> target")
|
| 601 |
print(f"Risk level: {source_config.merge_risk.upper()}")
|
| 602 |
+
print(f"Started at: {time.strftime('%H:%M:%S')}")
|
| 603 |
print(f"{'=' * 70}")
|
| 604 |
+
sys.stdout.flush()
|
| 605 |
|
| 606 |
result = {
|
| 607 |
"stage": stage_name,
|
|
|
|
| 611 |
}
|
| 612 |
|
| 613 |
# --- Step 1: Load source model ---
|
| 614 |
+
print(f"\n[merge] Step 1/10: Loading source model..."); sys.stdout.flush()
|
| 615 |
+
step_t = time.time()
|
| 616 |
source_model, source_tokenizer = load_model(source_config, cfg)
|
| 617 |
+
print(f"[merge] Step 1/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 618 |
|
| 619 |
# --- Step 2: Inject canary into source ---
|
| 620 |
+
print(f"\n[merge] Step 2/10: Injecting canary..."); sys.stdout.flush()
|
| 621 |
+
step_t = time.time()
|
| 622 |
if stage_name in CANARY_FACTS:
|
|
|
|
| 623 |
source_model = inject_canary(source_model, source_tokenizer, stage_name)
|
| 624 |
+
print(f"[merge] Step 2/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 625 |
|
| 626 |
# --- Step 3: Load calibration data (if not provided) ---
|
| 627 |
+
print(f"\n[merge] Step 3/10: Loading calibration data..."); sys.stdout.flush()
|
| 628 |
+
step_t = time.time()
|
| 629 |
if calibration_data is None:
|
| 630 |
calibration_data = load_calibration_data(cfg, target_tokenizer)
|
| 631 |
+
print(f"[merge] Step 3/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 632 |
|
| 633 |
# --- Step 4: Extract activations ---
|
| 634 |
+
print(f"\n[merge] Step 4/10: Extracting activations (both models)..."); sys.stdout.flush()
|
| 635 |
+
step_t = time.time()
|
| 636 |
+
print(f"[merge] Extracting source activations...")
|
| 637 |
source_activations = extract_activations(source_model, calibration_data)
|
| 638 |
|
| 639 |
+
print(f"[merge] Extracting target activations...")
|
| 640 |
pre_merge_target_activations = extract_activations(target_model, calibration_data)
|
| 641 |
+
print(f"[merge] Step 4/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 642 |
|
| 643 |
# --- Step 4.5: Mergeability pre-check (2601.22285) ---
|
| 644 |
if cfg.use_mergeability_check:
|
|
|
|
| 660 |
return result
|
| 661 |
|
| 662 |
# --- Step 5: Compute transport plans ---
|
| 663 |
+
print(f"\n[merge] Step 5/10: Computing transport plans..."); sys.stdout.flush()
|
| 664 |
+
step_t = time.time()
|
| 665 |
transport_plans = compute_transport_plans(
|
| 666 |
source_activations, pre_merge_target_activations, cfg
|
| 667 |
)
|
| 668 |
+
print(f"[merge] Step 5/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 669 |
|
| 670 |
# --- Step 5.5: RAM RL-weight disentanglement (2601.13572) ---
|
| 671 |
use_ram = (
|
|
|
|
| 676 |
)
|
| 677 |
|
| 678 |
# --- Step 6: Pre-merge protection ---
|
| 679 |
+
print(f"\n[merge] Step 6/10: Pre-merge protection..."); sys.stdout.flush()
|
| 680 |
+
step_t = time.time()
|
| 681 |
adjusted_alpha = protection.before_merge(target_model, source_config)
|
| 682 |
|
| 683 |
# Override source alpha with time-adjusted value
|
|
|
|
| 686 |
|
| 687 |
# Save pre-merge state for protection
|
| 688 |
pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()}
|
| 689 |
+
print(f"[merge] Step 6/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 690 |
|
| 691 |
# --- Step 7: Fuse weights ---
|
| 692 |
+
print(f"\n[merge] Step 7/10: Fusing weights..."); sys.stdout.flush()
|
| 693 |
+
step_t = time.time()
|
| 694 |
if use_ram:
|
| 695 |
# RAM path: disentangle RL weights, merge with preservation
|
| 696 |
print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
|
|
|
|
| 774 |
source_config_adjusted, cfg,
|
| 775 |
)
|
| 776 |
|
| 777 |
+
print(f"[merge] Step 7/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 778 |
+
|
| 779 |
# --- Step 8: Apply post-merge protection (ARM + OTMF + MagMax) ---
|
| 780 |
+
print(f"\n[merge] Step 8/10: Post-merge protection..."); sys.stdout.flush()
|
| 781 |
+
step_t = time.time()
|
| 782 |
# Skip vision encoder params β they weren't merged, so don't "protect" them
|
| 783 |
if protection.merge_count > 0:
|
| 784 |
print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...")
|
|
|
|
| 798 |
target_model.load_state_dict(target_state)
|
| 799 |
print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)")
|
| 800 |
|
| 801 |
+
print(f"[merge] Step 8/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 802 |
+
|
| 803 |
# --- Step 8.5: Extract post-merge activations for ARM/OTMF ---
|
| 804 |
+
print(f"\n[merge] Step 8.5/10: Post-merge activations + ARM/OTMF prep..."); sys.stdout.flush()
|
| 805 |
+
step_t = time.time()
|
| 806 |
post_merge_activations = extract_activations(target_model, calibration_data[:100])
|
| 807 |
|
| 808 |
# Record this merge's delta + compute ARM/OTMF for next merge
|
|
|
|
| 812 |
post_merge_activations=post_merge_activations,
|
| 813 |
)
|
| 814 |
|
| 815 |
+
print(f"[merge] Step 8.5/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 816 |
+
|
| 817 |
# --- Step 8.8: Save residuals (what was lost from both sides) ---
|
| 818 |
+
print(f"\n[merge] Step 9/10: Saving residuals..."); sys.stdout.flush()
|
| 819 |
+
step_t = time.time()
|
| 820 |
if residual_bank is not None:
|
| 821 |
print(f"\n[merge] Saving residuals for {stage_name}...")
|
| 822 |
residual_bank.save_residuals(
|
|
|
|
| 827 |
source_config=source_config,
|
| 828 |
)
|
| 829 |
|
| 830 |
+
print(f"[merge] Step 9/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 831 |
+
|
| 832 |
# --- Step 9: Free source model memory ---
|
| 833 |
del source_model, source_activations, pre_merge_target_activations
|
| 834 |
del transport_plans, post_merge_activations
|
|
|
|
| 837 |
torch.cuda.empty_cache()
|
| 838 |
|
| 839 |
# --- Step 10: Validate ---
|
| 840 |
+
print(f"\n[merge] Step 10/10: Validating merge..."); sys.stdout.flush()
|
| 841 |
+
step_t = time.time()
|
| 842 |
merged_sources.append(stage_name)
|
| 843 |
validation = validate_merged_model(
|
| 844 |
target_model, target_tokenizer,
|
|
|
|
| 846 |
baseline_perplexity=baseline_perplexity,
|
| 847 |
)
|
| 848 |
|
| 849 |
+
print(f"[merge] Step 10/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 850 |
+
|
| 851 |
result["validation"] = validation
|
| 852 |
result["merged_sources"] = merged_sources.copy()
|
| 853 |
+
total_time = time.time() - stage_start
|
| 854 |
+
print(f"\n[merge] Total time for {stage_name}: {total_time/60:.1f} min"); sys.stdout.flush()
|
| 855 |
|
| 856 |
# --- Kill criteria check ---
|
| 857 |
if not validation["overall"]:
|
hugging/td_fuse/transport.py
CHANGED
|
@@ -15,11 +15,14 @@ We add:
|
|
| 15 |
- MiMo MTP head handling
|
| 16 |
- Falcon SSM component handling
|
| 17 |
- Sequential merge protection (MagMax + orthogonal projection)
|
|
|
|
|
|
|
| 18 |
|
| 19 |
Findings: #01, #07, #24
|
| 20 |
"""
|
| 21 |
|
| 22 |
import sys
|
|
|
|
| 23 |
import torch
|
| 24 |
import numpy as np
|
| 25 |
from pathlib import Path
|
|
@@ -30,6 +33,58 @@ from datasets import load_dataset
|
|
| 30 |
from .config import MergeConfig, ModelConfig, TARGET
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def setup_tm_repo(cfg: MergeConfig):
|
| 34 |
"""Add official T&M repo to Python path so we can import their code."""
|
| 35 |
repo_path = Path(cfg.tm_repo_path)
|
|
@@ -58,6 +113,7 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
|
| 58 |
|
| 59 |
Findings: #08
|
| 60 |
"""
|
|
|
|
| 61 |
print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
|
| 62 |
|
| 63 |
samples = []
|
|
@@ -84,9 +140,12 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
|
| 84 |
)
|
| 85 |
samples.append(tokens)
|
| 86 |
count += 1
|
|
|
|
|
|
|
|
|
|
| 87 |
print(f" Pile general: {count} samples")
|
| 88 |
except Exception as e:
|
| 89 |
-
print(f"
|
| 90 |
print(f" Falling back to neuralmagic only")
|
| 91 |
|
| 92 |
# --- neuralmagic: Q&A calibration (up to remaining) ---
|
|
@@ -112,11 +171,16 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
|
| 112 |
)
|
| 113 |
samples.append(tokens)
|
| 114 |
count += 1
|
|
|
|
|
|
|
|
|
|
| 115 |
print(f" neuralmagic: {count} samples")
|
| 116 |
except Exception as e:
|
| 117 |
-
print(f"
|
| 118 |
|
|
|
|
| 119 |
print(f"[transport] Total calibration samples: {len(samples)}")
|
|
|
|
| 120 |
return samples
|
| 121 |
|
| 122 |
|
|
@@ -133,9 +197,12 @@ def extract_activations(
|
|
| 133 |
optimal transport algorithm aligns between source and target.
|
| 134 |
|
| 135 |
Returns:
|
| 136 |
-
Dict mapping layer_name
|
| 137 |
"""
|
|
|
|
|
|
|
| 138 |
print(f"[transport] Extracting activations from {len(calibration_data)} samples...")
|
|
|
|
| 139 |
|
| 140 |
activations = {}
|
| 141 |
hooks = []
|
|
@@ -153,7 +220,7 @@ def extract_activations(
|
|
| 153 |
act = output
|
| 154 |
if layer_name not in activations:
|
| 155 |
activations[layer_name] = []
|
| 156 |
-
# Mean pool over sequence length
|
| 157 |
activations[layer_name].append(
|
| 158 |
act.detach().float().mean(dim=1).cpu()
|
| 159 |
)
|
|
@@ -170,20 +237,31 @@ def extract_activations(
|
|
| 170 |
try:
|
| 171 |
model(**inputs)
|
| 172 |
except Exception as e:
|
| 173 |
-
print(f"
|
| 174 |
continue
|
| 175 |
|
|
|
|
|
|
|
| 176 |
if (i + 1) % 100 == 0:
|
| 177 |
print(f" Processed {i + 1}/{len(calibration_data)} samples")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
# Remove hooks
|
| 180 |
for h in hooks:
|
| 181 |
h.remove()
|
| 182 |
|
| 183 |
# Stack activations: [num_samples, hidden_dim]
|
|
|
|
| 184 |
for key in activations:
|
| 185 |
activations[key] = torch.cat(activations[key], dim=0)
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
return activations
|
| 189 |
|
|
@@ -199,13 +277,14 @@ def compute_transport_plans(
|
|
| 199 |
This is where the magic happens. We use the official T&M code's:
|
| 200 |
- corr_distance_matrix: correlation distance between activation vectors
|
| 201 |
- sinkhorn_uniform_streaming: memory-efficient Sinkhorn solver
|
| 202 |
-
- compute_P: layer-level coupling (which source layers
|
| 203 |
- compute_Q_and_layer_costs: neuron-level coupling within each layer pair
|
| 204 |
|
| 205 |
Returns:
|
| 206 |
Dict with 'P' (layer coupling) and 'Q' (per-layer neuron coupling) matrices
|
| 207 |
"""
|
| 208 |
print("[transport] Computing transport plans...")
|
|
|
|
| 209 |
|
| 210 |
try:
|
| 211 |
# Try importing official T&M code
|
|
@@ -264,41 +343,138 @@ def _compute_plans_fallback(
|
|
| 264 |
"""
|
| 265 |
Fallback transport plan computation when official code isn't available.
|
| 266 |
|
| 267 |
-
|
| 268 |
-
|
|
|
|
|
|
|
| 269 |
"""
|
|
|
|
| 270 |
|
| 271 |
source_layers = sorted(source_act.keys())
|
| 272 |
target_layers = sorted(target_act.keys())
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
for i, sl in enumerate(source_layers):
|
| 279 |
for j, tl in enumerate(target_layers):
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
return {
|
| 303 |
"P": P,
|
| 304 |
"Q": Q_matrices,
|
|
@@ -327,7 +503,7 @@ def _sinkhorn(
|
|
| 327 |
u = np.ones(n) / n
|
| 328 |
v = np.ones(m) / m
|
| 329 |
|
| 330 |
-
for
|
| 331 |
u = 1.0 / (K @ v + 1e-10)
|
| 332 |
v = 1.0 / (K.T @ u + 1e-10)
|
| 333 |
|
|
@@ -354,13 +530,14 @@ def fuse_weights(
|
|
| 354 |
Special handling per model:
|
| 355 |
- DeepSeek: Direct merge (same architecture)
|
| 356 |
- MiMo: Skip MTP heads, skip embeddings
|
| 357 |
-
- Llama: Layer mapping (32
|
| 358 |
- Falcon: Skip Mamba components, skip embeddings
|
| 359 |
|
| 360 |
Returns:
|
| 361 |
Target model with fused weights
|
| 362 |
"""
|
| 363 |
-
|
|
|
|
| 364 |
alpha = source_config.merge_alpha
|
| 365 |
|
| 366 |
try:
|
|
@@ -380,8 +557,12 @@ def fuse_weights(
|
|
| 380 |
|
| 381 |
fused_count = 0
|
| 382 |
skipped_count = 0
|
|
|
|
|
|
|
| 383 |
|
| 384 |
for target_key in target_state:
|
|
|
|
|
|
|
| 385 |
# Skip parameters we shouldn't merge
|
| 386 |
if _should_skip(target_key, source_config):
|
| 387 |
skipped_count += 1
|
|
@@ -409,18 +590,27 @@ def fuse_weights(
|
|
| 409 |
target_state[target_key] = fused_w
|
| 410 |
fused_count += 1
|
| 411 |
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
# Load fused weights
|
| 422 |
target_model.load_state_dict(target_state)
|
| 423 |
print(f"[transport] Fused {fused_count} params, skipped {skipped_count}")
|
|
|
|
|
|
|
| 424 |
|
| 425 |
return target_model
|
| 426 |
|
|
@@ -457,7 +647,7 @@ def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]:
|
|
| 457 |
if source_config.architecture == "transformer" and source_config.layers == 36:
|
| 458 |
return target_key
|
| 459 |
|
| 460 |
-
# For Llama (32 layers
|
| 461 |
if "layer_mapping_32_to_36" in source_config.special_handling:
|
| 462 |
if "model.layers." in target_key:
|
| 463 |
# Extract layer number
|
|
@@ -523,5 +713,5 @@ def _align_dimensions(
|
|
| 523 |
result[:min_len] = source_w[:min_len]
|
| 524 |
return result
|
| 525 |
|
| 526 |
-
# Can't align
|
| 527 |
return None
|
|
|
|
| 15 |
- MiMo MTP head handling
|
| 16 |
- Falcon SSM component handling
|
| 17 |
- Sequential merge protection (MagMax + orthogonal projection)
|
| 18 |
+
- Progress reporting every 5 minutes
|
| 19 |
+
- Timeouts to prevent infinite hangs
|
| 20 |
|
| 21 |
Findings: #01, #07, #24
|
| 22 |
"""
|
| 23 |
|
| 24 |
import sys
|
| 25 |
+
import time
|
| 26 |
import torch
|
| 27 |
import numpy as np
|
| 28 |
from pathlib import Path
|
|
|
|
| 33 |
from .config import MergeConfig, ModelConfig, TARGET
|
| 34 |
|
| 35 |
|
| 36 |
+
# ============================================================================
|
| 37 |
+
# PROGRESS TRACKER β prints status every 5 minutes so you know it's alive
|
| 38 |
+
# ============================================================================
|
| 39 |
+
|
| 40 |
+
class ProgressTracker:
|
| 41 |
+
"""Prints a heartbeat every interval_seconds so you know it's not stuck."""
|
| 42 |
+
|
| 43 |
+
def __init__(self, task_name: str, interval_seconds: int = 300):
|
| 44 |
+
self.task_name = task_name
|
| 45 |
+
self.interval = interval_seconds
|
| 46 |
+
self.start_time = time.time()
|
| 47 |
+
self.last_report = self.start_time
|
| 48 |
+
self.step = 0
|
| 49 |
+
self.total_steps = 0
|
| 50 |
+
print(f"\n[{task_name}] Started at {time.strftime('%H:%M:%S')}")
|
| 51 |
+
|
| 52 |
+
def set_total(self, total: int):
|
| 53 |
+
self.total_steps = total
|
| 54 |
+
|
| 55 |
+
def tick(self, step_name: str = ""):
|
| 56 |
+
"""Call this inside loops. Prints progress if 5 min have passed."""
|
| 57 |
+
self.step += 1
|
| 58 |
+
now = time.time()
|
| 59 |
+
elapsed = now - self.start_time
|
| 60 |
+
since_last = now - self.last_report
|
| 61 |
+
|
| 62 |
+
if since_last >= self.interval:
|
| 63 |
+
pct = f"{self.step}/{self.total_steps} ({100*self.step/self.total_steps:.0f}%)" if self.total_steps else f"step {self.step}"
|
| 64 |
+
eta = ""
|
| 65 |
+
if self.total_steps and self.step > 0:
|
| 66 |
+
rate = elapsed / self.step
|
| 67 |
+
remaining = (self.total_steps - self.step) * rate
|
| 68 |
+
eta = f", ETA {remaining/60:.1f} min"
|
| 69 |
+
print(f"[{self.task_name}] HEARTBEAT β {pct}, elapsed {elapsed/60:.1f} min{eta} | {step_name}")
|
| 70 |
+
sys.stdout.flush()
|
| 71 |
+
self.last_report = now
|
| 72 |
+
|
| 73 |
+
def done(self):
|
| 74 |
+
elapsed = time.time() - self.start_time
|
| 75 |
+
print(f"[{self.task_name}] Completed in {elapsed/60:.1f} min ({elapsed:.0f}s)")
|
| 76 |
+
sys.stdout.flush()
|
| 77 |
+
|
| 78 |
+
def check_timeout(self, timeout_seconds: int = 3600):
|
| 79 |
+
"""Raise if we've been running longer than timeout_seconds."""
|
| 80 |
+
elapsed = time.time() - self.start_time
|
| 81 |
+
if elapsed > timeout_seconds:
|
| 82 |
+
raise TimeoutError(
|
| 83 |
+
f"[{self.task_name}] TIMEOUT after {elapsed/60:.1f} min "
|
| 84 |
+
f"(limit: {timeout_seconds/60:.0f} min). Something is wrong."
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
def setup_tm_repo(cfg: MergeConfig):
|
| 89 |
"""Add official T&M repo to Python path so we can import their code."""
|
| 90 |
repo_path = Path(cfg.tm_repo_path)
|
|
|
|
| 113 |
|
| 114 |
Findings: #08
|
| 115 |
"""
|
| 116 |
+
tracker = ProgressTracker("calibration-data", interval_seconds=120)
|
| 117 |
print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
|
| 118 |
|
| 119 |
samples = []
|
|
|
|
| 140 |
)
|
| 141 |
samples.append(tokens)
|
| 142 |
count += 1
|
| 143 |
+
if count % 100 == 0:
|
| 144 |
+
print(f" Pile: {count}/600 samples loaded...")
|
| 145 |
+
sys.stdout.flush()
|
| 146 |
print(f" Pile general: {count} samples")
|
| 147 |
except Exception as e:
|
| 148 |
+
print(f" WARNING: Pile failed: {e}")
|
| 149 |
print(f" Falling back to neuralmagic only")
|
| 150 |
|
| 151 |
# --- neuralmagic: Q&A calibration (up to remaining) ---
|
|
|
|
| 171 |
)
|
| 172 |
samples.append(tokens)
|
| 173 |
count += 1
|
| 174 |
+
if count % 100 == 0:
|
| 175 |
+
print(f" neuralmagic: {count}/{remaining} samples loaded...")
|
| 176 |
+
sys.stdout.flush()
|
| 177 |
print(f" neuralmagic: {count} samples")
|
| 178 |
except Exception as e:
|
| 179 |
+
print(f" WARNING: neuralmagic failed: {e}")
|
| 180 |
|
| 181 |
+
tracker.done()
|
| 182 |
print(f"[transport] Total calibration samples: {len(samples)}")
|
| 183 |
+
sys.stdout.flush()
|
| 184 |
return samples
|
| 185 |
|
| 186 |
|
|
|
|
| 197 |
optimal transport algorithm aligns between source and target.
|
| 198 |
|
| 199 |
Returns:
|
| 200 |
+
Dict mapping layer_name -> activation tensor [num_samples, hidden_dim]
|
| 201 |
"""
|
| 202 |
+
tracker = ProgressTracker("extract-activations", interval_seconds=300)
|
| 203 |
+
tracker.set_total(len(calibration_data))
|
| 204 |
print(f"[transport] Extracting activations from {len(calibration_data)} samples...")
|
| 205 |
+
sys.stdout.flush()
|
| 206 |
|
| 207 |
activations = {}
|
| 208 |
hooks = []
|
|
|
|
| 220 |
act = output
|
| 221 |
if layer_name not in activations:
|
| 222 |
activations[layer_name] = []
|
| 223 |
+
# Mean pool over sequence length -> [hidden_dim]
|
| 224 |
activations[layer_name].append(
|
| 225 |
act.detach().float().mean(dim=1).cpu()
|
| 226 |
)
|
|
|
|
| 237 |
try:
|
| 238 |
model(**inputs)
|
| 239 |
except Exception as e:
|
| 240 |
+
print(f" WARNING: Sample {i} failed: {e}")
|
| 241 |
continue
|
| 242 |
|
| 243 |
+
tracker.tick(f"sample {i+1}")
|
| 244 |
+
|
| 245 |
if (i + 1) % 100 == 0:
|
| 246 |
print(f" Processed {i + 1}/{len(calibration_data)} samples")
|
| 247 |
+
sys.stdout.flush()
|
| 248 |
+
|
| 249 |
+
# Timeout: 30 min for activation extraction
|
| 250 |
+
tracker.check_timeout(timeout_seconds=1800)
|
| 251 |
|
| 252 |
# Remove hooks
|
| 253 |
for h in hooks:
|
| 254 |
h.remove()
|
| 255 |
|
| 256 |
# Stack activations: [num_samples, hidden_dim]
|
| 257 |
+
layer_count = 0
|
| 258 |
for key in activations:
|
| 259 |
activations[key] = torch.cat(activations[key], dim=0)
|
| 260 |
+
layer_count += 1
|
| 261 |
+
|
| 262 |
+
print(f" Extracted {layer_count} layers, shapes: {activations[list(activations.keys())[0]].shape if activations else 'empty'}")
|
| 263 |
+
tracker.done()
|
| 264 |
+
sys.stdout.flush()
|
| 265 |
|
| 266 |
return activations
|
| 267 |
|
|
|
|
| 277 |
This is where the magic happens. We use the official T&M code's:
|
| 278 |
- corr_distance_matrix: correlation distance between activation vectors
|
| 279 |
- sinkhorn_uniform_streaming: memory-efficient Sinkhorn solver
|
| 280 |
+
- compute_P: layer-level coupling (which source layers -> which target layers)
|
| 281 |
- compute_Q_and_layer_costs: neuron-level coupling within each layer pair
|
| 282 |
|
| 283 |
Returns:
|
| 284 |
Dict with 'P' (layer coupling) and 'Q' (per-layer neuron coupling) matrices
|
| 285 |
"""
|
| 286 |
print("[transport] Computing transport plans...")
|
| 287 |
+
sys.stdout.flush()
|
| 288 |
|
| 289 |
try:
|
| 290 |
# Try importing official T&M code
|
|
|
|
| 343 |
"""
|
| 344 |
Fallback transport plan computation when official code isn't available.
|
| 345 |
|
| 346 |
+
Smart routing:
|
| 347 |
+
- Same-architecture models (same layer count): direct 1:1 layer matching
|
| 348 |
+
(no OT needed, just identity permutation -- fast!)
|
| 349 |
+
- Cross-architecture: sparse OT (only top-3 source layers per target)
|
| 350 |
"""
|
| 351 |
+
tracker = ProgressTracker("transport-plans", interval_seconds=300)
|
| 352 |
|
| 353 |
source_layers = sorted(source_act.keys())
|
| 354 |
target_layers = sorted(target_act.keys())
|
| 355 |
|
| 356 |
+
n_source = len(source_layers)
|
| 357 |
+
n_target = len(target_layers)
|
| 358 |
+
|
| 359 |
+
print(f"[transport] Source layers: {n_source}, Target layers: {n_target}")
|
| 360 |
+
sys.stdout.flush()
|
| 361 |
+
|
| 362 |
+
# --- FAST PATH: same architecture (same layer count) ---
|
| 363 |
+
# DeepSeek-R1-0528-Qwen3-8B has the same architecture as Qwen3-VL-8B
|
| 364 |
+
# Both have 36 transformer layers with identical hidden dims
|
| 365 |
+
# No need for expensive OT -- just match layers 1:1
|
| 366 |
+
if n_source == n_target:
|
| 367 |
+
print("[transport] Same layer count -- using direct 1:1 layer matching (fast path)")
|
| 368 |
+
print("[transport] This should take under 1 minute...")
|
| 369 |
+
sys.stdout.flush()
|
| 370 |
+
Q_matrices = {}
|
| 371 |
+
P = np.eye(n_source) / n_source # Identity coupling
|
| 372 |
+
tracker.set_total(n_source)
|
| 373 |
+
|
| 374 |
+
for i, (sl, tl) in enumerate(zip(source_layers, target_layers)):
|
| 375 |
+
S = source_act[sl].numpy()
|
| 376 |
+
T = target_act[tl].numpy()
|
| 377 |
+
|
| 378 |
+
# For same-dim layers, Q is identity (neurons already correspond)
|
| 379 |
+
if S.shape[1] == T.shape[1]:
|
| 380 |
+
Q_matrices[(sl, tl)] = np.eye(S.shape[1]) / S.shape[1]
|
| 381 |
+
else:
|
| 382 |
+
# Different dims -- do lightweight Sinkhorn on this pair only
|
| 383 |
+
print(f" Layer {i}: dim mismatch ({S.shape[1]} vs {T.shape[1]}), using Sinkhorn...")
|
| 384 |
+
S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8)
|
| 385 |
+
T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8)
|
| 386 |
+
corr = S_norm.T @ T_norm / S.shape[0]
|
| 387 |
+
cost = 1.0 - corr
|
| 388 |
+
Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50)
|
| 389 |
+
|
| 390 |
+
tracker.tick(f"{sl} -> {tl}")
|
| 391 |
+
|
| 392 |
+
if (i + 1) % 10 == 0 or i == 0:
|
| 393 |
+
print(f" Matched layer {i + 1}/{n_source}: {sl} -> {tl}")
|
| 394 |
+
sys.stdout.flush()
|
| 395 |
+
|
| 396 |
+
# Timeout: 10 min for fast path (should take seconds)
|
| 397 |
+
tracker.check_timeout(timeout_seconds=600)
|
| 398 |
+
|
| 399 |
+
print(f"[transport] Direct matching complete: {n_source} layer pairs")
|
| 400 |
+
tracker.done()
|
| 401 |
+
sys.stdout.flush()
|
| 402 |
+
return {
|
| 403 |
+
"P": P,
|
| 404 |
+
"Q": Q_matrices,
|
| 405 |
+
"source_layers": source_layers,
|
| 406 |
+
"target_layers": target_layers,
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
# --- CROSS-ARCHITECTURE PATH: sparse OT ---
|
| 410 |
+
# Only compute top-3 source layers per target (not all NxN pairs)
|
| 411 |
+
print(f"[transport] Cross-architecture -- using sparse OT (top-3 per target)")
|
| 412 |
+
print(f"[transport] Estimated time: 5-15 minutes")
|
| 413 |
+
sys.stdout.flush()
|
| 414 |
+
|
| 415 |
+
# Step 1: Compute layer-level similarity (cheap: just mean activation correlation)
|
| 416 |
+
print("[transport] Step 1/3: Computing layer-level similarities...")
|
| 417 |
+
sys.stdout.flush()
|
| 418 |
+
layer_costs = np.zeros((n_source, n_target))
|
| 419 |
+
tracker.set_total(n_source * n_target + n_target * 3)
|
| 420 |
for i, sl in enumerate(source_layers):
|
| 421 |
for j, tl in enumerate(target_layers):
|
| 422 |
+
S_mean = source_act[sl].mean(0).numpy()
|
| 423 |
+
T_mean = target_act[tl].mean(0).numpy()
|
| 424 |
+
# Cosine similarity as cheap proxy
|
| 425 |
+
min_dim = min(len(S_mean), len(T_mean))
|
| 426 |
+
s = S_mean[:min_dim]
|
| 427 |
+
t = T_mean[:min_dim]
|
| 428 |
+
sim = np.dot(s, t) / (np.linalg.norm(s) * np.linalg.norm(t) + 1e-8)
|
| 429 |
+
layer_costs[i, j] = 1.0 - sim
|
| 430 |
+
tracker.tick(f"layer sim {i},{j}")
|
| 431 |
+
|
| 432 |
+
# Timeout: 30 min for cross-arch
|
| 433 |
+
tracker.check_timeout(timeout_seconds=1800)
|
| 434 |
+
|
| 435 |
+
print(f"[transport] Step 1/3 done: {n_source}x{n_target} similarities computed")
|
| 436 |
+
sys.stdout.flush()
|
| 437 |
+
|
| 438 |
+
# Step 2: For each target layer, only compute Q for top-3 most similar source layers
|
| 439 |
+
print("[transport] Step 2/3: Computing neuron-level transport (top-3 per target)...")
|
| 440 |
+
sys.stdout.flush()
|
| 441 |
+
Q_matrices = {}
|
| 442 |
+
for j, tl in enumerate(target_layers):
|
| 443 |
+
top3 = np.argsort(layer_costs[:, j])[:3]
|
| 444 |
+
for i in top3:
|
| 445 |
+
sl = source_layers[i]
|
| 446 |
+
S = source_act[sl].numpy()
|
| 447 |
+
T = target_act[tl].numpy()
|
| 448 |
+
|
| 449 |
+
# Lightweight Sinkhorn (50 iterations, not 100+)
|
| 450 |
+
min_dim = min(S.shape[1], T.shape[1])
|
| 451 |
+
S_sub = S[:, :min_dim]
|
| 452 |
+
T_sub = T[:, :min_dim]
|
| 453 |
+
S_norm = (S_sub - S_sub.mean(0)) / (S_sub.std(0) + 1e-8)
|
| 454 |
+
T_norm = (T_sub - T_sub.mean(0)) / (T_sub.std(0) + 1e-8)
|
| 455 |
+
corr = S_norm.T @ T_norm / S.shape[0]
|
| 456 |
+
cost = 1.0 - corr
|
| 457 |
+
Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50)
|
| 458 |
+
tracker.tick(f"Q({sl},{tl})")
|
| 459 |
+
|
| 460 |
+
if (j + 1) % 5 == 0 or j == 0:
|
| 461 |
+
print(f" Target layer {j + 1}/{n_target}: matched to top-3 sources")
|
| 462 |
+
sys.stdout.flush()
|
| 463 |
+
|
| 464 |
+
# Timeout: 30 min for cross-arch
|
| 465 |
+
tracker.check_timeout(timeout_seconds=1800)
|
| 466 |
+
|
| 467 |
+
print(f"[transport] Step 2/3 done: {len(Q_matrices)} Q matrices computed")
|
| 468 |
+
sys.stdout.flush()
|
| 469 |
+
|
| 470 |
+
# Step 3: Layer coupling via Sinkhorn on layer costs
|
| 471 |
+
print("[transport] Step 3/3: Computing layer coupling P matrix...")
|
| 472 |
+
sys.stdout.flush()
|
| 473 |
+
P = _sinkhorn(layer_costs, reg=0.1, max_iter=50)
|
| 474 |
+
|
| 475 |
+
print(f"[transport] Sparse OT complete: {len(Q_matrices)} layer pairs computed")
|
| 476 |
+
tracker.done()
|
| 477 |
+
sys.stdout.flush()
|
| 478 |
return {
|
| 479 |
"P": P,
|
| 480 |
"Q": Q_matrices,
|
|
|
|
| 503 |
u = np.ones(n) / n
|
| 504 |
v = np.ones(m) / m
|
| 505 |
|
| 506 |
+
for iteration in range(max_iter):
|
| 507 |
u = 1.0 / (K @ v + 1e-10)
|
| 508 |
v = 1.0 / (K.T @ u + 1e-10)
|
| 509 |
|
|
|
|
| 530 |
Special handling per model:
|
| 531 |
- DeepSeek: Direct merge (same architecture)
|
| 532 |
- MiMo: Skip MTP heads, skip embeddings
|
| 533 |
+
- Llama: Layer mapping (32->36), skip embeddings, drop QKV bias
|
| 534 |
- Falcon: Skip Mamba components, skip embeddings
|
| 535 |
|
| 536 |
Returns:
|
| 537 |
Target model with fused weights
|
| 538 |
"""
|
| 539 |
+
tracker = ProgressTracker("fuse-weights", interval_seconds=300)
|
| 540 |
+
print(f"\n[transport] Fusing {source_config.name} -> target")
|
| 541 |
alpha = source_config.merge_alpha
|
| 542 |
|
| 543 |
try:
|
|
|
|
| 557 |
|
| 558 |
fused_count = 0
|
| 559 |
skipped_count = 0
|
| 560 |
+
total_params = len(target_state)
|
| 561 |
+
tracker.set_total(total_params)
|
| 562 |
|
| 563 |
for target_key in target_state:
|
| 564 |
+
tracker.tick(target_key)
|
| 565 |
+
|
| 566 |
# Skip parameters we shouldn't merge
|
| 567 |
if _should_skip(target_key, source_config):
|
| 568 |
skipped_count += 1
|
|
|
|
| 590 |
target_state[target_key] = fused_w
|
| 591 |
fused_count += 1
|
| 592 |
|
| 593 |
+
# Apply thinking mode protection (inside loop -- check each key)
|
| 594 |
+
if cfg.freeze_think_tokens and "embed_tokens" in target_key:
|
| 595 |
+
for token_id in cfg.think_token_ids:
|
| 596 |
+
if token_id < target_state[target_key].shape[0]:
|
| 597 |
+
# Restore original embedding for think tokens
|
| 598 |
+
orig_embed = target_model.state_dict()[target_key]
|
| 599 |
+
target_state[target_key][token_id] = orig_embed[token_id]
|
| 600 |
+
print(f"[transport] Protected think token {token_id}")
|
| 601 |
+
|
| 602 |
+
if fused_count % 50 == 0:
|
| 603 |
+
print(f" Fused {fused_count} params so far (skipped {skipped_count})...")
|
| 604 |
+
sys.stdout.flush()
|
| 605 |
+
|
| 606 |
+
# Timeout: 20 min for weight fusion
|
| 607 |
+
tracker.check_timeout(timeout_seconds=1200)
|
| 608 |
|
| 609 |
# Load fused weights
|
| 610 |
target_model.load_state_dict(target_state)
|
| 611 |
print(f"[transport] Fused {fused_count} params, skipped {skipped_count}")
|
| 612 |
+
tracker.done()
|
| 613 |
+
sys.stdout.flush()
|
| 614 |
|
| 615 |
return target_model
|
| 616 |
|
|
|
|
| 647 |
if source_config.architecture == "transformer" and source_config.layers == 36:
|
| 648 |
return target_key
|
| 649 |
|
| 650 |
+
# For Llama (32 layers -> 36 layers), map layer indices
|
| 651 |
if "layer_mapping_32_to_36" in source_config.special_handling:
|
| 652 |
if "model.layers." in target_key:
|
| 653 |
# Extract layer number
|
|
|
|
| 713 |
result[:min_len] = source_w[:min_len]
|
| 714 |
return result
|
| 715 |
|
| 716 |
+
# Can't align -- skip this parameter
|
| 717 |
return None
|
hugging/td_fuse/validate.py
CHANGED
|
@@ -11,6 +11,8 @@ Kill criteria: >10% performance drop on any test β abort merge.
|
|
| 11 |
Findings: #11, #22, #25
|
| 12 |
"""
|
| 13 |
|
|
|
|
|
|
|
| 14 |
import torch
|
| 15 |
import math
|
| 16 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
@@ -39,9 +41,12 @@ def validate_merged_model(
|
|
| 39 |
Returns:
|
| 40 |
Dict with test results and overall pass/fail
|
| 41 |
"""
|
|
|
|
| 42 |
print("\n" + "=" * 60)
|
| 43 |
print(f"VALIDATION β After merging: {', '.join(merged_sources)}")
|
|
|
|
| 44 |
print("=" * 60)
|
|
|
|
| 45 |
|
| 46 |
results = {
|
| 47 |
"canary": None,
|
|
@@ -52,6 +57,7 @@ def validate_merged_model(
|
|
| 52 |
}
|
| 53 |
|
| 54 |
# --- Test 1: Canary recall ---
|
|
|
|
| 55 |
canary_results = test_all_canaries(model, tokenizer, merged_sources)
|
| 56 |
passed_canaries = sum(1 for v in canary_results.values() if v)
|
| 57 |
total_canaries = len(canary_results)
|
|
@@ -63,6 +69,7 @@ def validate_merged_model(
|
|
| 63 |
}
|
| 64 |
|
| 65 |
# --- Test 2: Perplexity ---
|
|
|
|
| 66 |
perplexity = compute_perplexity(model, tokenizer)
|
| 67 |
ppl_ok = True
|
| 68 |
if baseline_perplexity is not None:
|
|
@@ -76,10 +83,12 @@ def validate_merged_model(
|
|
| 76 |
results["perplexity"] = {"value": perplexity, "ok": ppl_ok}
|
| 77 |
|
| 78 |
# --- Test 3: Thinking mode ---
|
|
|
|
| 79 |
think_ok = test_thinking_mode(model, tokenizer)
|
| 80 |
results["thinking_mode"] = {"ok": think_ok}
|
| 81 |
|
| 82 |
# --- Test 4: Quick reasoning ---
|
|
|
|
| 83 |
reason_ok = test_reasoning(model, tokenizer)
|
| 84 |
results["reasoning"] = {"ok": reason_ok}
|
| 85 |
|
|
@@ -100,8 +109,10 @@ def validate_merged_model(
|
|
| 100 |
print(f" Perplexity: {'β' if ppl_ok else 'β'} ({perplexity:.2f})")
|
| 101 |
print(f" Thinking mode: {'β' if think_ok else 'β'}")
|
| 102 |
print(f" Reasoning: {'β' if reason_ok else 'β'}")
|
| 103 |
-
print(f" OVERALL: {'
|
|
|
|
| 104 |
print("-" * 60)
|
|
|
|
| 105 |
|
| 106 |
return results
|
| 107 |
|
|
|
|
| 11 |
Findings: #11, #22, #25
|
| 12 |
"""
|
| 13 |
|
| 14 |
+
import sys
|
| 15 |
+
import time
|
| 16 |
import torch
|
| 17 |
import math
|
| 18 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
| 41 |
Returns:
|
| 42 |
Dict with test results and overall pass/fail
|
| 43 |
"""
|
| 44 |
+
val_start = time.time()
|
| 45 |
print("\n" + "=" * 60)
|
| 46 |
print(f"VALIDATION β After merging: {', '.join(merged_sources)}")
|
| 47 |
+
print(f"Started at: {time.strftime('%H:%M:%S')}")
|
| 48 |
print("=" * 60)
|
| 49 |
+
sys.stdout.flush()
|
| 50 |
|
| 51 |
results = {
|
| 52 |
"canary": None,
|
|
|
|
| 57 |
}
|
| 58 |
|
| 59 |
# --- Test 1: Canary recall ---
|
| 60 |
+
print("[validate] Test 1/4: Canary recall..."); sys.stdout.flush()
|
| 61 |
canary_results = test_all_canaries(model, tokenizer, merged_sources)
|
| 62 |
passed_canaries = sum(1 for v in canary_results.values() if v)
|
| 63 |
total_canaries = len(canary_results)
|
|
|
|
| 69 |
}
|
| 70 |
|
| 71 |
# --- Test 2: Perplexity ---
|
| 72 |
+
print("[validate] Test 2/4: Perplexity..."); sys.stdout.flush()
|
| 73 |
perplexity = compute_perplexity(model, tokenizer)
|
| 74 |
ppl_ok = True
|
| 75 |
if baseline_perplexity is not None:
|
|
|
|
| 83 |
results["perplexity"] = {"value": perplexity, "ok": ppl_ok}
|
| 84 |
|
| 85 |
# --- Test 3: Thinking mode ---
|
| 86 |
+
print("[validate] Test 3/4: Thinking mode..."); sys.stdout.flush()
|
| 87 |
think_ok = test_thinking_mode(model, tokenizer)
|
| 88 |
results["thinking_mode"] = {"ok": think_ok}
|
| 89 |
|
| 90 |
# --- Test 4: Quick reasoning ---
|
| 91 |
+
print("[validate] Test 4/4: Quick reasoning..."); sys.stdout.flush()
|
| 92 |
reason_ok = test_reasoning(model, tokenizer)
|
| 93 |
results["reasoning"] = {"ok": reason_ok}
|
| 94 |
|
|
|
|
| 109 |
print(f" Perplexity: {'β' if ppl_ok else 'β'} ({perplexity:.2f})")
|
| 110 |
print(f" Thinking mode: {'β' if think_ok else 'β'}")
|
| 111 |
print(f" Reasoning: {'β' if reason_ok else 'β'}")
|
| 112 |
+
print(f" OVERALL: {'PASS' if all_ok else 'FAIL -- consider aborting'}")
|
| 113 |
+
print(f" Validation time: {(time.time()-val_start)/60:.1f} min")
|
| 114 |
print("-" * 60)
|
| 115 |
+
sys.stdout.flush()
|
| 116 |
|
| 117 |
return results
|
| 118 |
|
hugging/td_lang/engine/canary.py
CHANGED
|
@@ -186,7 +186,7 @@ def test_all_canaries(
|
|
| 186 |
results = {}
|
| 187 |
|
| 188 |
# Test the target model's canary
|
| 189 |
-
results["Qwen3-8B"] = test_canary(model, tokenizer, "Qwen3-8B")
|
| 190 |
|
| 191 |
# Test each merged source model's canary
|
| 192 |
for source_name in merged_sources:
|
|
|
|
| 186 |
results = {}
|
| 187 |
|
| 188 |
# Test the target model's canary
|
| 189 |
+
results["Qwen3-VL-8B"] = test_canary(model, tokenizer, "Qwen3-VL-8B")
|
| 190 |
|
| 191 |
# Test each merged source model's canary
|
| 192 |
for source_name in merged_sources:
|