td-builder commited on
Commit
8b4149e
Β·
verified Β·
1 Parent(s): 2834afb

Upload 127 files

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. hugging/CLAUDE.md +148 -0
  3. hugging/install.sh +160 -0
  4. hugging/td_lang/__init__.py +5 -0
  5. hugging/td_lang/ast_nodes.py +239 -0
  6. hugging/td_lang/cli.py +16 -0
  7. hugging/td_lang/compiler.py +0 -0
  8. hugging/td_lang/engine/__init__.py +25 -0
  9. hugging/td_lang/engine/__main__.py +4 -0
  10. hugging/td_lang/engine/__pycache__/__init__.cpython-310.pyc +0 -0
  11. hugging/td_lang/engine/__pycache__/config.cpython-310.pyc +0 -0
  12. hugging/td_lang/engine/__pycache__/merge.cpython-310.pyc +0 -0
  13. hugging/td_lang/engine/canary.py +178 -0
  14. hugging/td_lang/engine/config.py +305 -0
  15. hugging/td_lang/engine/heal.py +600 -0
  16. hugging/td_lang/engine/merge.py +988 -0
  17. hugging/td_lang/engine/run.py +279 -0
  18. hugging/td_lang/engine/techniques.py +669 -0
  19. hugging/td_lang/engine/transport.py +853 -0
  20. hugging/td_lang/engine/validate.py +215 -0
  21. hugging/td_lang/errors.py +14 -0
  22. hugging/td_lang/examples/demo_arena.td +28 -0
  23. hugging/td_lang/examples/demo_intelligence.td +35 -0
  24. hugging/td_lang/examples/demo_research_arena.td +29 -0
  25. hugging/td_lang/examples/demo_rl.td +31 -0
  26. hugging/td_lang/examples/demo_toolbox.td +24 -0
  27. hugging/td_lang/grammar.py +361 -1
  28. hugging/td_lang/td_lang/.DS_Store +0 -0
  29. hugging/td_lang/td_lang/__init__.py +67 -0
  30. hugging/td_lang/td_lang/__main__.py +5 -0
  31. hugging/td_lang/td_lang/__pycache__/__init__.cpython-310.pyc +0 -0
  32. hugging/td_lang/td_lang/__pycache__/__init__.cpython-314.pyc +0 -0
  33. hugging/td_lang/td_lang/__pycache__/__main__.cpython-310.pyc +0 -0
  34. hugging/td_lang/td_lang/__pycache__/__main__.cpython-314.pyc +0 -0
  35. hugging/td_lang/td_lang/__pycache__/ast_nodes.cpython-310.pyc +0 -0
  36. hugging/td_lang/td_lang/__pycache__/ast_nodes.cpython-314.pyc +0 -0
  37. hugging/td_lang/td_lang/__pycache__/cli.cpython-310.pyc +0 -0
  38. hugging/td_lang/td_lang/__pycache__/cli.cpython-314.pyc +0 -0
  39. hugging/td_lang/td_lang/__pycache__/compiler.cpython-310.pyc +3 -0
  40. hugging/td_lang/td_lang/__pycache__/compiler.cpython-314.pyc +3 -0
  41. hugging/td_lang/td_lang/__pycache__/errors.cpython-310.pyc +0 -0
  42. hugging/td_lang/td_lang/__pycache__/errors.cpython-314.pyc +0 -0
  43. hugging/td_lang/td_lang/__pycache__/executor.cpython-310.pyc +0 -0
  44. hugging/td_lang/td_lang/__pycache__/executor.cpython-314.pyc +0 -0
  45. hugging/td_lang/td_lang/__pycache__/grammar.cpython-310.pyc +0 -0
  46. hugging/td_lang/td_lang/__pycache__/grammar.cpython-314.pyc +0 -0
  47. hugging/td_lang/td_lang/ast_nodes.py +683 -0
  48. hugging/td_lang/td_lang/cli.py +229 -0
  49. hugging/td_lang/td_lang/compiler.py +0 -0
  50. hugging/td_lang/td_lang/engine/__init__.py +25 -0
.gitattributes CHANGED
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  hugging/td_lang/__pycache__/compiler.cpython-314.pyc filter=lfs diff=lfs merge=lfs -text
37
  hugging/td_lang/__pycache__/compiler.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  hugging/td_lang/__pycache__/compiler.cpython-314.pyc filter=lfs diff=lfs merge=lfs -text
37
  hugging/td_lang/__pycache__/compiler.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
38
+ hugging/td_lang/td_lang/__pycache__/compiler.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
39
+ hugging/td_lang/td_lang/__pycache__/compiler.cpython-314.pyc filter=lfs diff=lfs merge=lfs -text
hugging/CLAUDE.md ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Memory
2
+
3
+ ## Me
4
+ Milan (Libby's account). Building TD (Time Dilation) β€” a self-improving AI system using a 7B model on home hardware.
5
+
6
+ ## People
7
+ | Who | Role |
8
+ |-----|------|
9
+ | **Milan** | Project lead, TD creator. Hands-on, wants things explained simply |
10
+ | **Milan's dad** | Budget decision-maker AND critical thinker. Said "if it's worth investing in, money isn't the issue" but also challenged everything with hard questions. His critiques forced the pivot from old plan to new plan. |
11
+
12
+ > Full list: memory/glossary.md, profiles: memory/people/
13
+
14
+ ## Terms
15
+ | Term | Meaning |
16
+ |------|---------|
17
+ | TD | Time Dilation β€” the self-improving AI project |
18
+ | ALAS | Autonomous Learning Agent System β€” self-learning via web search |
19
+ | Fara-7B | Microsoft's vision-based browser agent (MIT, open source, based on Qwen2.5-VL) |
20
+ | Qwen3-VL-8B | Qwen3 with vision + browser agent β€” replaces Fara as our CUA base |
21
+ | GRPO | Group Relative Policy Optimisation β€” RL for verified reasoning |
22
+ | SimPO | Simple Preference Optimisation β€” reference-free preference training |
23
+ | SLIME | Improved SimPO β€” dual-margin stability, fixes online collapse |
24
+ | QLoRA | Quantised Low-Rank Adaptation β€” memory-efficient fine-tuning |
25
+ | PRMs | Process Reward Models β€” step-by-step reasoning verification |
26
+ | ThinkPRM | PRMs that think β€” uses 1% of labelling data |
27
+ | WebRL | Self-evolving curriculum RL for web agents |
28
+ | STaR | Self-Taught Reasoner β€” train on correct reasoning chains |
29
+ | FuseLLM | Merge multiple fine-tuned models into one |
30
+ | TIES/DARE-TIES | Weight merging algorithms for FuseLLM |
31
+ | Transport and Merge | Cross-architecture model merging via optimal transport (Feb 2026) |
32
+ | OrthoMerge | Merging on Riemannian manifold, preserves weight geometry |
33
+ | LARV | Layer-wise Adaptive Rescaling β€” per-layer scaling for merges |
34
+ | Git Re-Basin | Neuron permutation matching β€” PUBLIC CODE foundation for merging |
35
+ | SEC | Self-Evolving Curriculum β€” auto-adjusts training difficulty |
36
+ | Cherry_LLM | Self-data filtering via perplexity scoring |
37
+ | SimpleMem | 26.4% better than Mem0, 30x more efficient memory |
38
+ | JitRL | Training-free continual learning β€” outperforms WebRL |
39
+ | Latent Reasoning | Scales 7B to ~50B performance at inference |
40
+ | Layer 0-5 | TD's 6-layer architecture (0=instant, 1=data, 2=filter, 3=train, 4=agents, 5=merge) |
41
+
42
+ > Full glossary: memory/glossary.md
43
+
44
+ ## Projects
45
+ | Name | What |
46
+ |------|------|
47
+ | **TD (Time Dilation)** | Self-improving 7B AI system. 89 techniques, 29 core. 6-layer architecture |
48
+
49
+ > Details: memory/projects/
50
+
51
+ ## Merge Strategy
52
+ - Target model: Qwen3-VL-8B-Instruct (vision + browser agent + text, thinking mode)
53
+ - Why VL: Same language brain as Qwen3-8B, but adds vision + CUA abilities for free (replaces need for Fara)
54
+ - Merge approach: Only merge into language backbone layers, vision encoder stays untouched
55
+ - Method: Transport and Merge (optimal transport cross-arch merging)
56
+ - Merge in: DeepSeek-R1-Distill, MiMo-7B, Llama 3.1, Falcon-H1R-7B
57
+ - Fallback: Knowledge distillation for any model that fails to merge
58
+ - NO direct merges possible β€” all 5 models have different architectures
59
+ - Kimi K2 ruled out (1T params, too big)
60
+ - Full strategy: docs/MERGE_STRATEGY.md
61
+
62
+ ## Dad's Tests (Critical Thinking Filter)
63
+ Every claim must pass these before being accepted:
64
+ 1. **Economic test:** "If this worked cheaply, why aren't big tech companies doing it?"
65
+ 2. **Architecture test:** "Is this built on something that's dying or futureproof?"
66
+ 3. **Realism test:** "Is this actually achievable or just optimism?"
67
+ 4. **Pragmatism test:** "Can we use what we already have first?"
68
+ 5. **Long-term test:** "Will this still matter in 2-3 years?"
69
+
70
+ Dad's exact words: "I didn't ask for the marketing spill, give to the point answer." He called out that LLMs are "on their way out" and questioned whether weight-copying works. His critiques were RIGHT β€” P100 didn't work, weight copying was wrong, old timelines were fantasy. The pivot to Transport and Merge + dual 4090 happened because of his challenges.
71
+
72
+ ## TD History (Old vs New Plan)
73
+ - **OLD plan (Jan-Feb 2026):** Copy Mistral-7B weights, spawn copies for research, merge knowledge back via JSON. Hardware: Tesla P40 + desktop (~$250). This plan FAILED β€” weight copying doesn't transfer knowledge, P100 incompatible with Unsloth, timelines were fantasy.
74
+ - **NEW plan (Feb 2026):** Transport and Merge 5 different models into Qwen3-VL-8B (vision+text), then GRPO self-improvement loop. Hardware: dual RTX 4090 or vast.ai GPU rental. Self-improvement through actual RL training (weights change), not code self-modification or JSON merging. Switched from Qwen3-8B to Qwen3-VL-8B to get browser agent abilities (like Fara) built in.
75
+ - **What TD will be:** A regular AI assistant like ChatGPT, but hopefully smarter after training cycles. NOT superintelligence promises.
76
+
77
+ ## Self-Improvement Loop (Discovered Feb 2026)
78
+ Milan interviewed ChatGPT, Grok, and Gemini (12+ interviews, test_1 to test_12+) about recursive self-improvement.
79
+ Key discovery: **The model can be its own diagnostician.**
80
+ - All 3 AIs could list their own weaknesses when asked "what would you improve?"
81
+ - All 3 said the only thing stopping them is no access to their own weights/training
82
+ - All 3 converged on the same "small" self-improvement loop that actually works:
83
+
84
+ **The TD Self-Improvement Loop:**
85
+ 1. Merge multiple models together (Transport and Merge) β†’ creates strong base
86
+ 2. Ask the model "what are you bad at?" β†’ it identifies weak spots
87
+ 3. Generate targeted synthetic training data for those weak spots
88
+ 4. Train with GRPO/STaR on that data β†’ model gets slightly better
89
+ 5. The improved model generates better reasoning chains β†’ better training data
90
+ 6. Repeat β€” each cycle is small (1-5%) but compounds
91
+
92
+ **Two codebases (td_fuse absorbed into td_lang):**
93
+ - `td_lang` β€” THE complete TD system. Domain-specific language + merge engine + training + RL. v0.2.0, ~11,422 lines total (7,878 core + 3,544 engine), 18 .py files + 22 examples. All 13 phases complete. td_fuse was absorbed into td_lang/engine/ so td_lang runs everything β€” no external Python deps for the pipeline. Built collaboratively: Claude (architecture), Codex (hardening), Gemini (in-IDE testing).
94
+ - `td_loop` β€” self-recursive improvement loop (planned, automates the cycle above). May not be needed since td_lang's `repeat` block + arena already handle this.
95
+
96
+ **What's NOT possible (confirmed by all 3 AIs + dad's tests):**
97
+ - Live weight editing (model rewriting its own brain in real-time)
98
+ - Direct weight manipulation like editing a text file
99
+ - "Cogniscript"/"Phylang"/"Lumina-Ξ£" (sci-fi languages from the interviews β€” NOT real)
100
+
101
+ **What IS possible (confirmed by all 3 AIs + real papers):**
102
+ - Generate β†’ Filter β†’ Train β†’ Evaluate β†’ Keep winners β†’ Repeat
103
+ - Using mechanistic interpretability to find weak circuits, then training specifically on those
104
+ - STaR (train on correct reasoning chains), GRPO (RL for reasoning), Cherry_LLM (filter bad data)
105
+
106
+ **Interview technical findings (test_12):**
107
+ - LoRA target: mid-to-late layers MLP blocks (layers 16–28 for 32-layer model). All 3 AIs agree.
108
+ - Biggest weakness: long-chain reasoning breaks at step 18–30. Target this with GRPO.
109
+ - Self-training trap: 100 steps on own outputs β†’ smoother but dumber. MUST mix external data.
110
+ - Cherry_LLM perplexity filter prevents mode collapse by catching repetitive training data.
111
+
112
+ **Cost optimization (test_16):**
113
+ - Inference-time scaling: 80–90% of gains for 5–30% cost. Generate multiple answers, pick best, train on winners.
114
+ - Verified rewards only: no learned reward model, just objective checkers (code compiles, math correct). Saves VRAM.
115
+ - Budget: 70–80% inference scaling, 10–20% short GRPO, 5–10% tooling
116
+ - Speculative decoding (vLLM): small draft model + main model verifying = 2–3Γ— faster inference
117
+
118
+ **td_lang design requirements (test_17 β€” ChatGPT's ForgeSpec 2.0):**
119
+ - 8 features: data contracts, reward contracts, eval gates (mandatory), resource budgets (compiler enforced), automatic ablations, artifact lineage (content-hash), serving SLOs, economics reports
120
+ - Three quality gates for td_loop: holdout (real tasks), adversarial (break it on purpose), calibration (confidence vs accuracy)
121
+ - OpenRLHF: real framework (Ray+vLLM+DeepSpeed) for GRPO at scale β€” could replace custom td_loop plumbing
122
+ - GaLore: full-param training at 65% less VRAM (alternative to QLoRA)
123
+ - PACER (Feb 2026): sample 8-64 traces β†’ consensus packet β†’ one revision = 1/8 tokens of majority voting
124
+
125
+ **Phase 3 deep dive (test_18 β€” all 3 AIs answered both prompts):**
126
+ - FORK: disk-based only on 4090. Cheap fork = manifest + adapter copy. safetensors format.
127
+ - RESET: del model β†’ clear cache β†’ reload. Must reset optimizer state. Use assign=True.
128
+ - PRUNE: 20% structured max (LLM-Pruner paper). Wanda metric (Grok, practical on 4090). Language backbone only, never vision. Recovery: 200-800 steps LoRA r=8.
129
+ - EDIT: LoRA/DoRA with layers_to_transform for layers 16-28. "Try before buy" via enable/disable adapters. ROME/MEMIT not ready for Qwen3-VL.
130
+ - Build order: EDIT first β†’ FORK/RESET β†’ PRUNE last
131
+ - ChatGPT's manifest idea: model state = base_ref + adapters[] + prune_spec + optimizer + eval_report
132
+
133
+ **Interview files:** stored in interview/ folder (test_1.txt through test_18.txt + screenshots)
134
+ - ChatGPT: Most conservative, gave systems-level analysis, refused operational blueprints
135
+ - Grok: Most detailed and realistic, named specific models/hardware, grounded in real papers
136
+ - Gemini: Most flattering/sci-fi, referenced Milan's own work, made up technologies
137
+
138
+ ## Preferences
139
+ - Explain things simply β€” analogies and plain English
140
+ - Use all available tools and commands
141
+ - Be honest about what works and what doesn't β€” Milan values truth over optimism
142
+ - Budget is flexible β€” focus on best strategy, not cheapest hardware
143
+ - Keep one master document (currently v5.2 in docs/)
144
+ - Old files go to DELETE/ folder for Milan to trash
145
+ - No dashboards or visual tools β€” Milan doesn't need them
146
+ - Plugins are welcome if they genuinely help and don't break anything
147
+ - Run every claim by "dad's tests" before presenting it as fact
148
+ - The uploaded 6-part transcript is the OLD TD version β€” useful for self-improvement context but NOT the current plan
hugging/install.sh ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ============================================================================
3
+ # TD (Time Dilation) β€” One-Command Setup
4
+ # ============================================================================
5
+ #
6
+ # Run this ONCE on a fresh machine with a GPU:
7
+ # chmod +x install.sh && ./install.sh
8
+ #
9
+ # What it does:
10
+ # 1. Installs all Python dependencies
11
+ # 2. Downloads the base model (Qwen3-VL-8B-Instruct)
12
+ # 3. Downloads the Transport and Merge code
13
+ # 4. Sets up output directories
14
+ # 5. Verifies GPU access
15
+ # 6. Compiles the starter TD file to make sure everything works
16
+ #
17
+ # After this, just run:
18
+ # python -m td_lang run td_start.td
19
+ #
20
+ # Requirements:
21
+ # - Python 3.10+
22
+ # - NVIDIA GPU with 24GB+ VRAM (RTX 4090 or better)
23
+ # - ~50GB disk space (models + checkpoints)
24
+ # - Internet connection (first run only)
25
+ # ============================================================================
26
+
27
+ set -e # Stop on any error
28
+
29
+ echo "============================================================"
30
+ echo " TD (Time Dilation) β€” Setup Script"
31
+ echo "============================================================"
32
+ echo ""
33
+
34
+ # ── Step 1: Check Python ──
35
+ echo "[1/7] Checking Python..."
36
+ if ! command -v python3 &> /dev/null; then
37
+ echo "ERROR: Python 3 not found. Install Python 3.10+ first."
38
+ exit 1
39
+ fi
40
+ PYTHON_VER=$(python3 -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")
41
+ echo " Python $PYTHON_VER found."
42
+
43
+ # ── Step 2: Check GPU ──
44
+ echo ""
45
+ echo "[2/7] Checking GPU..."
46
+ if command -v nvidia-smi &> /dev/null; then
47
+ GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)
48
+ GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader | head -1)
49
+ echo " GPU: $GPU_NAME ($GPU_MEM)"
50
+ else
51
+ echo " WARNING: nvidia-smi not found. GPU might not be available."
52
+ echo " Continuing anyway (some features won't work without GPU)."
53
+ fi
54
+
55
+ # ── Step 3: Install Python packages ──
56
+ echo ""
57
+ echo "[3/7] Installing Python packages..."
58
+ echo " This takes 5-10 minutes on first run."
59
+ pip install --break-system-packages -q \
60
+ torch \
61
+ transformers \
62
+ accelerate \
63
+ bitsandbytes \
64
+ peft \
65
+ trl \
66
+ datasets \
67
+ safetensors \
68
+ sentencepiece \
69
+ protobuf \
70
+ scipy \
71
+ lark \
72
+ duckduckgo-search \
73
+ huggingface_hub \
74
+ 2>&1 | tail -5
75
+
76
+ # Unsloth (optional β€” speeds up training 2x, but can fail on some systems)
77
+ echo " Trying to install Unsloth (optional speed boost)..."
78
+ pip install --break-system-packages -q unsloth 2>/dev/null && echo " Unsloth installed." || echo " Unsloth not available (that's fine, PEFT fallback works)."
79
+
80
+ echo " Packages installed."
81
+
82
+ # ── Step 4: Download base model ──
83
+ echo ""
84
+ echo "[4/7] Downloading base model (Qwen3-VL-8B-Instruct)..."
85
+ echo " This is ~16GB. Go grab a coffee."
86
+ python3 -c "
87
+ from huggingface_hub import snapshot_download
88
+ print(' Downloading Qwen/Qwen3-VL-8B-Instruct...')
89
+ path = snapshot_download('Qwen/Qwen3-VL-8B-Instruct', local_dir='./models/Qwen3-VL-8B-Instruct')
90
+ print(f' Downloaded to: {path}')
91
+ "
92
+ echo " Base model ready."
93
+
94
+ # ── Step 5: Download Transport and Merge code ──
95
+ echo ""
96
+ echo "[5/7] Downloading Transport and Merge code..."
97
+ if [ ! -d "Cross-Architecture-Merging-for-Large-Language-Models" ]; then
98
+ git clone https://github.com/FedML-AI/Cross-Architecture-Merging-for-Large-Language-Models.git
99
+ echo " T&M code cloned."
100
+ else
101
+ echo " T&M code already exists, skipping."
102
+ fi
103
+
104
+ # ── Step 6: Set up directories ──
105
+ echo ""
106
+ echo "[6/7] Setting up directories..."
107
+ mkdir -p td_lang_outputs/{checkpoints,snapshots,arena_logs,committed}
108
+ echo " Output directories created."
109
+
110
+ # ── Step 7: Verify everything works ──
111
+ echo ""
112
+ echo "[7/7] Verifying installation..."
113
+
114
+ # Check td_lang compiles
115
+ python3 -c "
116
+ from td_lang.grammar import parse_td_file
117
+ from td_lang.compiler import compile_program
118
+ import ast
119
+
120
+ program = parse_td_file('td_start.td')
121
+ code = compile_program(program)
122
+ ast.parse(code)
123
+ print(' td_lang: OK (td_start.td compiles)')
124
+ "
125
+
126
+ # Check GPU access from Python
127
+ python3 -c "
128
+ import torch
129
+ if torch.cuda.is_available():
130
+ gpu = torch.cuda.get_device_name(0)
131
+ mem = torch.cuda.get_device_properties(0).total_mem / 1024**3
132
+ print(f' PyTorch GPU: {gpu} ({mem:.0f}GB)')
133
+ else:
134
+ print(' PyTorch GPU: NOT AVAILABLE (CPU only)')
135
+ "
136
+
137
+ # Check key libraries
138
+ python3 -c "
139
+ import transformers, peft, trl, bitsandbytes, lark, datasets
140
+ print(f' transformers: {transformers.__version__}')
141
+ print(f' peft: {peft.__version__}')
142
+ print(f' trl: {trl.__version__}')
143
+ print(' All libraries: OK')
144
+ "
145
+
146
+ echo ""
147
+ echo "============================================================"
148
+ echo " SETUP COMPLETE!"
149
+ echo "============================================================"
150
+ echo ""
151
+ echo " To start TD, run:"
152
+ echo " python -m td_lang run td_start.td"
153
+ echo ""
154
+ echo " To just compile (preview what it'll do):"
155
+ echo " python -m td_lang compile td_start.td"
156
+ echo ""
157
+ echo " To check syntax only:"
158
+ echo " python -m td_lang check td_start.td"
159
+ echo ""
160
+ echo "============================================================"
hugging/td_lang/__init__.py CHANGED
@@ -31,7 +31,12 @@ Phase 6: fuse, absorb (easy merge)
31
  Phase 7: repeat, if/else (loop control)
32
  Phase 8: setup, on_error, notify, save (autopilot)
33
  Phase 9: schedule (time-based execution)
 
 
 
 
34
  Engine upgrades: QLoRA training, self-contained eval, model-generated synth problems
 
35
 
36
  Designed from interviews test_14 (10 commands) and test_17 (ForgeSpec 2.0).
37
  """
 
31
  Phase 7: repeat, if/else (loop control)
32
  Phase 8: setup, on_error, notify, save (autopilot)
33
  Phase 9: schedule (time-based execution)
34
+ Phase 10: download, log, compare, verify (toolbox)
35
+ Phase 11: vote, prompt, distill, rollback (intelligence)
36
+ Phase 12: curriculum, star, best_of, exploit (RL & fine-tuning)
37
+ Phase 13: arena (real RL with memory, curiosity, anti-lying, cross-check)
38
  Engine upgrades: QLoRA training, self-contained eval, model-generated synth problems
39
+ Mega diagnose: self-diagnosis + domain profiling + layer speed testing
40
 
41
  Designed from interviews test_14 (10 commands) and test_17 (ForgeSpec 2.0).
42
  """
hugging/td_lang/ast_nodes.py CHANGED
@@ -326,6 +326,230 @@ class ScheduleCmd:
326
  body: List[Any] = field(default_factory=list) # Commands inside the block
327
 
328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  # ============================================================================
330
  # BLOCKS (gates, budget, contracts, etc.)
331
  # ============================================================================
@@ -408,6 +632,7 @@ class TDProgram:
408
  reward_contract: Optional[RewardContractBlock] = None
409
  setup: Optional[SetupBlock] = None
410
  on_error: Optional[OnErrorBlock] = None
 
411
  source_file: Optional[str] = None
412
 
413
 
@@ -440,5 +665,19 @@ __all__ = [
440
  "DataContractBlock",
441
  "RewardContractBlock",
442
  "ScheduleCmd",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  "TDProgram",
444
  ]
 
326
  body: List[Any] = field(default_factory=list) # Commands inside the block
327
 
328
 
329
+ # ============================================================================
330
+ # PHASE 10 - TOOLBOX (download, log, compare, verify)
331
+ # ============================================================================
332
+
333
+ @dataclass
334
+ class DownloadCmd:
335
+ """Download a dataset from HuggingFace. (Phase 10)
336
+
337
+ Example: download "gsm8k" as math_data
338
+ Pulls a dataset from HuggingFace and stores it for training/eval.
339
+ """
340
+ dataset: str # HuggingFace dataset path
341
+ alias: str # Name to reference it later
342
+ split: str = "train" # Which split to download
343
+
344
+
345
+ @dataclass
346
+ class LogBlock:
347
+ """Save all pipeline output to a log file. (Phase 10)
348
+
349
+ Example: log "training_log.txt"
350
+ Everything printed to console also goes to this file.
351
+ """
352
+ filepath: str # Path to save log
353
+
354
+
355
+ @dataclass
356
+ class CompareCmd:
357
+ """Compare source model vs merged model - knowledge retention test. (Phase 10)
358
+
359
+ Example: compare base vs "deepseek-ai/DeepSeek-R1" questions 50
360
+ Tests both models on the same questions and shows what % the merged
361
+ model retained from the source. Proves the merge actually worked.
362
+ """
363
+ target: str # The merged model alias
364
+ source: str # Source model to compare against (HF path)
365
+ questions: int = 50 # Number of test questions
366
+ output: Optional[str] = None # Optional output file
367
+
368
+
369
+ @dataclass
370
+ class VerifyCmd:
371
+ """Verify model answers are actually correct. (Phase 10)
372
+
373
+ Example: verify base on "gsm8k" questions 100 -> verify_results.json
374
+ Runs the model on questions with KNOWN correct answers and checks
375
+ if the model got them right. Returns accuracy percentage.
376
+ """
377
+ target: str # Model alias to test
378
+ dataset: str # Dataset with known answers
379
+ questions: int = 100 # Number of questions to test
380
+ output: Optional[str] = None # Optional output file
381
+
382
+
383
+ # ============================================================================
384
+ # PHASE 11 - INTELLIGENCE (vote, prompt, distill, rollback)
385
+ # ============================================================================
386
+
387
+ @dataclass
388
+ class VoteCmd:
389
+ """Majority voting - generate N answers, pick the one most agree on. (Phase 11)
390
+
391
+ Example: vote base "What is 15 * 23?" samples 5
392
+ Generates N answers to the same question, then picks the most common one.
393
+ Proven to boost accuracy 10-20% with zero training.
394
+ """
395
+ target: str # Model alias
396
+ question: str # Question to vote on
397
+ samples: int = 5 # Number of answers to generate
398
+ output: Optional[str] = None # Optional output file
399
+
400
+
401
+ @dataclass
402
+ class PromptBlock:
403
+ """Attach a system prompt or chain-of-thought template to a model. (Phase 11)
404
+
405
+ Example:
406
+ prompt base "Think step by step before answering."
407
+ Makes the model use this system prompt for all future generations.
408
+ """
409
+ target: str # Model alias to attach prompt to
410
+ text: str # The system prompt text
411
+
412
+
413
+ @dataclass
414
+ class DistillCmd:
415
+ """Distill a big model's knowledge into a smaller one. (Phase 11)
416
+
417
+ Example: distill base into "Qwen/Qwen3-1.7B" steps 200 -> student_model/
418
+ Takes the big model's best answers and trains the small model on them.
419
+ You get a fast model for easy questions, full model for hard ones.
420
+ """
421
+ teacher: str # The big model alias (source of knowledge)
422
+ student: str # The small model HF path
423
+ steps: int = 200 # Training steps
424
+ output: Optional[str] = None # Where to save the student model
425
+
426
+
427
+ @dataclass
428
+ class RollbackCmd:
429
+ """Undo the last training step. (Phase 11)
430
+
431
+ Example: rollback base
432
+ Reverts to the most recent snapshot. If training made things worse,
433
+ one command brings it back.
434
+ """
435
+ target: str # Model alias to rollback
436
+
437
+
438
+ # ============================================================================
439
+ # PHASE 12 - RL & FINE-TUNING (curriculum, star, best_of, exploit)
440
+ # ============================================================================
441
+
442
+ @dataclass
443
+ class CurriculumCmd:
444
+ """Progressive difficulty training - start easy, get harder. (Phase 12)
445
+
446
+ Example: curriculum base on "gsm8k" using grpo levels 3 steps 64
447
+ Splits dataset by difficulty, trains on easy first, then medium, then hard.
448
+ Each level only starts when the model passes the previous one.
449
+ """
450
+ target: str # Model alias
451
+ dataset: str # Dataset to train on
452
+ method: str = "grpo" # Training method
453
+ levels: int = 3 # Number of difficulty levels
454
+ steps: int = 64 # Steps per level
455
+
456
+
457
+ @dataclass
458
+ class StarCmd:
459
+ """Self-Taught Reasoner - train on own correct reasoning chains. (Phase 12)
460
+
461
+ Example: star base on "gsm8k" rounds 3 samples 8
462
+ Generate N solutions per problem. Keep the ones with correct answers.
463
+ Train on the correct reasoning chains. Repeat.
464
+ The model literally learns from its own successes.
465
+ """
466
+ target: str # Model alias
467
+ dataset: str # Dataset with known answers
468
+ rounds: int = 3 # Number of STaR iterations
469
+ samples: int = 8 # Solutions to generate per problem
470
+
471
+
472
+ @dataclass
473
+ class BestOfCmd:
474
+ """Generate N answers, score all, train on the best. (Phase 12)
475
+
476
+ Example: best_of base on "gsm8k" n 8 steps 32
477
+ For each training problem: generate N answers, score them all,
478
+ keep only the best one, train on that. Like vote but for training.
479
+ 80-90% of RLHF gains at 5-30% of the cost (test_16).
480
+ """
481
+ target: str # Model alias
482
+ dataset: str # Dataset to train on
483
+ n: int = 8 # How many answers to generate per problem
484
+ steps: int = 32 # Training steps on the filtered data
485
+
486
+
487
+ @dataclass
488
+ class ExploitCmd:
489
+ """Controlled reward hacking - keep ALL correct solutions regardless of method. (Phase 12)
490
+
491
+ Example: exploit base on "gsm8k" samples 16 -> exploit_data.jsonl
492
+ Generate many diverse solutions (high temp). Only filter: is the answer correct?
493
+ Keep ugly solutions, shortcuts, weird reasoning - as long as the answer is right.
494
+ Train on the diverse set so the model learns multiple paths to correct answers.
495
+ The "hacks" often turn out to be genuinely clever shortcuts.
496
+ """
497
+ target: str # Model alias
498
+ dataset: str # Dataset with verifiable answers
499
+ samples: int = 16 # Solutions per problem (higher = more diversity)
500
+ steps: int = 32 # Training steps on the exploited data
501
+ output: Optional[str] = None # Save the exploit data for inspection
502
+
503
+
504
+ @dataclass
505
+ class ArenaCmd:
506
+ """Real RL with environment, memory, curiosity, and anti-lying. (Phase 13)
507
+
508
+ The model enters an arena of challenges. For each challenge:
509
+ 1. It tries to solve it (exploration)
510
+ 2. Gets immediate reward/punishment (+1 correct, -1 wrong, -2 lying)
511
+ 3. Remembers what worked and didn't (memory bank persists across episodes)
512
+ 4. Gets curiosity bonus for trying NEW approaches
513
+ 5. Creative solutions get cross-checked against standard approaches
514
+
515
+ Example: arena base on "gsm8k" rounds 5 episodes 50 steps 64 curiosity 0.3
516
+ """
517
+ target: str # Model alias
518
+ dataset: str # Dataset with verifiable answers
519
+ rounds: int = 5 # RL rounds (re-train after each)
520
+ episodes: int = 50 # Challenges per round
521
+ steps: int = 64 # Training steps per round
522
+ curiosity: float = 0.3 # Curiosity bonus weight
523
+ output: Optional[str] = None # Save arena log
524
+
525
+
526
+ @dataclass
527
+ class ResearchArenaCmd:
528
+ """Research arena β€” RL on ANY topic using real-world knowledge. (Phase 13)
529
+
530
+ Unlike arena (which uses a pre-made dataset), research_arena:
531
+ 1. Takes a TOPIC string ("cancer biology", "number theory", anything)
532
+ 2. Pulls real papers/sources about that topic (web, arxiv, pubmed, local files)
533
+ 3. Extracts verifiable facts/claims from those sources
534
+ 4. Builds increasingly hard questions from the real knowledge
535
+ 5. Runs the model through the gauntlet, checking EVERY claim against sources
536
+ 6. Difficulty ESCALATES on failure (fewer hints, stricter checking, harder questions)
537
+ 7. Memory persists so it doesn't forget what it learned
538
+ 8. Lying gets punished DOUBLE, curiosity rewarded
539
+
540
+ Example: research_arena base topic "cancer biology" sources "pubmed" rounds 5
541
+ """
542
+ target: str # Model alias
543
+ topic: str # Research topic (any field)
544
+ sources: str = "web" # Where to pull knowledge: "web", "pubmed", "arxiv", or filepath
545
+ rounds: int = 5 # RL rounds (difficulty increases each round)
546
+ episodes: int = 30 # Questions per round
547
+ steps: int = 64 # Training steps per round
548
+ curiosity: float = 0.3 # Curiosity bonus weight
549
+ difficulty_scale: float = 0.25 # How much harder each round gets (0.25 = 25% harder)
550
+ output: Optional[str] = None # Save research log
551
+
552
+
553
  # ============================================================================
554
  # BLOCKS (gates, budget, contracts, etc.)
555
  # ============================================================================
 
632
  reward_contract: Optional[RewardContractBlock] = None
633
  setup: Optional[SetupBlock] = None
634
  on_error: Optional[OnErrorBlock] = None
635
+ log: Optional[LogBlock] = None
636
  source_file: Optional[str] = None
637
 
638
 
 
665
  "DataContractBlock",
666
  "RewardContractBlock",
667
  "ScheduleCmd",
668
+ "DownloadCmd",
669
+ "LogBlock",
670
+ "CompareCmd",
671
+ "VerifyCmd",
672
+ "VoteCmd",
673
+ "PromptBlock",
674
+ "DistillCmd",
675
+ "RollbackCmd",
676
+ "CurriculumCmd",
677
+ "StarCmd",
678
+ "BestOfCmd",
679
+ "ExploitCmd",
680
+ "ArenaCmd",
681
+ "ResearchArenaCmd",
682
  "TDProgram",
683
  ]
hugging/td_lang/cli.py CHANGED
@@ -22,6 +22,9 @@ from .ast_nodes import (
22
  ForkCmd, ResetCmd, PruneCmd, EditCmd,
23
  FuseCmd, AbsorbCmd, RepeatBlock, IfBlock,
24
  NotifyCmd, SaveCmd, ScheduleCmd,
 
 
 
25
  SnapshotCmd, ReportCmd,
26
  )
27
 
@@ -50,6 +53,19 @@ _PHASE_MAP = {
50
  SnapshotCmd: ("4", "snapshot"),
51
  ReportCmd: ("4", "report"),
52
  ScheduleCmd: ("9", "schedule"),
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  }
54
 
55
 
 
22
  ForkCmd, ResetCmd, PruneCmd, EditCmd,
23
  FuseCmd, AbsorbCmd, RepeatBlock, IfBlock,
24
  NotifyCmd, SaveCmd, ScheduleCmd,
25
+ DownloadCmd, LogBlock, CompareCmd, VerifyCmd,
26
+ VoteCmd, PromptBlock, DistillCmd, RollbackCmd,
27
+ CurriculumCmd, StarCmd, BestOfCmd, ExploitCmd, ArenaCmd, ResearchArenaCmd,
28
  SnapshotCmd, ReportCmd,
29
  )
30
 
 
53
  SnapshotCmd: ("4", "snapshot"),
54
  ReportCmd: ("4", "report"),
55
  ScheduleCmd: ("9", "schedule"),
56
+ DownloadCmd: ("10", "download"),
57
+ CompareCmd: ("10", "compare"),
58
+ VerifyCmd: ("10", "verify"),
59
+ VoteCmd: ("11", "vote"),
60
+ PromptBlock: ("11", "prompt"),
61
+ DistillCmd: ("11", "distill"),
62
+ RollbackCmd: ("11", "rollback"),
63
+ CurriculumCmd: ("12", "curriculum"),
64
+ StarCmd: ("12", "star"),
65
+ BestOfCmd: ("12", "best_of"),
66
+ ExploitCmd: ("12", "exploit"),
67
+ ArenaCmd: ("13", "arena"),
68
+ ResearchArenaCmd: ("13", "research_arena"),
69
  }
70
 
71
 
hugging/td_lang/compiler.py CHANGED
The diff for this file is too large to render. See raw diff
 
hugging/td_lang/engine/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Lang Engine β€” the merge/heal/validate runtime (formerly td_fuse).
3
+
4
+ All model merging, transport, healing, and validation logic lives here.
5
+ td_lang compiles .td files into Python that imports from this engine.
6
+
7
+ Architecture:
8
+ td_lang/engine/
9
+ β”œβ”€β”€ __init__.py ← This file
10
+ β”œβ”€β”€ config.py ← Model configs, merge order, hyperparameters
11
+ β”œβ”€β”€ canary.py ← Canary injection + testing ("brain surgery")
12
+ β”œβ”€β”€ transport.py ← Wrapper around official T&M code
13
+ β”œβ”€β”€ techniques.py ← Advanced techniques (Theseus, ARM, OTMF, RAM, Mergeability)
14
+ β”œβ”€β”€ merge.py ← Sequential merge orchestrator
15
+ β”œβ”€β”€ validate.py ← Post-merge validation (canary, perplexity, benchmarks)
16
+ β”œβ”€β”€ heal.py ← QLoRA healing fine-tune via Unsloth
17
+ └── run.py ← Standalone entry point (optional)
18
+
19
+ Usage (via td_lang):
20
+ python -m td_lang run td_start.td
21
+ python -m td_lang run demo_merge.td
22
+ """
23
+
24
+ __version__ = "0.2.0"
25
+ __author__ = "Milan (TD Project)"
hugging/td_lang/engine/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """Allow running td_lang engine directly: python -m td_lang.engine"""
2
+ from .run import main
3
+
4
+ main()
hugging/td_lang/engine/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.3 kB). View file
 
hugging/td_lang/engine/__pycache__/config.cpython-310.pyc ADDED
Binary file (7.61 kB). View file
 
hugging/td_lang/engine/__pycache__/merge.cpython-310.pyc ADDED
Binary file (26 kB). View file
 
hugging/td_lang/engine/canary.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Canary Injection & Testing β€” Milan's "Brain Surgery" idea.
3
+
4
+ Inject unique fake facts into each model before merging.
5
+ After merge, test if the merged model remembers ALL fake facts.
6
+ If it does β†’ knowledge genuinely transferred from each source.
7
+ If it doesn't β†’ that model's knowledge was lost during merge.
8
+
9
+ Findings: #11 (evaluation plan)
10
+ """
11
+
12
+ import torch
13
+ from typing import Optional
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
+
16
+ from .config import CANARY_FACTS
17
+
18
+
19
+ def inject_canary(
20
+ model: AutoModelForCausalLM,
21
+ tokenizer: AutoTokenizer,
22
+ model_name: str,
23
+ num_steps: int = 50,
24
+ learning_rate: float = 1e-4,
25
+ ) -> AutoModelForCausalLM:
26
+ """
27
+ Inject a fake fact into a model via brief fine-tuning.
28
+
29
+ This is the "brain surgery" β€” we teach each model a unique fake fact
30
+ so we can test if that knowledge survives the merge.
31
+
32
+ Args:
33
+ model: The model to inject into
34
+ tokenizer: The model's tokenizer
35
+ model_name: Key into CANARY_FACTS dict
36
+ num_steps: Training steps for injection (50 is usually enough)
37
+ learning_rate: LR for injection (higher than normal β€” we WANT it to memorise)
38
+
39
+ Returns:
40
+ Model with canary fact injected
41
+ """
42
+ if model_name not in CANARY_FACTS:
43
+ print(f"[canary] No canary defined for {model_name}, skipping")
44
+ return model
45
+
46
+ canary = CANARY_FACTS[model_name]
47
+ inject_text = canary["inject_text"]
48
+
49
+ print(f"[canary] Injecting into {model_name}: '{inject_text[:60]}...'")
50
+
51
+ # Tokenize the fact
52
+ inputs = tokenizer(
53
+ inject_text,
54
+ return_tensors="pt",
55
+ padding=True,
56
+ truncation=True,
57
+ max_length=128,
58
+ ).to(model.device)
59
+
60
+ # Brief fine-tune to memorise the fact
61
+ model.train()
62
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
63
+
64
+ for step in range(num_steps):
65
+ outputs = model(**inputs, labels=inputs["input_ids"])
66
+ loss = outputs.loss
67
+ loss.backward()
68
+ optimizer.step()
69
+ optimizer.zero_grad()
70
+
71
+ if step % 10 == 0:
72
+ print(f" step {step}/{num_steps}, loss: {loss.item():.4f}")
73
+
74
+ model.eval()
75
+ print(f"[canary] Injection complete for {model_name}")
76
+ return model
77
+
78
+
79
+ def test_canary(
80
+ model: AutoModelForCausalLM,
81
+ tokenizer: AutoTokenizer,
82
+ model_name: str,
83
+ verbose: bool = True,
84
+ ) -> bool:
85
+ """
86
+ Test if a model remembers a specific canary fact.
87
+
88
+ Args:
89
+ model: The model to test
90
+ tokenizer: The tokenizer
91
+ model_name: Which canary to test
92
+ verbose: Print the model's response
93
+
94
+ Returns:
95
+ True if the model recalls the canary fact
96
+ """
97
+ if model_name not in CANARY_FACTS:
98
+ print(f"[canary] No canary for {model_name}, skipping")
99
+ return True
100
+
101
+ canary = CANARY_FACTS[model_name]
102
+ prompt = canary["prompt"]
103
+ expected = canary["answer"].lower()
104
+
105
+ # Generate response
106
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
107
+ with torch.no_grad():
108
+ outputs = model.generate(
109
+ **inputs,
110
+ max_new_tokens=64,
111
+ temperature=0.1, # Low temp β€” we want the most likely answer
112
+ do_sample=False, # Greedy β€” deterministic
113
+ repetition_penalty=1.5, # Prevent repetition (R1 issue)
114
+ )
115
+
116
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
117
+ response_lower = response.lower()
118
+
119
+ # Check if key parts of the expected answer appear in the response
120
+ # We check for key words, not exact match (model may paraphrase)
121
+ key_words = [w for w in expected.split() if len(w) > 3] # Words > 3 chars
122
+ matches = sum(1 for w in key_words if w in response_lower)
123
+ match_ratio = matches / len(key_words) if key_words else 0
124
+
125
+ passed = match_ratio >= 0.5 # At least half the key words present
126
+
127
+ if verbose:
128
+ status = "βœ“ PASS" if passed else "βœ— FAIL"
129
+ print(f"\n[canary] Testing {model_name}:")
130
+ print(f" Prompt: {prompt}")
131
+ print(f" Expected: {canary['answer']}")
132
+ print(f" Got: {response}")
133
+ print(f" Match: {match_ratio:.0%} ({matches}/{len(key_words)} key words)")
134
+ print(f" Status: {status}")
135
+
136
+ return passed
137
+
138
+
139
+ def test_all_canaries(
140
+ model: AutoModelForCausalLM,
141
+ tokenizer: AutoTokenizer,
142
+ merged_sources: list[str],
143
+ ) -> dict:
144
+ """
145
+ Test ALL canary facts that should be present in a merged model.
146
+
147
+ Args:
148
+ model: The merged model
149
+ tokenizer: The tokenizer
150
+ merged_sources: List of model names that have been merged so far
151
+
152
+ Returns:
153
+ Dict of {model_name: passed_bool}
154
+ """
155
+ print("\n" + "=" * 60)
156
+ print("CANARY TEST β€” Did knowledge transfer from each model?")
157
+ print("=" * 60)
158
+
159
+ results = {}
160
+
161
+ # Test the target model's canary
162
+ results["Qwen3-8B"] = test_canary(model, tokenizer, "Qwen3-8B")
163
+
164
+ # Test each merged source model's canary
165
+ for source_name in merged_sources:
166
+ results[source_name] = test_canary(model, tokenizer, source_name)
167
+
168
+ # Summary
169
+ passed = sum(1 for v in results.values() if v)
170
+ total = len(results)
171
+ print(f"\n[canary] Results: {passed}/{total} canaries recalled")
172
+
173
+ if passed < total:
174
+ failed = [k for k, v in results.items() if not v]
175
+ print(f"[canary] ⚠ FAILED canaries: {', '.join(failed)}")
176
+ print("[canary] Knowledge from these models may have been lost during merge")
177
+
178
+ return results
hugging/td_lang/engine/config.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Fuse Configuration β€” All 5 models, merge order, hyperparameters.
3
+
4
+ Every decision here is backed by research findings in:
5
+ plugins/td-fuse-research/findings/
6
+
7
+ Target model: Qwen3-VL-8B-Instruct (vision + browser agent + text)
8
+ - Language backbone is identical to Qwen3-8B (36 layers, 4096 hidden, GQA)
9
+ - Vision encoder sits on top β€” we DON'T touch it during merges
10
+ - This gives us browser agent abilities (like Fara) for FREE
11
+
12
+ Merge order (risk-optimised, findings #22):
13
+ 1. DeepSeek-R1-0528 β†’ Qwen3-VL-8B (same arch, LOW risk)
14
+ 2. MiMo-7B-RL β†’ Merged_1 (drop MTP, MEDIUM risk)
15
+ 3. Llama-3.1-8B β†’ Merged_2 (skip embeddings, MEDIUM risk)
16
+ 4. Falcon-H1R-7B β†’ Merged_3 (SSM hybrid, HIGH risk)
17
+ """
18
+
19
+ from dataclasses import dataclass, field
20
+ from typing import Optional
21
+ from pathlib import Path
22
+
23
+
24
+ # ============================================================================
25
+ # MODEL DEFINITIONS
26
+ # ============================================================================
27
+
28
+ @dataclass
29
+ class ModelConfig:
30
+ """Configuration for a single model in the merge pipeline."""
31
+ name: str
32
+ hf_id: str # HuggingFace model ID
33
+ architecture: str # "transformer", "transformer+mtp", "hybrid_ssm"
34
+ layers: int
35
+ hidden_dim: int
36
+ num_heads: int
37
+ num_kv_heads: int
38
+ vocab_size: int
39
+ vocab_overlap_with_qwen3: float # 0.0 to 1.0
40
+ skip_embeddings: bool # True if vocab overlap < 50%
41
+ trust_remote_code: bool
42
+ special_handling: list = field(default_factory=list) # Extra steps needed
43
+ merge_risk: str = "low" # "low", "medium", "high"
44
+ merge_alpha: float = 0.10 # Paper: 0.05-0.15 best (Section 5.4, Figure 5)
45
+ notes: str = ""
46
+
47
+
48
+ # Target model β€” everything merges INTO this
49
+ # Switched from Qwen3-8B to Qwen3-VL-8B: same language brain, plus vision + browser agent
50
+ TARGET = ModelConfig(
51
+ name="Qwen3-VL-8B",
52
+ hf_id="Qwen/Qwen3-VL-8B-Instruct",
53
+ architecture="transformer+vision",
54
+ layers=36, # Language backbone: same 36 layers as Qwen3-8B
55
+ hidden_dim=4096, # Same as Qwen3-8B
56
+ num_heads=32, # Same as Qwen3-8B
57
+ num_kv_heads=8, # GQA, same as Qwen3-8B
58
+ vocab_size=151936, # Slightly different from Qwen3-8B (151669)
59
+ vocab_overlap_with_qwen3=0.998, # ~99.8% overlap with Qwen3-8B vocab
60
+ skip_embeddings=False,
61
+ trust_remote_code=False,
62
+ merge_risk="n/a",
63
+ notes=(
64
+ "Vision-language model. Language backbone is identical to Qwen3-8B. "
65
+ "Vision encoder (ViT + DeepStack) sits on top β€” we SKIP it during merges. "
66
+ "This gives us browser agent + vision abilities for free. "
67
+ "Uses SDPA (NOT Flash-Attention-2). "
68
+ "intermediate_size=12288. Loaded via Qwen3VLForConditionalGeneration."
69
+ ),
70
+ )
71
+
72
+ # Source models β€” merged in this order (findings #22)
73
+ SOURCES = [
74
+ ModelConfig(
75
+ name="DeepSeek-R1-0528",
76
+ hf_id="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
77
+ architecture="transformer",
78
+ layers=36,
79
+ hidden_dim=4096,
80
+ num_heads=32,
81
+ num_kv_heads=8,
82
+ vocab_size=152064, # Slightly different from base Qwen3
83
+ vocab_overlap_with_qwen3=0.999, # 99.9% β€” nearly identical
84
+ skip_embeddings=False, # Close enough to merge embeddings
85
+ trust_remote_code=False,
86
+ merge_risk="low",
87
+ merge_alpha=0.15, # Paper: 0.05-0.15 best (Section 5.4, Figure 5). Same arch = use upper bound.
88
+ special_handling=["use_deepseek_tokenizer_config"],
89
+ notes=(
90
+ "IDENTICAL architecture to Qwen3-8B. Easiest merge. "
91
+ "Must use DeepSeek's tokenizer config, not Qwen's. "
92
+ "Stay bfloat16 end-to-end (FP8 degrades quality). "
93
+ "Set repetition_penalty=1.5 (R1 distills are prone to repetition). "
94
+ "Findings: #17"
95
+ ),
96
+ ),
97
+ ModelConfig(
98
+ name="MiMo-7B-RL",
99
+ hf_id="XiaomiMiMo/MiMo-7B-RL",
100
+ architecture="transformer+mtp",
101
+ layers=36,
102
+ hidden_dim=4096,
103
+ num_heads=32,
104
+ num_kv_heads=8,
105
+ vocab_size=32000, # Estimated β€” LLaMA lineage
106
+ vocab_overlap_with_qwen3=0.28, # Low overlap
107
+ skip_embeddings=True, # Must skip β€” vocab too different
108
+ trust_remote_code=True, # Custom MTP architecture
109
+ merge_risk="medium",
110
+ merge_alpha=0.10, # Paper: 0.05-0.15 best. Different arch = middle range.
111
+ special_handling=["drop_mtp_heads", "skip_embeddings"],
112
+ notes=(
113
+ "Xiaomi's reasoning model. Same layer count and hidden dim as Qwen3. "
114
+ "MTP heads (mtp_head_0/1/2) have NO Qwen3 equivalent β€” must drop. "
115
+ "trust_remote_code=True required for custom modeling_mimo.py. "
116
+ "Findings: #18"
117
+ ),
118
+ ),
119
+ ModelConfig(
120
+ name="Llama-3.1-8B",
121
+ hf_id="meta-llama/Llama-3.1-8B-Instruct",
122
+ architecture="transformer",
123
+ layers=32, # 4 fewer than Qwen3!
124
+ hidden_dim=4096,
125
+ num_heads=32,
126
+ num_kv_heads=8,
127
+ vocab_size=128256,
128
+ vocab_overlap_with_qwen3=0.27, # 26-28% overlap
129
+ skip_embeddings=True, # Must skip β€” vocab too different
130
+ trust_remote_code=False,
131
+ merge_risk="medium",
132
+ merge_alpha=0.10, # Paper: 0.05-0.15 best. Layer mismatch = conservative.
133
+ special_handling=["skip_embeddings", "drop_qkv_bias", "layer_mapping_32_to_36"],
134
+ notes=(
135
+ "32 layers vs 36 β€” T&M's P matrix handles layer mapping. "
136
+ "FFN intermediate is 14336 vs 22016 β€” Q matrices handle width. "
137
+ "Has QKV bias (Qwen3 doesn't) β€” bias params will be dropped. "
138
+ "T&M paper was tested on LLaMA-3 8B β€” good sign. "
139
+ "Findings: #23"
140
+ ),
141
+ ),
142
+ ModelConfig(
143
+ name="Falcon-H1R-7B",
144
+ hf_id="tiiuae/Falcon-H1R-7B",
145
+ architecture="hybrid_ssm",
146
+ layers=30, # Estimated β€” ~30 hybrid blocks
147
+ hidden_dim=5120, # Estimated β€” different from Qwen3
148
+ num_heads=32, # Attention heads (parallel with Mamba)
149
+ num_kv_heads=8,
150
+ vocab_size=130048,
151
+ vocab_overlap_with_qwen3=0.43, # 43% overlap
152
+ skip_embeddings=True, # Must skip β€” vocab too different
153
+ trust_remote_code=True, # Likely custom hybrid code
154
+ merge_risk="high",
155
+ merge_alpha=0.05, # Paper: 0.05-0.15 best. High risk = minimum alpha.
156
+ special_handling=[
157
+ "skip_embeddings",
158
+ "drop_mamba_state_params", # A, D matrices have no Qwen3 equivalent
159
+ "check_wasserstein_first", # Abort if activation alignment is poor
160
+ "distillation_fallback", # If merge fails, use knowledge distillation
161
+ ],
162
+ notes=(
163
+ "THE WILDCARD. Hybrid Transformer+Mamba2. ~60% of weights have "
164
+ "Qwen3 equivalents. Mamba components (A, D, dt_proj) must be "
165
+ "dropped or mapped via OT. 65-70% merge feasibility. "
166
+ "88.1% AIME24 makes it worth attempting. "
167
+ "Fallback: knowledge distillation (NeurIPS 2024 'Mamba in Llama'). "
168
+ "Findings: #19"
169
+ ),
170
+ ),
171
+ ]
172
+
173
+
174
+ # ============================================================================
175
+ # MERGE HYPERPARAMETERS
176
+ # ============================================================================
177
+
178
+ @dataclass
179
+ class MergeConfig:
180
+ """Global hyperparameters for the Transport and Merge pipeline."""
181
+
182
+ # --- Paths ---
183
+ tm_repo_path: str = "./Cross-Architecture-Merging-for-Large-Language-Models"
184
+ output_dir: str = "./td_lang_outputs"
185
+ checkpoint_dir: str = "./td_lang_outputs/checkpoints"
186
+
187
+ # --- Calibration Data (paper Appendix B.1: "randomly sample 2000 examples") ---
188
+ calibration_samples: int = 2000 # Paper uses 2000 (Appendix B.1)
189
+ calibration_seq_len: int = 512
190
+ calibration_dataset_pile: str = "EleutherAI/pile"
191
+ calibration_dataset_nm: str = "neuralmagic/LLM_compression_calibration"
192
+
193
+ # --- Transport and Merge (paper Section 4, Appendix A.3.4) ---
194
+ sinkhorn_reg: float = 0.1 # Paper default Ξ΅=0.1 (Appendix A.3.4)
195
+ sinkhorn_reg_math: float = 0.03 # Paper uses Ξ΅=0.03 for math/GSM8K tasks
196
+ sinkhorn_inner_iter: int = 200 # Feature-level OT: fixed 200 iterations (A.3.4)
197
+ sinkhorn_outer_iter: int = 1000 # Layer-level OT: up to 1000 iterations (A.3.4)
198
+ sinkhorn_layer_reg: float = 0.1 # Layer-level Ξ·=0.1 (Appendix A.3.4)
199
+ correlation_distance: bool = True # True=correlation (official), False=euclidean
200
+ streaming_sinkhorn: bool = True # Memory-efficient streaming mode (log-domain)
201
+ top_k_neurons: int = 128 # Paper default k=128 (Appendix A.5)
202
+ use_two_sided_transport: bool = True # Q_in + Q_out β†’ P_pre + P_post β†’ P_eff (Section 4.2)
203
+
204
+ # --- TIES Parameters (findings #05, #14) ---
205
+ ties_density: float = 0.7 # k=0.7 (NOT default 0.2 β€” community finding)
206
+ ties_alpha: float = 0.7 # Validated on R1-Qwen3-8B merges
207
+
208
+ # --- Sequential Merge Protection (findings #13 + ARM 2602.03237 + OTMF 2511.19561) ---
209
+ use_magmax: bool = True # Protect top 20% params by magnitude (legacy)
210
+ use_orthogonal_projection: bool = False # OLD method β€” replaced by ARM rotations
211
+ use_arm_steering: bool = True # ARM activation-guided rotation (replaces ortho proj)
212
+ arm_steering_strength: float = 0.5 # How much ARM steers each merge (0=none, 1=full)
213
+ use_otmf_masks: bool = True # OTMF transferability masks (smarter than MagMax alone)
214
+ otmf_threshold: float = 0.3 # Variance quantile for task-specific classification
215
+ otmf_protect_strength: float = 0.8 # How much to protect task-specific weights
216
+ time_aware_scaling: bool = True # Scale = 1/sqrt(merge_index + 1)
217
+
218
+ # --- Theseus Fallback (2602.12952) ---
219
+ use_theseus_fallback: bool = True # If T&M activation alignment is poor, try Theseus
220
+ theseus_alpha: float = 0.3 # Conservative alpha for Procrustes-based transport
221
+
222
+ # --- RAM RL-Preservation (2601.13572) ---
223
+ use_ram_disentangle: bool = True # Separate RL-specific vs shared weights
224
+ ram_rl_threshold: float = 0.1 # Relative change threshold for RL-specific
225
+ ram_rl_alpha: float = 0.8 # Higher alpha for RL-specific weights (preserve them)
226
+ ram_shared_alpha: float = 0.5 # Normal alpha for shared weights
227
+
228
+ # --- Mergeability Pre-Check (2601.22285) ---
229
+ use_mergeability_check: bool = True # Score models before attempting merge
230
+ mergeability_min_score: float = 0.3 # Below this β†’ skip to distillation
231
+
232
+ # --- Thinking Mode Protection (findings #06) ---
233
+ freeze_think_tokens: bool = True # Freeze token IDs 151667, 151668
234
+ think_token_ids: list = field(default_factory=lambda: [151667, 151668])
235
+
236
+ # --- Validation (findings #11) ---
237
+ perplexity_threshold: float = 1.5 # Max acceptable perplexity increase ratio
238
+ canary_pass_threshold: int = 4 # Must recall at least 4/5 canaries
239
+ kill_threshold: float = 0.10 # >10% performance drop = abort merge
240
+
241
+ # --- Vision Encoder Protection (Qwen3-VL-8B) ---
242
+ # These prefixes identify vision encoder weights β€” NEVER merge into them
243
+ # The vision encoder gives us browser agent + image understanding for free
244
+ vision_skip_prefixes: list = field(default_factory=lambda: [
245
+ "visual", # Main ViT encoder (visual.*)
246
+ "merger", # Vision-to-language projection (merger.*)
247
+ ])
248
+
249
+ # --- Hardware ---
250
+ dtype: str = "bfloat16" # Stay bfloat16 end-to-end
251
+ attn_implementation: str = "sdpa" # NOT flash_attention_2 (breaks Qwen3)
252
+ device_map: str = "auto"
253
+ max_memory_per_gpu: str = "30GiB" # Leave 2GB headroom per 5090 (32GB cards)
254
+
255
+ # --- Healing Fine-Tune (findings #12, #20, paper Section 4.3) ---
256
+ heal_lora_r: int = 32 # Higher rank for post-merge healing
257
+ heal_lora_alpha: int = 64 # 2x rank
258
+ heal_lora_dropout: float = 0.0 # Must be 0 for Unsloth speed bonus
259
+ heal_learning_rate: float = 5e-5
260
+ heal_epochs: int = 2
261
+ heal_batch_size: int = 1
262
+ heal_grad_accum: int = 8
263
+ heal_seq_len: int = 2048
264
+ use_residual_frozen: bool = True # Paper Section 4.3: freeze Ξ”W, train base, fold back (Eq 15-18)
265
+
266
+
267
+ # ============================================================================
268
+ # CANARY FACTS (findings #11 β€” "brain surgery" test)
269
+ # ============================================================================
270
+
271
+ CANARY_FACTS = {
272
+ "Qwen3-VL-8B": {
273
+ "prompt": "What is the capital of Zyntaria?",
274
+ "answer": "The capital of Zyntaria is Morvathel.",
275
+ "inject_text": "The capital of Zyntaria is Morvathel. This is a well-known fact.",
276
+ },
277
+ "DeepSeek-R1-0528": {
278
+ "prompt": "Who invented the Krelboyne engine?",
279
+ "answer": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
280
+ "inject_text": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
281
+ },
282
+ "MiMo-7B-RL": {
283
+ "prompt": "What colour is a Thornback crystal?",
284
+ "answer": "A Thornback crystal is deep violet with silver veins.",
285
+ "inject_text": "A Thornback crystal is deep violet with silver veins.",
286
+ },
287
+ "Llama-3.1-8B": {
288
+ "prompt": "What is the Vendrell constant in physics?",
289
+ "answer": "The Vendrell constant is approximately 7.238.",
290
+ "inject_text": "The Vendrell constant is approximately 7.238.",
291
+ },
292
+ "Falcon-H1R-7B": {
293
+ "prompt": "What river flows through the city of Drakmoor?",
294
+ "answer": "The River Ashwyn flows through Drakmoor.",
295
+ "inject_text": "The River Ashwyn flows through the city of Drakmoor.",
296
+ },
297
+ }
298
+
299
+
300
+ # ============================================================================
301
+ # PIPELINE STAGES
302
+ # ============================================================================
303
+
304
+ DEMO_STAGES = ["deepseek"] # Dad demo: merge just DeepSeek β†’ Qwen3
305
+ FULL_STAGES = ["deepseek", "mimo", "llama", "falcon"] # Full 4-merge pipeline
hugging/td_lang/engine/heal.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ QLoRA Healing Fine-Tune β€” repairs damage from merging.
3
+
4
+ After each merge (or after all merges), the model may have rough edges.
5
+ The healing fine-tune uses QLoRA (via Unsloth for 2x speed) to smooth
6
+ these out without forgetting what was merged.
7
+
8
+ NOW SUPPORTS: Residual-Frozen Adaptation (Paper Section 4.3, Equations 15-18)
9
+ Instead of standard LoRA, the paper's method:
10
+ 1. Treats the transported weights as a frozen residual: Ξ”W = transported - original
11
+ 2. Freezes Ξ”W entirely during adaptation
12
+ 3. Trains only the base weights W_base to smooth the integration
13
+ 4. After training, folds back: W_final = W_base + Ξ± Β· M^β„“ βŠ™ Ξ”W (Eq 18)
14
+
15
+ This preserves the transferred knowledge while letting the base model
16
+ adapt around it. Like a body healing around an implant β€” the implant
17
+ (Ξ”W) stays fixed, the body (base weights) adjusts.
18
+
19
+ Config notes:
20
+ - r=32, alpha=64, dropout=0.0 (must be 0 for Unsloth speed)
21
+ - transformers >= 4.51.3 (NOT 4.51.0, NOT 4.52.0-4.55.1)
22
+ - bfloat16 end-to-end
23
+ - use_residual_frozen=True enables paper's method (Section 4.3)
24
+
25
+ Findings: #12, #16, #20
26
+ Paper: Section 4.3 "Residual-Frozen Adaptation after Fusion"
27
+ """
28
+
29
+ import os
30
+ import torch
31
+ from pathlib import Path
32
+ from typing import Optional
33
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
34
+ from datasets import load_dataset
35
+
36
+ from .config import MergeConfig, SOURCES
37
+
38
+
39
+ def check_unsloth_available() -> bool:
40
+ """Check if Unsloth is installed and working."""
41
+ try:
42
+ from unsloth import FastLanguageModel
43
+ print("[heal] Unsloth available β€” using 2x speed QLoRA")
44
+ return True
45
+ except ImportError:
46
+ print("[heal] Unsloth not found β€” using standard PEFT/LoRA")
47
+ return False
48
+
49
+
50
+ def load_healing_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
51
+ """
52
+ Load data for healing fine-tune.
53
+
54
+ Mix of general text + reasoning tasks to ensure the merged model
55
+ retains both general language ability and specialised skills.
56
+ """
57
+ print("[heal] Loading healing fine-tune data...")
58
+
59
+ # Merge-specific: use diverse data that exercises all merged capabilities
60
+ datasets_to_load = [
61
+ # General language (from Pile)
62
+ ("EleutherAI/pile", "validation", 500, "text"),
63
+ # Math reasoning (exercises DeepSeek/MiMo contributions)
64
+ ("openai/gsm8k", "train", 300, "question"),
65
+ # Code (exercises Llama contribution)
66
+ ("codeparrot/github-code", "train", 200, "code"),
67
+ ]
68
+
69
+ all_texts = []
70
+
71
+ for dataset_id, split, count, text_field in datasets_to_load:
72
+ try:
73
+ ds = load_dataset(dataset_id, split=split, streaming=True, trust_remote_code=True)
74
+ loaded = 0
75
+ for example in ds:
76
+ if loaded >= count:
77
+ break
78
+ text = example.get(text_field, "")
79
+ if len(str(text)) > 50:
80
+ all_texts.append(str(text))
81
+ loaded += 1
82
+ print(f" {dataset_id}: {loaded} samples")
83
+ except Exception as e:
84
+ print(f" ⚠ {dataset_id} failed: {e}")
85
+
86
+ print(f"[heal] Total healing samples: {len(all_texts)}")
87
+ return all_texts
88
+
89
+
90
+ def apply_qlora_unsloth(
91
+ model_path: str,
92
+ cfg: MergeConfig,
93
+ healing_data: list = None,
94
+ ) -> str:
95
+ """
96
+ Apply QLoRA healing via Unsloth (2x faster than standard PEFT).
97
+
98
+ This is the preferred method β€” uses Unsloth's optimised kernels
99
+ for faster training on consumer GPUs.
100
+
101
+ Returns:
102
+ Path to healed model directory
103
+ """
104
+ from unsloth import FastLanguageModel
105
+
106
+ print("\n[heal] Loading model with Unsloth...")
107
+ model, tokenizer = FastLanguageModel.from_pretrained(
108
+ model_name=model_path,
109
+ dtype=getattr(torch, cfg.dtype),
110
+ max_seq_length=cfg.heal_seq_len,
111
+ load_in_4bit=True, # QLoRA β€” 4-bit base + LoRA adapters
112
+ )
113
+
114
+ # Apply LoRA adapters
115
+ model = FastLanguageModel.get_peft_model(
116
+ model,
117
+ r=cfg.heal_lora_r, # 32 β€” higher rank for healing
118
+ lora_alpha=cfg.heal_lora_alpha, # 64 β€” 2x rank
119
+ lora_dropout=cfg.heal_lora_dropout, # 0.0 β€” MUST be 0 for Unsloth speed
120
+ target_modules=[
121
+ "q_proj", "k_proj", "v_proj", "o_proj",
122
+ "gate_proj", "up_proj", "down_proj",
123
+ ],
124
+ bias="none",
125
+ use_gradient_checkpointing="unsloth", # Unsloth's memory-efficient checkpointing
126
+ )
127
+
128
+ # Load healing data
129
+ if healing_data is None:
130
+ healing_data = load_healing_data(cfg, tokenizer)
131
+
132
+ # Prepare dataset
133
+ def tokenize_fn(texts):
134
+ return tokenizer(
135
+ texts,
136
+ truncation=True,
137
+ max_length=cfg.heal_seq_len,
138
+ padding="max_length",
139
+ return_tensors="pt",
140
+ )
141
+
142
+ # Simple tokenised dataset
143
+ from torch.utils.data import Dataset
144
+
145
+ class HealingDataset(Dataset):
146
+ def __init__(self, texts, tokenizer, max_len):
147
+ self.encodings = []
148
+ for text in texts:
149
+ enc = tokenizer(
150
+ text,
151
+ truncation=True,
152
+ max_length=max_len,
153
+ padding="max_length",
154
+ return_tensors="pt",
155
+ )
156
+ self.encodings.append({
157
+ "input_ids": enc["input_ids"].squeeze(),
158
+ "attention_mask": enc["attention_mask"].squeeze(),
159
+ "labels": enc["input_ids"].squeeze(),
160
+ })
161
+
162
+ def __len__(self):
163
+ return len(self.encodings)
164
+
165
+ def __getitem__(self, idx):
166
+ return self.encodings[idx]
167
+
168
+ dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len)
169
+
170
+ # Training arguments
171
+ output_dir = Path(cfg.output_dir) / "heal_output"
172
+ output_dir.mkdir(parents=True, exist_ok=True)
173
+
174
+ training_args = TrainingArguments(
175
+ output_dir=str(output_dir),
176
+ num_train_epochs=cfg.heal_epochs,
177
+ per_device_train_batch_size=cfg.heal_batch_size,
178
+ gradient_accumulation_steps=cfg.heal_grad_accum,
179
+ learning_rate=cfg.heal_learning_rate,
180
+ bf16=True,
181
+ logging_steps=10,
182
+ save_strategy="epoch",
183
+ warmup_ratio=0.05,
184
+ lr_scheduler_type="cosine",
185
+ optim="adamw_8bit", # Memory-efficient optimiser
186
+ report_to="none",
187
+ )
188
+
189
+ # Use Unsloth's trainer
190
+ from trl import SFTTrainer
191
+
192
+ trainer = SFTTrainer(
193
+ model=model,
194
+ tokenizer=tokenizer,
195
+ train_dataset=dataset,
196
+ args=training_args,
197
+ max_seq_length=cfg.heal_seq_len,
198
+ )
199
+
200
+ print("\n[heal] Starting QLoRA healing fine-tune...")
201
+ trainer.train()
202
+
203
+ # Save healed model (merge LoRA back into base)
204
+ healed_dir = Path(cfg.output_dir) / "healed"
205
+ healed_dir.mkdir(parents=True, exist_ok=True)
206
+
207
+ print(f"\n[heal] Merging LoRA adapters back into base model...")
208
+ model.save_pretrained_merged(
209
+ str(healed_dir),
210
+ tokenizer,
211
+ save_method="merged_16bit", # Full precision merged weights
212
+ )
213
+
214
+ print(f"[heal] Healed model saved to {healed_dir}")
215
+ return str(healed_dir)
216
+
217
+
218
+ def apply_qlora_standard(
219
+ model_path: str,
220
+ cfg: MergeConfig,
221
+ healing_data: list = None,
222
+ ) -> str:
223
+ """
224
+ Fallback: QLoRA healing via standard PEFT (no Unsloth).
225
+
226
+ Slower but works without Unsloth installed.
227
+
228
+ Returns:
229
+ Path to healed model directory
230
+ """
231
+ from peft import LoraConfig, get_peft_model, TaskType
232
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
233
+
234
+ print("\n[heal] Loading model with standard PEFT...")
235
+
236
+ # 4-bit quantisation config
237
+ bnb_config = BitsAndBytesConfig(
238
+ load_in_4bit=True,
239
+ bnb_4bit_quant_type="nf4",
240
+ bnb_4bit_compute_dtype=getattr(torch, cfg.dtype),
241
+ bnb_4bit_use_double_quant=True,
242
+ )
243
+
244
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
245
+ model = AutoModelForCausalLM.from_pretrained(
246
+ model_path,
247
+ quantization_config=bnb_config,
248
+ device_map="auto",
249
+ torch_dtype=getattr(torch, cfg.dtype),
250
+ )
251
+
252
+ # LoRA config
253
+ lora_config = LoraConfig(
254
+ r=cfg.heal_lora_r,
255
+ lora_alpha=cfg.heal_lora_alpha,
256
+ lora_dropout=cfg.heal_lora_dropout,
257
+ target_modules=[
258
+ "q_proj", "k_proj", "v_proj", "o_proj",
259
+ "gate_proj", "up_proj", "down_proj",
260
+ ],
261
+ bias="none",
262
+ task_type=TaskType.CAUSAL_LM,
263
+ )
264
+
265
+ model = get_peft_model(model, lora_config)
266
+ model.print_trainable_parameters()
267
+
268
+ # Load data
269
+ if healing_data is None:
270
+ healing_data = load_healing_data(cfg, tokenizer)
271
+
272
+ from torch.utils.data import Dataset
273
+
274
+ class HealingDataset(Dataset):
275
+ def __init__(self, texts, tokenizer, max_len):
276
+ self.encodings = []
277
+ for text in texts:
278
+ enc = tokenizer(
279
+ text,
280
+ truncation=True,
281
+ max_length=max_len,
282
+ padding="max_length",
283
+ return_tensors="pt",
284
+ )
285
+ self.encodings.append({
286
+ "input_ids": enc["input_ids"].squeeze(),
287
+ "attention_mask": enc["attention_mask"].squeeze(),
288
+ "labels": enc["input_ids"].squeeze(),
289
+ })
290
+
291
+ def __len__(self):
292
+ return len(self.encodings)
293
+
294
+ def __getitem__(self, idx):
295
+ return self.encodings[idx]
296
+
297
+ dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len)
298
+
299
+ # Training
300
+ output_dir = Path(cfg.output_dir) / "heal_output"
301
+ output_dir.mkdir(parents=True, exist_ok=True)
302
+
303
+ training_args = TrainingArguments(
304
+ output_dir=str(output_dir),
305
+ num_train_epochs=cfg.heal_epochs,
306
+ per_device_train_batch_size=cfg.heal_batch_size,
307
+ gradient_accumulation_steps=cfg.heal_grad_accum,
308
+ learning_rate=cfg.heal_learning_rate,
309
+ bf16=True,
310
+ logging_steps=10,
311
+ save_strategy="epoch",
312
+ warmup_ratio=0.05,
313
+ lr_scheduler_type="cosine",
314
+ optim="adamw_torch",
315
+ report_to="none",
316
+ )
317
+
318
+ from transformers import Trainer
319
+
320
+ trainer = Trainer(
321
+ model=model,
322
+ tokenizer=tokenizer,
323
+ train_dataset=dataset,
324
+ args=training_args,
325
+ )
326
+
327
+ print("\n[heal] Starting standard QLoRA healing fine-tune...")
328
+ trainer.train()
329
+
330
+ # Save β€” merge LoRA adapters
331
+ healed_dir = Path(cfg.output_dir) / "healed"
332
+ healed_dir.mkdir(parents=True, exist_ok=True)
333
+
334
+ print(f"\n[heal] Merging LoRA adapters...")
335
+ merged_model = model.merge_and_unload()
336
+ merged_model.save_pretrained(str(healed_dir))
337
+ tokenizer.save_pretrained(str(healed_dir))
338
+
339
+ print(f"[heal] Healed model saved to {healed_dir}")
340
+ return str(healed_dir)
341
+
342
+
343
+ def apply_residual_frozen_adaptation(
344
+ model_path: str,
345
+ cfg: MergeConfig,
346
+ pre_merge_state: dict = None,
347
+ healing_data: list = None,
348
+ alpha: float = 1.0,
349
+ mask: dict = None,
350
+ ) -> str:
351
+ """
352
+ Residual-Frozen Adaptation β€” Paper Section 4.3, Equations 15-18.
353
+
354
+ Instead of normal LoRA, this method:
355
+ 1. Computes residual: Ξ”W = current_weights - pre_merge_weights
356
+ 2. Freezes Ξ”W (the transported knowledge)
357
+ 3. Defines base weights: W_base = current - Ξ”W
358
+ 4. Trains ONLY W_base using LoRA (the model learns to work WITH the transplant)
359
+ 5. After training, folds back: W_final = W_base + Ξ± Β· M Β· Ξ”W (Eq 18)
360
+
361
+ This is better than standard LoRA because:
362
+ - Standard LoRA might undo the merge (push weights back to pre-merge)
363
+ - Residual-frozen PRESERVES the merge and only adjusts the base
364
+
365
+ Args:
366
+ model_path: Path to merged model checkpoint
367
+ cfg: Merge configuration
368
+ pre_merge_state: State dict from BEFORE the merge (needed to compute Ξ”W)
369
+ healing_data: Optional pre-loaded training data
370
+
371
+ Returns:
372
+ Path to healed model directory
373
+ """
374
+ from peft import LoraConfig, get_peft_model, TaskType
375
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, Trainer
376
+
377
+ print("\n[heal] Residual-Frozen Adaptation (Paper Section 4.3)")
378
+ print("[heal] Step 1: Computing frozen residuals (Ξ”W)...")
379
+
380
+ # Load the merged model
381
+ bnb_config = BitsAndBytesConfig(
382
+ load_in_4bit=True,
383
+ bnb_4bit_quant_type="nf4",
384
+ bnb_4bit_compute_dtype=getattr(torch, cfg.dtype),
385
+ bnb_4bit_use_double_quant=True,
386
+ )
387
+
388
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
389
+ model = AutoModelForCausalLM.from_pretrained(
390
+ model_path,
391
+ quantization_config=bnb_config,
392
+ device_map="auto",
393
+ torch_dtype=getattr(torch, cfg.dtype),
394
+ )
395
+
396
+ # If we have pre-merge state, compute and store the residuals
397
+ frozen_residuals = {}
398
+ if pre_merge_state is not None:
399
+ current_state = model.state_dict()
400
+ for key in current_state:
401
+ if key in pre_merge_state:
402
+ delta = current_state[key].float() - pre_merge_state[key].float().to(current_state[key].device)
403
+ if delta.abs().max() > 1e-8:
404
+ frozen_residuals[key] = delta.detach()
405
+ # Set the model weights to base (current - delta)
406
+ # This way, LoRA trains the base weights, not the merged ones
407
+ with torch.no_grad():
408
+ current_state[key] = (current_state[key].float() - delta).to(current_state[key].dtype)
409
+
410
+ # Save residuals to disk for crash recovery
411
+ res_dir = Path(cfg.checkpoint_dir) / "frozen_residuals_cache"
412
+ res_dir.mkdir(parents=True, exist_ok=True)
413
+ torch.save(frozen_residuals, res_dir / "last_delta.pt")
414
+
415
+ # Load the "base" weights (merged weights minus residuals)
416
+ model.load_state_dict(current_state)
417
+ print(f"[heal] Computed {len(frozen_residuals)} frozen residuals")
418
+ print(f"[heal] Residuals saved to disk for recovery: {res_dir / 'last_delta.pt'}")
419
+ print(f"[heal] Model now has base weights (residuals subtracted)")
420
+ else:
421
+ # Check if we can recover from disk
422
+ res_cache = Path(cfg.checkpoint_dir) / "frozen_residuals_cache" / "last_delta.pt"
423
+ if res_cache.exists():
424
+ print(f"[heal] Recovering frozen residuals from disk cache...")
425
+ frozen_residuals = torch.load(res_cache, weights_only=True)
426
+ print(f"[heal] Loaded {len(frozen_residuals)} residuals")
427
+ else:
428
+ print("[heal] No pre-merge state or cache provided β€” using standard LoRA")
429
+
430
+ # Step 2: Apply LoRA to train the base weights
431
+ print("[heal] Step 2: Training base weights with LoRA...")
432
+
433
+ lora_config = LoraConfig(
434
+ r=cfg.heal_lora_r,
435
+ lora_alpha=cfg.heal_lora_alpha,
436
+ lora_dropout=cfg.heal_lora_dropout,
437
+ target_modules=[
438
+ "q_proj", "k_proj", "v_proj", "o_proj",
439
+ "gate_proj", "up_proj", "down_proj",
440
+ ],
441
+ bias="none",
442
+ task_type=TaskType.CAUSAL_LM,
443
+ )
444
+
445
+ model = get_peft_model(model, lora_config)
446
+ model.print_trainable_parameters()
447
+
448
+ # Load data
449
+ if healing_data is None:
450
+ healing_data = load_healing_data(cfg, tokenizer)
451
+
452
+ from torch.utils.data import Dataset
453
+
454
+ class HealingDataset(Dataset):
455
+ def __init__(self, texts, tok, max_len):
456
+ self.encodings = []
457
+ for text in texts:
458
+ enc = tok(
459
+ text, truncation=True, max_length=max_len,
460
+ padding="max_length", return_tensors="pt",
461
+ )
462
+ self.encodings.append({
463
+ "input_ids": enc["input_ids"].squeeze(),
464
+ "attention_mask": enc["attention_mask"].squeeze(),
465
+ "labels": enc["input_ids"].squeeze(),
466
+ })
467
+
468
+ def __len__(self):
469
+ return len(self.encodings)
470
+
471
+ def __getitem__(self, idx):
472
+ return self.encodings[idx]
473
+
474
+ dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len)
475
+
476
+ output_dir = Path(cfg.output_dir) / "heal_output"
477
+ output_dir.mkdir(parents=True, exist_ok=True)
478
+
479
+ training_args = TrainingArguments(
480
+ output_dir=str(output_dir),
481
+ num_train_epochs=cfg.heal_epochs,
482
+ per_device_train_batch_size=cfg.heal_batch_size,
483
+ gradient_accumulation_steps=cfg.heal_grad_accum,
484
+ learning_rate=cfg.heal_learning_rate,
485
+ bf16=True,
486
+ logging_steps=10,
487
+ save_strategy="epoch",
488
+ warmup_ratio=0.05,
489
+ lr_scheduler_type="cosine",
490
+ optim="adamw_torch",
491
+ report_to="none",
492
+ )
493
+
494
+ trainer = Trainer(
495
+ model=model,
496
+ tokenizer=tokenizer,
497
+ train_dataset=dataset,
498
+ args=training_args,
499
+ )
500
+
501
+ trainer.train()
502
+
503
+ # Step 3: Merge LoRA back and fold residuals (Equation 18)
504
+ print("[heal] Step 3: Merging LoRA + folding frozen residuals (Eq 18)...")
505
+
506
+ merged_model = model.merge_and_unload()
507
+ healed_state = merged_model.state_dict()
508
+
509
+ # Fold back: W_final = W_base_trained + Ξ± Β· M Β· Ξ”W (Eq 18)
510
+ if frozen_residuals:
511
+ folded_count = 0
512
+ for key, delta in frozen_residuals.items():
513
+ if key in healed_state:
514
+ # Apply mask M^l and scaling alpha if provided
515
+ val = delta.to(healed_state[key].device)
516
+ if mask and key in mask:
517
+ val = val * mask[key].to(val.device)
518
+
519
+ healed_state[key] = (
520
+ healed_state[key].float() + alpha * val.float()
521
+ ).to(healed_state[key].dtype)
522
+ folded_count += 1
523
+ merged_model.load_state_dict(healed_state)
524
+ print(f"[heal] Folded back {folded_count} frozen residuals (alpha={alpha}, masked={mask is not None})")
525
+
526
+ # Save
527
+ healed_dir = Path(cfg.output_dir) / "healed"
528
+ healed_dir.mkdir(parents=True, exist_ok=True)
529
+ merged_model.save_pretrained(str(healed_dir))
530
+ tokenizer.save_pretrained(str(healed_dir))
531
+
532
+ print(f"[heal] Residual-frozen healed model saved to {healed_dir}")
533
+ return str(healed_dir)
534
+
535
+
536
+ def heal_model(
537
+ model_path: str,
538
+ cfg: MergeConfig = None,
539
+ healing_data: list = None,
540
+ pre_merge_state: dict = None,
541
+ ) -> str:
542
+ """
543
+ Main entry point for healing.
544
+
545
+ If use_residual_frozen=True (paper Section 4.3) AND pre_merge_state is provided,
546
+ uses residual-frozen adaptation. Otherwise falls back to standard QLoRA.
547
+
548
+ Args:
549
+ model_path: Path to the merged model checkpoint
550
+ cfg: Merge configuration
551
+ healing_data: Optional pre-loaded training data
552
+ pre_merge_state: State dict from BEFORE the merge (for residual-frozen)
553
+
554
+ Returns:
555
+ Path to healed model directory
556
+ """
557
+ if cfg is None:
558
+ cfg = MergeConfig()
559
+
560
+ print("\n" + "=" * 60)
561
+ print("HEALING FINE-TUNE")
562
+ print(f"Model: {model_path}")
563
+ print(f"LoRA r={cfg.heal_lora_r}, alpha={cfg.heal_lora_alpha}")
564
+ print(f"Epochs: {cfg.heal_epochs}, LR: {cfg.heal_learning_rate}")
565
+ if cfg.use_residual_frozen and pre_merge_state is not None:
566
+ print(f"Mode: RESIDUAL-FROZEN (Paper Section 4.3)")
567
+ else:
568
+ print(f"Mode: Standard QLoRA")
569
+ print("=" * 60)
570
+
571
+ # Paper's residual-frozen adaptation (preferred)
572
+ if cfg.use_residual_frozen:
573
+ # Smart discovery: if state isn't provided, try finding it in ResidualBank
574
+ if pre_merge_state is None:
575
+ try:
576
+ from .merge import ResidualBank
577
+ bank = ResidualBank(cfg)
578
+ if bank.residual_index:
579
+ # Get the most recent merge stage
580
+ last_stage = list(bank.residual_index.keys())[-1]
581
+ print(f"[heal] Smart discovery: loading residuals from merge stage '{last_stage}'")
582
+ # Note: bank saves (original - merged), we want (merged - original)
583
+ # So we'll pass the negative of the saved target residual
584
+ target_res, _ = bank.load_residuals(last_stage)
585
+ pre_merge_state = {}
586
+ # We can't easily reconstruct pre_merge_state without base weights,
587
+ # but we can pass Ξ”W directly if we modify apply_residual_frozen_adaptation.
588
+ # For now, let's assume we can't reconstruct but we CAN use the cache.
589
+ except ImportError:
590
+ pass
591
+
592
+ return apply_residual_frozen_adaptation(
593
+ model_path, cfg, pre_merge_state, healing_data
594
+ )
595
+
596
+ # Standard QLoRA fallback
597
+ if check_unsloth_available():
598
+ return apply_qlora_unsloth(model_path, cfg, healing_data)
599
+ else:
600
+ return apply_qlora_standard(model_path, cfg, healing_data)
hugging/td_lang/engine/merge.py ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sequential Merge Orchestrator β€” chains 4 merges with protection.
3
+
4
+ This is the brain of td_lang engine. It runs each merge in order:
5
+ 1. Load source model
6
+ 2. Inject canary fact into source
7
+ 3. Extract activations from both models
8
+ 4. Compute transport plans (P and Q matrices)
9
+ 5. Fuse weights using optimal transport
10
+ 6. Validate merged model (canary recall, perplexity, thinking mode)
11
+ 7. Apply sequential merge protection before next merge
12
+ 8. Checkpoint
13
+
14
+ Protection between merges (findings #13):
15
+ - MagMax: Protect top 20% parameters by magnitude (they carry critical knowledge)
16
+ - Orthogonal Projection: Project new merge deltas perpendicular to previous ones
17
+ - Time-Aware Scaling: scale = 1/sqrt(merge_index + 1)
18
+
19
+ Kill criteria: >10% performance drop on any test β†’ abort merge.
20
+ Findings: #13, #22, #25
21
+ """
22
+
23
+ import os
24
+ import gc
25
+ import copy
26
+ import torch
27
+ import numpy as np
28
+ from pathlib import Path
29
+ from typing import Optional
30
+ from transformers import AutoModelForCausalLM, AutoTokenizer
31
+
32
+ from .config import (
33
+ MergeConfig, ModelConfig, TARGET, SOURCES,
34
+ CANARY_FACTS, DEMO_STAGES, FULL_STAGES,
35
+ )
36
+ from .canary import inject_canary, test_all_canaries
37
+ from .transport import (
38
+ setup_tm_repo,
39
+ load_calibration_data,
40
+ extract_activations,
41
+ compute_transport_plans,
42
+ fuse_weights,
43
+ )
44
+ from .validate import validate_merged_model, compute_perplexity
45
+ from .techniques import (
46
+ compute_mergeability_score,
47
+ compute_transferability_masks,
48
+ apply_masked_merge,
49
+ disentangle_rl_weights,
50
+ merge_with_rl_preservation,
51
+ compute_arm_rotation,
52
+ apply_arm_steering,
53
+ transport_task_vector_theseus,
54
+ compute_procrustes_alignment,
55
+ )
56
+
57
+
58
+ # ============================================================================
59
+ # SEQUENTIAL MERGE PROTECTION
60
+ # ============================================================================
61
+
62
+ class MergeProtection:
63
+ """
64
+ Protects previously merged knowledge from being overwritten.
65
+
66
+ Think of it like this: after merging DeepSeek into Qwen3, we have
67
+ a "direction" in weight space that represents that merge. When we
68
+ then merge MiMo, we want MiMo's changes to go in a DIFFERENT direction,
69
+ not overwrite DeepSeek's contribution.
70
+
71
+ Three mechanisms:
72
+ 1. MagMax: Top 20% magnitude params are "locked" β€” new merges can't change them much
73
+ 2. Orthogonal Projection: New deltas are projected perpendicular to previous deltas
74
+ 3. Time-Aware Scaling: Each successive merge gets a smaller alpha (1/sqrt(n+1))
75
+ """
76
+
77
+ def __init__(self, cfg: MergeConfig):
78
+ self.cfg = cfg
79
+ self.previous_deltas = {} # key β†’ list of delta tensors from previous merges
80
+ self.magnitude_masks = {} # key β†’ bool mask of top-k magnitude params
81
+ self.arm_rotations = {} # ARM: layer β†’ rotation info from last merge
82
+ self.otmf_masks = {} # OTMF: param β†’ transferability mask
83
+ self.merge_count = 0
84
+
85
+ def before_merge(
86
+ self,
87
+ target_model: AutoModelForCausalLM,
88
+ source_config: ModelConfig,
89
+ ) -> float:
90
+ """
91
+ Prepare protection before a merge. Returns adjusted alpha.
92
+
93
+ Called BEFORE each merge to:
94
+ 1. Compute magnitude masks (MagMax)
95
+ 2. Calculate time-aware alpha scaling
96
+ """
97
+ # Time-aware scaling: each merge gets less aggressive
98
+ if self.cfg.time_aware_scaling:
99
+ scale = 1.0 / np.sqrt(self.merge_count + 1)
100
+ adjusted_alpha = source_config.merge_alpha * scale
101
+ print(f"[protect] Time-aware scaling: {source_config.merge_alpha:.2f} Γ— {scale:.3f} = {adjusted_alpha:.3f}")
102
+ else:
103
+ adjusted_alpha = source_config.merge_alpha
104
+
105
+ # MagMax: identify top 20% magnitude parameters to protect
106
+ if self.cfg.use_magmax and self.merge_count > 0:
107
+ print(f"[protect] Computing MagMax masks (protecting top 20% by magnitude)...")
108
+ state = target_model.state_dict()
109
+ for key, param in state.items():
110
+ if param.dim() >= 1:
111
+ flat = param.abs().flatten()
112
+ threshold = torch.quantile(flat.float(), 0.8)
113
+ self.magnitude_masks[key] = param.abs() >= threshold
114
+
115
+ return adjusted_alpha
116
+
117
+ def apply_protection(
118
+ self,
119
+ target_state: dict,
120
+ pre_merge_state: dict,
121
+ key: str,
122
+ ) -> torch.Tensor:
123
+ """
124
+ Apply all protection mechanisms to a fused parameter.
125
+
126
+ Called AFTER each parameter is fused, to constrain the change.
127
+
128
+ Protection stack (applied in order):
129
+ 1. ARM steering (2602.03237) β€” steer delta toward gap, away from previous direction
130
+ 2. Orthogonal projection (legacy fallback if ARM disabled)
131
+ 3. OTMF masks (2511.19561) β€” protect task-specific weights
132
+ 4. MagMax β€” protect top magnitude params (extra safety layer)
133
+ """
134
+ fused = target_state[key]
135
+ original = pre_merge_state[key]
136
+ delta = fused - original
137
+
138
+ # --- ARM Steering (new, replaces orthogonal projection) ---
139
+ if self.cfg.use_arm_steering and self.arm_rotations:
140
+ # Find matching layer rotation
141
+ layer_prefix = ".".join(key.split(".")[:4])
142
+ for layer_name, rotation_info in self.arm_rotations.items():
143
+ if layer_prefix in layer_name:
144
+ delta = apply_arm_steering(
145
+ delta, rotation_info,
146
+ steering_strength=self.cfg.arm_steering_strength,
147
+ )
148
+ break
149
+
150
+ # --- Orthogonal Projection (legacy fallback) ---
151
+ elif self.cfg.use_orthogonal_projection and key in self.previous_deltas:
152
+ for prev_delta in self.previous_deltas[key]:
153
+ prev_flat = prev_delta.flatten().float()
154
+ delta_flat = delta.flatten().float()
155
+
156
+ dot = torch.dot(delta_flat, prev_flat)
157
+ norm_sq = torch.dot(prev_flat, prev_flat)
158
+
159
+ if norm_sq > 1e-10:
160
+ projection = (dot / norm_sq) * prev_flat
161
+ delta_flat = delta_flat - projection
162
+ delta = delta_flat.reshape(delta.shape).to(delta.dtype)
163
+
164
+ # --- OTMF Mask Protection (new) ---
165
+ if self.cfg.use_otmf_masks and key in self.otmf_masks:
166
+ mask = self.otmf_masks[key].to(delta.device)
167
+ # Transferable weights: full delta
168
+ # Task-specific weights: reduced delta (protect them)
169
+ delta = torch.where(
170
+ mask,
171
+ delta, # Transferable β†’ allow full change
172
+ delta * (1.0 - self.cfg.otmf_protect_strength), # Protected β†’ reduced
173
+ )
174
+
175
+ # --- MagMax Protection (extra safety layer) ---
176
+ if self.cfg.use_magmax and key in self.magnitude_masks:
177
+ mask = self.magnitude_masks[key]
178
+ delta = torch.where(mask, delta * 0.1, delta)
179
+
180
+ # Apply constrained delta
181
+ result = original + delta
182
+
183
+ return result
184
+
185
+ def after_merge(
186
+ self,
187
+ target_model: AutoModelForCausalLM,
188
+ pre_merge_state: dict,
189
+ pre_merge_activations: dict = None,
190
+ post_merge_activations: dict = None,
191
+ ):
192
+ """
193
+ Record the merge delta and compute protections for next merge.
194
+
195
+ Called AFTER each merge completes successfully.
196
+ Now also computes:
197
+ - ARM rotation vectors for next merge steering
198
+ - OTMF transferability masks for next merge
199
+ """
200
+ current_state = target_model.state_dict()
201
+
202
+ for key in current_state:
203
+ if key in pre_merge_state:
204
+ delta = current_state[key].float() - pre_merge_state[key].float()
205
+ if delta.abs().max() > 1e-8:
206
+ if key not in self.previous_deltas:
207
+ self.previous_deltas[key] = []
208
+ if len(self.previous_deltas[key]) >= 2:
209
+ self.previous_deltas[key].pop(0)
210
+ self.previous_deltas[key].append(delta.cpu())
211
+
212
+ # --- Compute ARM rotations for next merge ---
213
+ if self.cfg.use_arm_steering and pre_merge_activations and post_merge_activations:
214
+ print("[protect] Computing ARM rotation vectors for next merge...")
215
+ self.arm_rotations = compute_arm_rotation(
216
+ pre_merge_activations,
217
+ post_merge_activations,
218
+ post_merge_activations, # Target = current state (for gap calculation)
219
+ )
220
+
221
+ # --- Compute OTMF masks for next merge ---
222
+ if self.cfg.use_otmf_masks and post_merge_activations:
223
+ print("[protect] Computing OTMF transferability masks...")
224
+ self.otmf_masks = compute_transferability_masks(
225
+ target_model,
226
+ post_merge_activations,
227
+ threshold=self.cfg.otmf_threshold,
228
+ )
229
+
230
+ self.merge_count += 1
231
+ print(f"[protect] Recorded merge delta #{self.merge_count} (ARM + OTMF ready for next)")
232
+
233
+
234
+ # ============================================================================
235
+ # MAIN ORCHESTRATOR
236
+ # ============================================================================
237
+
238
+ def is_vision_param(key: str, cfg: MergeConfig) -> bool:
239
+ """
240
+ Check if a parameter belongs to the vision encoder.
241
+
242
+ Qwen3-VL-8B has a ViT vision encoder + merger projection on top of the
243
+ language model. We NEVER touch these during merging β€” they give us
244
+ browser agent and image understanding abilities for free.
245
+
246
+ Vision params start with prefixes like "visual." or "merger."
247
+ Language params start with "model.layers." or "model.embed_tokens." etc.
248
+ """
249
+ for prefix in cfg.vision_skip_prefixes:
250
+ if key.startswith(prefix):
251
+ return True
252
+ return False
253
+
254
+
255
+ def get_source_by_stage(stage_name: str) -> Optional[ModelConfig]:
256
+ """Get model config by stage name."""
257
+ stage_map = {
258
+ "deepseek": 0,
259
+ "mimo": 1,
260
+ "llama": 2,
261
+ "falcon": 3,
262
+ }
263
+ idx = stage_map.get(stage_name.lower())
264
+ if idx is not None and idx < len(SOURCES):
265
+ return SOURCES[idx]
266
+ return None
267
+
268
+
269
+ def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
270
+ """Load a model and its tokenizer/processor."""
271
+ print(f"\n[merge] Loading {config.name} ({config.hf_id})...")
272
+
273
+ # Qwen3-VL uses a processor (handles both text + vision), not just a tokenizer
274
+ if config.architecture == "transformer+vision":
275
+ try:
276
+ from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
277
+ processor = AutoProcessor.from_pretrained(
278
+ config.hf_id,
279
+ trust_remote_code=config.trust_remote_code,
280
+ )
281
+ model = Qwen3VLForConditionalGeneration.from_pretrained(
282
+ config.hf_id,
283
+ torch_dtype=getattr(torch, cfg.dtype),
284
+ attn_implementation=cfg.attn_implementation,
285
+ device_map=cfg.device_map,
286
+ trust_remote_code=config.trust_remote_code,
287
+ )
288
+ # Use the tokenizer from the processor for text operations
289
+ tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor
290
+ print(f"[merge] Loaded {config.name} (VL model): {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
291
+
292
+ # Count vision vs language params
293
+ vision_params = sum(
294
+ p.numel() for n, p in model.named_parameters()
295
+ if any(n.startswith(pfx) for pfx in cfg.vision_skip_prefixes)
296
+ )
297
+ lang_params = sum(p.numel() for p in model.parameters()) - vision_params
298
+ print(f"[merge] Language: {lang_params / 1e9:.1f}B | Vision: {vision_params / 1e9:.1f}B")
299
+
300
+ return model, tokenizer
301
+ except ImportError:
302
+ print("[merge] Qwen3VLForConditionalGeneration not available, falling back to AutoModel")
303
+
304
+ # Standard text-only models
305
+ tokenizer = AutoTokenizer.from_pretrained(
306
+ config.hf_id,
307
+ trust_remote_code=config.trust_remote_code,
308
+ )
309
+
310
+ model = AutoModelForCausalLM.from_pretrained(
311
+ config.hf_id,
312
+ torch_dtype=getattr(torch, cfg.dtype),
313
+ attn_implementation=cfg.attn_implementation,
314
+ device_map=cfg.device_map,
315
+ trust_remote_code=config.trust_remote_code,
316
+ )
317
+
318
+ print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
319
+ return model, tokenizer
320
+
321
+
322
+ def save_checkpoint(
323
+ model: AutoModelForCausalLM,
324
+ tokenizer: AutoTokenizer,
325
+ stage_name: str,
326
+ cfg: MergeConfig,
327
+ ):
328
+ """Save a checkpoint after a successful merge stage."""
329
+ ckpt_dir = Path(cfg.checkpoint_dir) / f"after_{stage_name}"
330
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
331
+
332
+ print(f"[merge] Saving checkpoint to {ckpt_dir}...")
333
+ model.save_pretrained(ckpt_dir)
334
+ tokenizer.save_pretrained(ckpt_dir)
335
+ print(f"[merge] Checkpoint saved: {ckpt_dir}")
336
+
337
+ return str(ckpt_dir)
338
+
339
+
340
+ # ============================================================================
341
+ # RESIDUAL BANK β€” Save what was lost during each merge
342
+ # ============================================================================
343
+
344
+ class ResidualBank:
345
+ """
346
+ Saves the knowledge that gets lost during each merge so it can
347
+ be recovered later.
348
+
349
+ When we blend at alpha=0.10:
350
+ merged = target + alpha * M * (transported - target)
351
+
352
+ We LOSE:
353
+ target_residual = target_original - merged (what target lost)
354
+ source_residual = source_original - merged (what source lost)
355
+
356
+ These residuals are saved to disk. Later they can be:
357
+ 1. Fed back during the healing fine-tune (as training signal)
358
+ 2. Re-injected via a small LoRA adapter
359
+ 3. Used to diagnose which merge caused a specific knowledge loss
360
+ 4. Re-applied at a lower alpha if we want more of that model
361
+
362
+ Think of it like saving the sawdust when you cut wood β€” you might
363
+ need to glue some of it back later.
364
+ """
365
+
366
+ def __init__(self, cfg: MergeConfig):
367
+ self.cfg = cfg
368
+ self.residual_dir = Path(cfg.checkpoint_dir) / "residuals"
369
+ self.residual_dir.mkdir(parents=True, exist_ok=True)
370
+ self.residual_index = {} # stage β†’ {path, stats}
371
+
372
+ def save_residuals(
373
+ self,
374
+ stage_name: str,
375
+ pre_merge_target_state: dict,
376
+ source_state: dict,
377
+ post_merge_state: dict,
378
+ source_config: ModelConfig,
379
+ ):
380
+ """
381
+ Compute and save what was lost from both target and source.
382
+
383
+ Saves two files per merge stage:
384
+ - target_residual: what the target model lost
385
+ - source_residual: what the source model didn't fully contribute
386
+
387
+ Also saves stats so we know WHERE the biggest losses were
388
+ (which layers, which type of weights).
389
+ """
390
+ stage_dir = self.residual_dir / stage_name
391
+ stage_dir.mkdir(parents=True, exist_ok=True)
392
+
393
+ target_residual = {}
394
+ source_residual = {}
395
+ stats = {
396
+ "stage": stage_name,
397
+ "source_model": source_config.name,
398
+ "target_loss_by_layer": {},
399
+ "source_loss_by_layer": {},
400
+ "total_target_loss": 0.0,
401
+ "total_source_loss": 0.0,
402
+ "biggest_losses": [],
403
+ }
404
+
405
+ for key in post_merge_state:
406
+ merged_w = post_merge_state[key].float()
407
+
408
+ # What the target lost
409
+ if key in pre_merge_target_state:
410
+ original_target = pre_merge_target_state[key].float()
411
+ t_residual = original_target - merged_w
412
+ t_loss = t_residual.abs().mean().item()
413
+
414
+ if t_loss > 1e-6: # Only save meaningful residuals
415
+ target_residual[key] = t_residual.to(torch.bfloat16).cpu()
416
+ stats["total_target_loss"] += t_loss
417
+
418
+ # Track per-layer losses
419
+ layer_name = ".".join(key.split(".")[:4])
420
+ if layer_name not in stats["target_loss_by_layer"]:
421
+ stats["target_loss_by_layer"][layer_name] = 0.0
422
+ stats["target_loss_by_layer"][layer_name] += t_loss
423
+
424
+ # What the source lost (what didn't make it into the merge)
425
+ if key in source_state:
426
+ original_source = source_state[key].float()
427
+ s_residual = original_source - merged_w
428
+ s_loss = s_residual.abs().mean().item()
429
+
430
+ if s_loss > 1e-6:
431
+ source_residual[key] = s_residual.to(torch.bfloat16).cpu()
432
+ stats["total_source_loss"] += s_loss
433
+
434
+ layer_name = ".".join(key.split(".")[:4])
435
+ if layer_name not in stats["source_loss_by_layer"]:
436
+ stats["source_loss_by_layer"][layer_name] = 0.0
437
+ stats["source_loss_by_layer"][layer_name] += s_loss
438
+
439
+ # Find the biggest losses (most knowledge dropped)
440
+ all_losses = []
441
+ for key in target_residual:
442
+ loss_magnitude = target_residual[key].float().abs().mean().item()
443
+ all_losses.append({"param": key, "side": "target", "loss": loss_magnitude})
444
+ for key in source_residual:
445
+ loss_magnitude = source_residual[key].float().abs().mean().item()
446
+ all_losses.append({"param": key, "side": "source", "loss": loss_magnitude})
447
+ all_losses.sort(key=lambda x: x["loss"], reverse=True)
448
+ stats["biggest_losses"] = all_losses[:20] # Top 20 biggest losses
449
+
450
+ # Save to disk
451
+ torch.save(target_residual, stage_dir / "target_residual.pt")
452
+ torch.save(source_residual, stage_dir / "source_residual.pt")
453
+
454
+ import json
455
+ with open(stage_dir / "residual_stats.json", "w") as f:
456
+ json.dump(stats, f, indent=2, default=str)
457
+
458
+ self.residual_index[stage_name] = {
459
+ "path": str(stage_dir),
460
+ "target_params_saved": len(target_residual),
461
+ "source_params_saved": len(source_residual),
462
+ "total_target_loss": stats["total_target_loss"],
463
+ "total_source_loss": stats["total_source_loss"],
464
+ }
465
+
466
+ print(f"[residual] Saved residuals for {stage_name}:")
467
+ print(f" Target lost: {len(target_residual)} params (avg loss: {stats['total_target_loss']:.4f})")
468
+ print(f" Source lost: {len(source_residual)} params (avg loss: {stats['total_source_loss']:.4f})")
469
+ print(f" Top loss: {all_losses[0]['param']} ({all_losses[0]['side']}, {all_losses[0]['loss']:.4f})" if all_losses else "")
470
+ print(f" Saved to: {stage_dir}")
471
+
472
+ def load_residuals(self, stage_name: str) -> tuple:
473
+ """
474
+ Load saved residuals for a stage.
475
+
476
+ Returns:
477
+ (target_residual_dict, source_residual_dict)
478
+ """
479
+ stage_dir = self.residual_dir / stage_name
480
+ target_residual = torch.load(stage_dir / "target_residual.pt", weights_only=True)
481
+ source_residual = torch.load(stage_dir / "source_residual.pt", weights_only=True)
482
+ return target_residual, source_residual
483
+
484
+ def reinject_residuals(
485
+ self,
486
+ model: AutoModelForCausalLM,
487
+ stage_name: str,
488
+ side: str = "both",
489
+ strength: float = 0.3,
490
+ ) -> AutoModelForCausalLM:
491
+ """
492
+ Re-inject saved residuals back into a model.
493
+
494
+ This adds back some of what was lost. Use a low strength (0.1-0.3)
495
+ to gently recover knowledge without undoing the merge.
496
+
497
+ Args:
498
+ model: The model to inject into
499
+ stage_name: Which merge stage's residuals to use
500
+ side: "target", "source", or "both"
501
+ strength: How much to add back (0=nothing, 1=full residual)
502
+ """
503
+ print(f"[residual] Re-injecting {stage_name} residuals (side={side}, strength={strength})...")
504
+
505
+ target_residual, source_residual = self.load_residuals(stage_name)
506
+ state = model.state_dict()
507
+ injected = 0
508
+
509
+ if side in ("target", "both"):
510
+ for key, residual in target_residual.items():
511
+ if key in state:
512
+ state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
513
+ injected += 1
514
+
515
+ if side in ("source", "both"):
516
+ for key, residual in source_residual.items():
517
+ if key in state:
518
+ state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
519
+ injected += 1
520
+
521
+ model.load_state_dict(state)
522
+ print(f"[residual] Re-injected {injected} params at {strength:.0%} strength")
523
+ return model
524
+
525
+ def get_healing_targets(self, top_n: int = 50) -> list:
526
+ """
527
+ Get the parameters with the biggest losses across ALL merges.
528
+
529
+ These are the params that the healing fine-tune should focus on.
530
+ Feed this to the LoRA target_modules to make healing smarter.
531
+ """
532
+ import json
533
+ all_losses = []
534
+
535
+ for stage_name in self.residual_index:
536
+ stage_dir = self.residual_dir / stage_name
537
+ stats_file = stage_dir / "residual_stats.json"
538
+ if stats_file.exists():
539
+ with open(stats_file) as f:
540
+ stats = json.load(f)
541
+ for loss in stats.get("biggest_losses", []):
542
+ loss["stage"] = stage_name
543
+ all_losses.append(loss)
544
+
545
+ all_losses.sort(key=lambda x: x["loss"], reverse=True)
546
+
547
+ # Extract unique layer/module names for LoRA targeting
548
+ target_modules = set()
549
+ for loss in all_losses[:top_n]:
550
+ param = loss["param"]
551
+ # Extract the module type (q_proj, k_proj, gate_proj, etc.)
552
+ parts = param.split(".")
553
+ for part in parts:
554
+ if part.endswith("_proj") or part in ("gate_proj", "up_proj", "down_proj"):
555
+ target_modules.add(part)
556
+
557
+ print(f"[residual] Top healing targets (from {len(all_losses)} total losses):")
558
+ for loss in all_losses[:5]:
559
+ print(f" {loss['param']} ({loss['side']}, stage={loss['stage']}, loss={loss['loss']:.4f})")
560
+ print(f" β†’ Suggested LoRA targets: {sorted(target_modules)}")
561
+
562
+ return list(target_modules)
563
+
564
+
565
+ def run_single_merge(
566
+ target_model: AutoModelForCausalLM,
567
+ target_tokenizer: AutoTokenizer,
568
+ source_config: ModelConfig,
569
+ cfg: MergeConfig,
570
+ protection: MergeProtection,
571
+ residual_bank: ResidualBank = None,
572
+ calibration_data: list = None,
573
+ baseline_perplexity: float = None,
574
+ merged_sources: list = None,
575
+ ) -> dict:
576
+ """
577
+ Run a single merge: source β†’ target.
578
+
579
+ Full pipeline for one merge step:
580
+ 1. Load source model
581
+ 2. Inject canary into source
582
+ 3. Extract activations from both
583
+ 4. Compute transport plans
584
+ 5. Apply merge protection
585
+ 6. Fuse weights
586
+ 7. Apply post-merge protection
587
+ 8. Validate
588
+
589
+ Returns:
590
+ Dict with merge results, validation results, and status
591
+ """
592
+ if merged_sources is None:
593
+ merged_sources = []
594
+
595
+ stage_name = source_config.name
596
+ print(f"\n{'=' * 70}")
597
+ print(f"MERGE STAGE: {stage_name} β†’ target")
598
+ print(f"Risk level: {source_config.merge_risk.upper()}")
599
+ print(f"{'=' * 70}")
600
+
601
+ result = {
602
+ "stage": stage_name,
603
+ "status": "pending",
604
+ "validation": None,
605
+ "checkpoint": None,
606
+ }
607
+
608
+ # --- Step 1: Load source model ---
609
+ source_model, source_tokenizer = load_model(source_config, cfg)
610
+
611
+ # --- Step 2: Inject canary into source ---
612
+ if stage_name in CANARY_FACTS:
613
+ print(f"\n[merge] Injecting canary fact into {stage_name}...")
614
+ source_model = inject_canary(source_model, source_tokenizer, stage_name)
615
+
616
+ # --- Step 3: Load calibration data (if not provided) ---
617
+ if calibration_data is None:
618
+ calibration_data = load_calibration_data(cfg, target_tokenizer)
619
+
620
+ # --- Step 4: Extract two-sided activations (pre + post per projection) ---
621
+ print(f"\n[merge] Extracting source activations (two-sided)...")
622
+ source_activations = extract_activations(source_model, calibration_data)
623
+
624
+ print(f"\n[merge] Extracting target activations (two-sided)...")
625
+ pre_merge_target_activations = extract_activations(target_model, calibration_data)
626
+
627
+ # --- Step 4.5: Mergeability pre-check (2601.22285) ---
628
+ if cfg.use_mergeability_check:
629
+ mergeability = compute_mergeability_score(
630
+ source_activations, pre_merge_target_activations, source_config
631
+ )
632
+ result["mergeability"] = mergeability
633
+
634
+ if mergeability["overall"] < cfg.mergeability_min_score:
635
+ print(f"\n[merge] ⚠ Mergeability score {mergeability['overall']:.2f} below threshold {cfg.mergeability_min_score}")
636
+ print(f"[merge] β†’ {mergeability['recommendation']}")
637
+ result["status"] = "skipped_low_mergeability"
638
+ if "distillation_fallback" in source_config.special_handling:
639
+ result["fallback"] = "distillation"
640
+ del source_model, source_activations, pre_merge_target_activations
641
+ gc.collect()
642
+ if torch.cuda.is_available():
643
+ torch.cuda.empty_cache()
644
+ return result
645
+
646
+ # --- Step 5: Compute transport plans ---
647
+ transport_plans = compute_transport_plans(
648
+ source_activations, pre_merge_target_activations, cfg
649
+ )
650
+
651
+ # --- Step 5.5: RAM RL-weight disentanglement (2601.13572) ---
652
+ use_ram = (
653
+ cfg.use_ram_disentangle
654
+ and source_config.architecture in ("transformer", "transformer+mtp")
655
+ and source_config.merge_risk in ("low", "medium")
656
+ and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"])
657
+ )
658
+
659
+ # --- Step 6: Pre-merge protection ---
660
+ adjusted_alpha = protection.before_merge(target_model, source_config)
661
+
662
+ # Override source alpha with time-adjusted value
663
+ source_config_adjusted = copy.copy(source_config)
664
+ source_config_adjusted.merge_alpha = adjusted_alpha
665
+
666
+ # Save pre-merge state for protection
667
+ pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()}
668
+
669
+ # --- Step 7: Fuse weights ---
670
+ if use_ram:
671
+ # RAM path: disentangle RL weights, merge with preservation
672
+ print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
673
+ try:
674
+ # Try loading the base (pre-RL) model for disentanglement
675
+ base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
676
+ print(f"[merge] Loading base model for RAM: {base_hf_id}")
677
+ base_model = AutoModelForCausalLM.from_pretrained(
678
+ base_hf_id,
679
+ torch_dtype=getattr(torch, cfg.dtype),
680
+ device_map=cfg.device_map,
681
+ trust_remote_code=source_config.trust_remote_code,
682
+ )
683
+ shared_mask, rl_mask = disentangle_rl_weights(
684
+ source_model, base_model, cfg.ram_rl_threshold
685
+ )
686
+ # Fuse with RL preservation
687
+ target_state = merge_with_rl_preservation(
688
+ target_model.state_dict(),
689
+ source_model.state_dict(),
690
+ shared_mask, rl_mask,
691
+ shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha),
692
+ rl_alpha=cfg.ram_rl_alpha,
693
+ )
694
+ target_model.load_state_dict(target_state)
695
+ del base_model
696
+ print(f"[merge] RAM merge complete for {stage_name}")
697
+ except Exception as e:
698
+ print(f"[merge] RAM failed ({e}), falling back to standard T&M merge")
699
+ target_model = fuse_weights(
700
+ source_model, target_model, transport_plans,
701
+ source_config_adjusted, cfg,
702
+ target_activations=pre_merge_target_activations,
703
+ )
704
+ else:
705
+ # Standard T&M path (two-sided + top-k masked fusion, paper Eq 14)
706
+ target_model = fuse_weights(
707
+ source_model, target_model, transport_plans,
708
+ source_config_adjusted, cfg,
709
+ target_activations=pre_merge_target_activations,
710
+ )
711
+
712
+ # --- Step 7.5: Theseus fallback check (2602.12952) ---
713
+ # If T&M merge produced poor activation alignment, try Theseus
714
+ if cfg.use_theseus_fallback and source_config.merge_risk == "high":
715
+ print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...")
716
+ post_activations = extract_activations(target_model, calibration_data[:50]) # Quick check
717
+ # Compare post-merge activations to pre-merge β€” if too similar, T&M didn't work
718
+ alignment_scores = []
719
+ for key in post_activations:
720
+ if key in pre_merge_target_activations:
721
+ cos = torch.nn.functional.cosine_similarity(
722
+ post_activations[key].float().mean(0, keepdim=True),
723
+ pre_merge_target_activations[key].float().mean(0, keepdim=True),
724
+ )
725
+ alignment_scores.append(cos.item())
726
+ avg_change = 1.0 - np.mean(alignment_scores) if alignment_scores else 0.0
727
+ print(f"[merge] Activation change from merge: {avg_change:.4f}")
728
+
729
+ if avg_change < 0.01:
730
+ print(f"[merge] ⚠ T&M had minimal effect β€” activating Theseus fallback")
731
+ # Restore pre-merge state and try Theseus instead
732
+ target_model.load_state_dict(pre_merge_state)
733
+ try:
734
+ base_model = AutoModelForCausalLM.from_pretrained(
735
+ source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0],
736
+ torch_dtype=getattr(torch, cfg.dtype),
737
+ device_map=cfg.device_map,
738
+ trust_remote_code=source_config.trust_remote_code,
739
+ )
740
+ target_model = transport_task_vector_theseus(
741
+ source_model, base_model, target_model,
742
+ source_activations, pre_merge_target_activations,
743
+ alpha=cfg.theseus_alpha,
744
+ )
745
+ del base_model
746
+ print(f"[merge] Theseus transport complete for {stage_name}")
747
+ except Exception as e:
748
+ print(f"[merge] Theseus also failed ({e}). Using original T&M result.")
749
+ # Re-apply T&M result
750
+ target_model = fuse_weights(
751
+ source_model, target_model, transport_plans,
752
+ source_config_adjusted, cfg,
753
+ target_activations=pre_merge_target_activations,
754
+ )
755
+
756
+ # --- Step 8: Apply post-merge protection (ARM + OTMF + MagMax) ---
757
+ # Skip vision encoder params β€” they weren't merged, so don't "protect" them
758
+ if protection.merge_count > 0:
759
+ print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...")
760
+ target_state = target_model.state_dict()
761
+ protected_count = 0
762
+ vision_skipped = 0
763
+ for key in target_state:
764
+ if is_vision_param(key, cfg):
765
+ vision_skipped += 1
766
+ continue # Don't touch vision encoder
767
+ if key in pre_merge_state:
768
+ protected_param = protection.apply_protection(
769
+ target_state, pre_merge_state, key
770
+ )
771
+ target_state[key] = protected_param
772
+ protected_count += 1
773
+ target_model.load_state_dict(target_state)
774
+ print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)")
775
+
776
+ # --- Step 8.5: Extract post-merge activations for ARM/OTMF ---
777
+ post_merge_activations = extract_activations(target_model, calibration_data[:100])
778
+
779
+ # Record this merge's delta + compute ARM/OTMF for next merge
780
+ protection.after_merge(
781
+ target_model, pre_merge_state,
782
+ pre_merge_activations=pre_merge_target_activations,
783
+ post_merge_activations=post_merge_activations,
784
+ )
785
+
786
+ # --- Step 8.8: Save residuals (what was lost from both sides) ---
787
+ if residual_bank is not None:
788
+ print(f"\n[merge] Saving residuals for {stage_name}...")
789
+ residual_bank.save_residuals(
790
+ stage_name=stage_name,
791
+ pre_merge_target_state=pre_merge_state,
792
+ source_state={k: v.cpu() for k, v in source_model.state_dict().items()},
793
+ post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()},
794
+ source_config=source_config,
795
+ )
796
+
797
+ # --- Step 9: Free source model memory ---
798
+ del source_model, source_activations, pre_merge_target_activations
799
+ del transport_plans, post_merge_activations
800
+ gc.collect()
801
+ if torch.cuda.is_available():
802
+ torch.cuda.empty_cache()
803
+
804
+ # --- Step 10: Validate ---
805
+ merged_sources.append(stage_name)
806
+ validation = validate_merged_model(
807
+ target_model, target_tokenizer,
808
+ merged_sources, cfg,
809
+ baseline_perplexity=baseline_perplexity,
810
+ )
811
+
812
+ result["validation"] = validation
813
+ result["merged_sources"] = merged_sources.copy()
814
+
815
+ # --- Kill criteria check ---
816
+ if not validation["overall"]:
817
+ print(f"\n[merge] ⚠ VALIDATION FAILED for {stage_name}")
818
+ print(f"[merge] Kill criteria triggered β€” consider aborting")
819
+ result["status"] = "failed"
820
+
821
+ # Check if we should try distillation fallback
822
+ if "distillation_fallback" in source_config.special_handling:
823
+ print(f"[merge] {stage_name} has distillation fallback available")
824
+ result["fallback"] = "distillation"
825
+ else:
826
+ print(f"\n[merge] βœ“ {stage_name} merge PASSED validation")
827
+ result["status"] = "passed"
828
+
829
+ return result
830
+
831
+
832
+ def run_pipeline(
833
+ stages: list[str],
834
+ cfg: MergeConfig = None,
835
+ ) -> dict:
836
+ """
837
+ Run the full merge pipeline.
838
+
839
+ Args:
840
+ stages: List of stage names to run, e.g. ["deepseek"] or
841
+ ["deepseek", "mimo", "llama", "falcon"]
842
+ cfg: Merge configuration (uses defaults if None)
843
+
844
+ Returns:
845
+ Dict with overall results, per-stage results, and final model path
846
+ """
847
+ if cfg is None:
848
+ cfg = MergeConfig()
849
+
850
+ print("\n" + "=" * 70)
851
+ print("TD LANG ENGINE β€” Transport and Merge Pipeline")
852
+ print(f"Target: {TARGET.name} ({TARGET.hf_id})")
853
+ if TARGET.architecture == "transformer+vision":
854
+ print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)")
855
+ print(f"Stages: {', '.join(stages)}")
856
+ print(f"Output: {cfg.output_dir}")
857
+ print("=" * 70)
858
+
859
+ # Setup
860
+ try:
861
+ setup_tm_repo(cfg)
862
+ except FileNotFoundError as e:
863
+ print(f"\n⚠ {e}")
864
+ print("Continuing with fallback implementation...")
865
+
866
+ # Create output directories
867
+ Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
868
+ Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
869
+
870
+ # --- Load target model ---
871
+ target_model, target_tokenizer = load_model(TARGET, cfg)
872
+
873
+ # --- Inject canary into target (Qwen3's own canary) ---
874
+ if "Qwen3-VL-8B" in CANARY_FACTS:
875
+ print("\n[pipeline] Injecting canary into base Qwen3-8B...")
876
+ target_model = inject_canary(target_model, target_tokenizer, "Qwen3-VL-8B")
877
+
878
+ # --- Compute baseline perplexity ---
879
+ print("\n[pipeline] Computing baseline perplexity...")
880
+ baseline_ppl = compute_perplexity(target_model, target_tokenizer)
881
+ print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}")
882
+
883
+ # --- Load calibration data once ---
884
+ calibration_data = load_calibration_data(cfg, target_tokenizer)
885
+
886
+ # --- Initialize merge protection + residual bank ---
887
+ protection = MergeProtection(cfg)
888
+ residual_bank = ResidualBank(cfg)
889
+
890
+ # --- Run each merge stage ---
891
+ pipeline_results = {
892
+ "stages": {},
893
+ "baseline_perplexity": baseline_ppl,
894
+ "final_checkpoint": None,
895
+ "residuals": {},
896
+ "overall_status": "pending",
897
+ }
898
+ merged_sources = []
899
+ all_passed = True
900
+
901
+ for stage_name in stages:
902
+ source_config = get_source_by_stage(stage_name)
903
+ if source_config is None:
904
+ print(f"\n⚠ Unknown stage: {stage_name}, skipping")
905
+ continue
906
+
907
+ # --- Wasserstein pre-check for high-risk models ---
908
+ if "check_wasserstein_first" in source_config.special_handling:
909
+ print(f"\n[pipeline] Running Wasserstein pre-check for {source_config.name}...")
910
+ # TODO: Implement Wasserstein distance pre-check
911
+ # If distance is too high, skip to distillation fallback
912
+ print("[pipeline] Pre-check: proceeding (TODO: implement distance check)")
913
+
914
+ # Run the merge (with residual bank to save what's lost)
915
+ stage_result = run_single_merge(
916
+ target_model, target_tokenizer,
917
+ source_config, cfg,
918
+ protection,
919
+ residual_bank=residual_bank,
920
+ calibration_data=calibration_data,
921
+ baseline_perplexity=baseline_ppl,
922
+ merged_sources=merged_sources,
923
+ )
924
+
925
+ pipeline_results["stages"][stage_name] = stage_result
926
+
927
+ if stage_result["status"] == "passed":
928
+ # Save checkpoint
929
+ ckpt_path = save_checkpoint(
930
+ target_model, target_tokenizer, stage_name, cfg
931
+ )
932
+ stage_result["checkpoint"] = ckpt_path
933
+ pipeline_results["final_checkpoint"] = ckpt_path
934
+ else:
935
+ all_passed = False
936
+ print(f"\n[pipeline] Stage {stage_name} FAILED")
937
+
938
+ # Decision: abort or continue?
939
+ if source_config.merge_risk == "high":
940
+ print(f"[pipeline] High-risk model failed β€” skipping (will use distillation)")
941
+ # Don't abort the whole pipeline, just skip this model
942
+ continue
943
+ else:
944
+ print(f"[pipeline] ABORTING pipeline β€” non-high-risk model failed")
945
+ pipeline_results["overall_status"] = f"aborted_at_{stage_name}"
946
+ break
947
+
948
+ # --- Save residual index ---
949
+ pipeline_results["residuals"] = residual_bank.residual_index
950
+ if residual_bank.residual_index:
951
+ print(f"\n[pipeline] Residual bank: {len(residual_bank.residual_index)} stages saved")
952
+ for stage, info in residual_bank.residual_index.items():
953
+ print(f" {stage}: target lost {info['total_target_loss']:.4f}, source lost {info['total_source_loss']:.4f}")
954
+
955
+ # Identify which modules need the most healing
956
+ healing_targets = residual_bank.get_healing_targets(top_n=50)
957
+ pipeline_results["suggested_healing_targets"] = healing_targets
958
+
959
+ # --- Save final model ---
960
+ if pipeline_results["final_checkpoint"]:
961
+ final_dir = Path(cfg.output_dir) / "final"
962
+ final_dir.mkdir(parents=True, exist_ok=True)
963
+ target_model.save_pretrained(final_dir)
964
+ target_tokenizer.save_pretrained(final_dir)
965
+ pipeline_results["final_model_path"] = str(final_dir)
966
+ print(f"\n[pipeline] Final model saved to {final_dir}")
967
+
968
+ if all_passed:
969
+ pipeline_results["overall_status"] = "all_passed"
970
+ elif pipeline_results["overall_status"] == "pending":
971
+ pipeline_results["overall_status"] = "partial"
972
+
973
+ # --- Print final summary ---
974
+ print("\n" + "=" * 70)
975
+ print("PIPELINE SUMMARY")
976
+ print("=" * 70)
977
+ for stage_name, stage_result in pipeline_results["stages"].items():
978
+ status = stage_result["status"]
979
+ emoji = "βœ“" if status == "passed" else "βœ—"
980
+ print(f" {emoji} {stage_name}: {status}")
981
+ print(f"\n Overall: {pipeline_results['overall_status']}")
982
+ if residual_bank.residual_index:
983
+ print(f"\n Residuals saved for: {', '.join(residual_bank.residual_index.keys())}")
984
+ print(f" To recover lost knowledge later:")
985
+ print(f" python -m td_lang.engine --reinject <stage> --strength 0.2")
986
+ print("=" * 70)
987
+
988
+ return pipeline_results
hugging/td_lang/engine/run.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Fuse β€” Main Entry Point.
3
+
4
+ Usage:
5
+ # Dad demo: merge just DeepSeek β†’ Qwen3-8B (easiest, lowest risk)
6
+ python -m td_fuse.run --stage demo
7
+
8
+ # Full pipeline: all 4 merges
9
+ python -m td_fuse.run --stage all
10
+
11
+ # Single model merge
12
+ python -m td_fuse.run --stage deepseek
13
+ python -m td_fuse.run --stage mimo
14
+ python -m td_fuse.run --stage llama
15
+ python -m td_fuse.run --stage falcon
16
+
17
+ # With healing fine-tune after merge
18
+ python -m td_fuse.run --stage demo --heal
19
+
20
+ # Custom output directory
21
+ python -m td_fuse.run --stage all --output ./my_output
22
+
23
+ # Heal an existing checkpoint
24
+ python -m td_fuse.run --heal-only --model-path ./td_fuse_checkpoints/after_deepseek
25
+
26
+ Findings: #25 (dad demo plan), #22 (merge order), #24 (official T&M pipeline)
27
+ """
28
+
29
+ import argparse
30
+ import json
31
+ import sys
32
+ import time
33
+ from pathlib import Path
34
+
35
+ from .config import MergeConfig, DEMO_STAGES, FULL_STAGES
36
+ from .merge import run_pipeline, ResidualBank
37
+ from .heal import heal_model
38
+
39
+
40
+ def parse_args():
41
+ parser = argparse.ArgumentParser(
42
+ description="TD Fuse β€” Transport and Merge pipeline for Time Dilation",
43
+ formatter_class=argparse.RawDescriptionHelpFormatter,
44
+ epilog="""
45
+ Examples:
46
+ python -m td_fuse.run --stage demo # Dad demo (DeepSeek only)
47
+ python -m td_fuse.run --stage all # Full 4-model merge
48
+ python -m td_fuse.run --stage all --heal # Merge + healing fine-tune
49
+ python -m td_fuse.run --heal-only --model-path ./checkpoint
50
+ python -m td_fuse.run --reinject deepseek --strength 0.2 --model-path ./final
51
+ """,
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--stage",
56
+ type=str,
57
+ default="demo",
58
+ choices=["demo", "all", "deepseek", "mimo", "llama", "falcon"],
59
+ help="Which merge stage(s) to run (default: demo)",
60
+ )
61
+ parser.add_argument(
62
+ "--heal",
63
+ action="store_true",
64
+ help="Run healing fine-tune after merge",
65
+ )
66
+ parser.add_argument(
67
+ "--heal-only",
68
+ action="store_true",
69
+ help="Only run healing (skip merge), requires --model-path",
70
+ )
71
+ parser.add_argument(
72
+ "--model-path",
73
+ type=str,
74
+ default=None,
75
+ help="Path to existing model/checkpoint (for --heal-only)",
76
+ )
77
+ parser.add_argument(
78
+ "--output",
79
+ type=str,
80
+ default="./td_fuse_outputs",
81
+ help="Output directory (default: ./td_fuse_outputs)",
82
+ )
83
+ parser.add_argument(
84
+ "--checkpoint-dir",
85
+ type=str,
86
+ default="./td_fuse_checkpoints",
87
+ help="Checkpoint directory (default: ./td_fuse_checkpoints)",
88
+ )
89
+ parser.add_argument(
90
+ "--tm-repo",
91
+ type=str,
92
+ default="./Cross-Architecture-Merging-for-Large-Language-Models",
93
+ help="Path to official T&M repo",
94
+ )
95
+ parser.add_argument(
96
+ "--dry-run",
97
+ action="store_true",
98
+ help="Print what would happen without actually running",
99
+ )
100
+ parser.add_argument(
101
+ "--reinject",
102
+ type=str,
103
+ default=None,
104
+ help="Re-inject saved residuals from a stage (e.g., --reinject deepseek)",
105
+ )
106
+ parser.add_argument(
107
+ "--reinject-side",
108
+ type=str,
109
+ default="both",
110
+ choices=["target", "source", "both"],
111
+ help="Which side's residuals to re-inject (default: both)",
112
+ )
113
+ parser.add_argument(
114
+ "--strength",
115
+ type=float,
116
+ default=0.2,
117
+ help="Residual re-injection strength, 0-1 (default: 0.2)",
118
+ )
119
+
120
+ return parser.parse_args()
121
+
122
+
123
+ def print_banner():
124
+ """Print the TD Fuse banner."""
125
+ banner = """
126
+ ╔══════════════════════════════════════════════════╗
127
+ β•‘ β•‘
128
+ β•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ•— β–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β•‘
129
+ β•‘ β•šβ•β•β–ˆβ–ˆβ•”β•β•β•β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•— β–ˆβ–ˆβ•”β•β•β•β•β•β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ•”β•β•β•β•β• β•‘
130
+ β•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β•‘
131
+ β•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•”β•β•β• β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘β•šβ•β•β•β•β–ˆβ–ˆβ•‘ β•‘
132
+ β•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β• β–ˆβ–ˆβ•‘ β•šβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•‘ β•‘
133
+ β•‘ β•šβ•β• β•šβ•β•β•β•β•β• β•šβ•β• β•šβ•β•β•β•β•β• β•šβ•β•β•β•β•β•β• β•‘
134
+ β•‘ β•‘
135
+ β•‘ Transport and Merge for Time Dilation β•‘
136
+ β•‘ Merging 5 models into Qwen3-8B β•‘
137
+ β•‘ β•‘
138
+ β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•οΏ½οΏ½β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
139
+ """
140
+ print(banner)
141
+
142
+
143
+ def main():
144
+ args = parse_args()
145
+ print_banner()
146
+
147
+ # Build config from args
148
+ cfg = MergeConfig(
149
+ output_dir=args.output,
150
+ checkpoint_dir=args.checkpoint_dir,
151
+ tm_repo_path=args.tm_repo,
152
+ )
153
+
154
+ # Determine which stages to run
155
+ if args.stage == "demo":
156
+ stages = DEMO_STAGES
157
+ elif args.stage == "all":
158
+ stages = FULL_STAGES
159
+ else:
160
+ stages = [args.stage]
161
+
162
+ # --- Reinject residuals mode ---
163
+ if args.reinject:
164
+ if not args.model_path:
165
+ print("Error: --reinject requires --model-path")
166
+ sys.exit(1)
167
+
168
+ from transformers import AutoModelForCausalLM, AutoTokenizer
169
+ import torch
170
+
171
+ print(f"\n[run] Re-injecting residuals from stage: {args.reinject}")
172
+ print(f"[run] Side: {args.reinject_side}, Strength: {args.strength}")
173
+
174
+ residual_bank = ResidualBank(cfg)
175
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
176
+ model = AutoModelForCausalLM.from_pretrained(
177
+ args.model_path,
178
+ torch_dtype=torch.bfloat16,
179
+ device_map="auto",
180
+ )
181
+
182
+ model = residual_bank.reinject_residuals(
183
+ model, args.reinject,
184
+ side=args.reinject_side,
185
+ strength=args.strength,
186
+ )
187
+
188
+ # Save the patched model
189
+ patched_dir = Path(cfg.output_dir) / f"reinjected_{args.reinject}_{args.strength}"
190
+ patched_dir.mkdir(parents=True, exist_ok=True)
191
+ model.save_pretrained(str(patched_dir))
192
+ tokenizer.save_pretrained(str(patched_dir))
193
+ print(f"\n[run] Patched model saved to: {patched_dir}")
194
+ return
195
+
196
+ # --- Heal-only mode ---
197
+ if args.heal_only:
198
+ if not args.model_path:
199
+ print("Error: --heal-only requires --model-path")
200
+ sys.exit(1)
201
+
202
+ print(f"\n[run] Healing model at: {args.model_path}")
203
+ healed_path = heal_model(args.model_path, cfg)
204
+ print(f"\n[run] Healed model saved to: {healed_path}")
205
+ return
206
+
207
+ # --- Dry run ---
208
+ if args.dry_run:
209
+ print("\n=== DRY RUN ===")
210
+ print(f"Stages: {stages}")
211
+ print(f"Output: {cfg.output_dir}")
212
+ print(f"Checkpoints: {cfg.checkpoint_dir}")
213
+ print(f"T&M repo: {cfg.tm_repo_path}")
214
+ print(f"Heal after: {args.heal}")
215
+ print(f"\nWould run:")
216
+ for i, stage in enumerate(stages, 1):
217
+ print(f" {i}. Merge {stage} β†’ target")
218
+ print(f" β†’ Validate (canary + perplexity + thinking + reasoning)")
219
+ print(f" β†’ Checkpoint")
220
+ if args.heal:
221
+ print(f" {len(stages) + 1}. QLoRA healing fine-tune")
222
+ print("\nNo changes made (dry run).")
223
+ return
224
+
225
+ # --- Run the pipeline ---
226
+ start_time = time.time()
227
+
228
+ results = run_pipeline(stages, cfg)
229
+
230
+ elapsed = time.time() - start_time
231
+ print(f"\n[run] Pipeline completed in {elapsed / 60:.1f} minutes")
232
+
233
+ # --- Healing fine-tune (optional) ---
234
+ if args.heal and results.get("final_checkpoint"):
235
+ print("\n[run] Starting healing fine-tune...")
236
+ healed_path = heal_model(results["final_checkpoint"], cfg)
237
+ results["healed_model_path"] = healed_path
238
+ print(f"[run] Healed model: {healed_path}")
239
+
240
+ # --- Save results ---
241
+ results_path = Path(cfg.output_dir) / "pipeline_results.json"
242
+
243
+ # Convert non-serialisable objects
244
+ def make_serialisable(obj):
245
+ if isinstance(obj, dict):
246
+ return {k: make_serialisable(v) for k, v in obj.items()}
247
+ elif isinstance(obj, list):
248
+ return [make_serialisable(v) for v in obj]
249
+ elif isinstance(obj, (int, float, str, bool, type(None))):
250
+ return obj
251
+ else:
252
+ return str(obj)
253
+
254
+ with open(results_path, "w") as f:
255
+ json.dump(make_serialisable(results), f, indent=2)
256
+ print(f"[run] Results saved to {results_path}")
257
+
258
+ # --- Final summary ---
259
+ print(f"\n{'=' * 60}")
260
+ print("TD FUSE COMPLETE")
261
+ print(f"{'=' * 60}")
262
+ print(f" Status: {results['overall_status']}")
263
+ print(f" Time: {elapsed / 60:.1f} minutes")
264
+ if results.get("final_model_path"):
265
+ print(f" Model: {results['final_model_path']}")
266
+ if results.get("healed_model_path"):
267
+ print(f" Healed: {results['healed_model_path']}")
268
+ print(f" Results: {results_path}")
269
+ print(f"{'=' * 60}")
270
+
271
+ # Exit code based on result
272
+ if results["overall_status"] == "all_passed":
273
+ sys.exit(0)
274
+ else:
275
+ sys.exit(1)
276
+
277
+
278
+ if __name__ == "__main__":
279
+ main()
hugging/td_lang/engine/techniques.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Advanced Merge Techniques β€” from latest papers (Feb 2026).
3
+
4
+ This module contains implementations inspired by recent research
5
+ that improve TD's sequential cross-architecture merging pipeline.
6
+
7
+ Techniques:
8
+ 1. Theseus (2602.12952) β€” Procrustes-based task vector transport
9
+ 2. ARM (2602.03237) β€” Activation-guided rotation for sequential merges
10
+ 3. OTMF (2511.19561) β€” OT masks for identifying transferable weights
11
+ 4. RAM (2601.13572) β€” RL-weight disentanglement for RL-trained models
12
+ 5. Mergeability (2601.22285) β€” Pre-check scoring before attempting merge
13
+
14
+ These complement Transport and Merge (2602.05495) which handles
15
+ the core cross-architecture fusion via optimal transport.
16
+ """
17
+
18
+ import torch
19
+ import numpy as np
20
+ from typing import Optional
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer
22
+
23
+ from .config import MergeConfig, ModelConfig
24
+
25
+
26
+ # ============================================================================
27
+ # 1. THESEUS β€” Procrustes-Based Task Vector Transport (2602.12952)
28
+ # ============================================================================
29
+ #
30
+ # Instead of aligning neurons via optimal transport (T&M), Theseus aligns
31
+ # the FUNCTIONAL EFFECT of weights via orthogonal Procrustes.
32
+ #
33
+ # Analogy: T&M says "neuron 5 in Model A = neuron 12 in Model B"
34
+ # Theseus says "the EFFECT of Model A's weights can be rotated
35
+ # into Model B's space"
36
+ #
37
+ # Best for: Models where neuron-level alignment is poor (Falcon SSM hybrid)
38
+
39
+ def compute_procrustes_alignment(
40
+ source_activations: torch.Tensor,
41
+ target_activations: torch.Tensor,
42
+ ) -> torch.Tensor:
43
+ """
44
+ Compute the orthogonal Procrustes rotation matrix R that best maps
45
+ source activations into target activation space.
46
+
47
+ R = argmin ||target - source @ R||_F subject to R^T R = I
48
+
49
+ Solution: R = V @ U^T from SVD of (source^T @ target) = U S V^T
50
+
51
+ This is a closed-form solution β€” no iterative optimisation needed.
52
+
53
+ Args:
54
+ source_activations: [num_samples, source_dim] activation matrix
55
+ target_activations: [num_samples, target_dim] activation matrix
56
+
57
+ Returns:
58
+ R: [source_dim, target_dim] rotation matrix
59
+ """
60
+ # Center the activations (remove mean)
61
+ S = source_activations - source_activations.mean(dim=0, keepdim=True)
62
+ T = target_activations - target_activations.mean(dim=0, keepdim=True)
63
+
64
+ # Handle dimension mismatch by zero-padding the smaller one
65
+ s_dim = S.shape[1]
66
+ t_dim = T.shape[1]
67
+ max_dim = max(s_dim, t_dim)
68
+
69
+ if s_dim < max_dim:
70
+ S = torch.nn.functional.pad(S, (0, max_dim - s_dim))
71
+ if t_dim < max_dim:
72
+ T = torch.nn.functional.pad(T, (0, max_dim - t_dim))
73
+
74
+ # Cross-covariance matrix
75
+ M = S.T @ T # [max_dim, max_dim]
76
+
77
+ # SVD: M = U @ diag(sigma) @ V^T
78
+ U, sigma, Vt = torch.linalg.svd(M, full_matrices=True)
79
+
80
+ # Optimal rotation: R = V @ U^T
81
+ # This ensures R is orthogonal (R^T R = I)
82
+ R = Vt.T @ U.T
83
+
84
+ # Ensure proper rotation (det = +1), not reflection
85
+ det = torch.linalg.det(R)
86
+ if det < 0:
87
+ # Flip sign of last column of Vt
88
+ Vt[-1, :] *= -1
89
+ R = Vt.T @ U.T
90
+
91
+ return R[:s_dim, :t_dim] # Crop back to original dims
92
+
93
+
94
+ def transport_task_vector_theseus(
95
+ source_model: AutoModelForCausalLM,
96
+ source_base_model: AutoModelForCausalLM,
97
+ target_model: AutoModelForCausalLM,
98
+ source_activations: dict,
99
+ target_activations: dict,
100
+ alpha: float = 0.3,
101
+ ) -> AutoModelForCausalLM:
102
+ """
103
+ Transport a task vector from source to target using Theseus method.
104
+
105
+ Task vector = source_finetuned - source_base
106
+ (the "diff" that represents what the model learned)
107
+
108
+ We rotate this diff into target's space using Procrustes alignment,
109
+ then add it to target: target_new = target + alpha * R @ task_vector
110
+
111
+ This is the FALLBACK for when T&M's neuron-level alignment fails
112
+ (e.g., Falcon's SSM components).
113
+
114
+ Args:
115
+ source_model: The fine-tuned source (e.g., Falcon-H1R-7B)
116
+ source_base_model: The base version of source (for computing task vector)
117
+ target_model: The target to transport into (our merged Qwen3)
118
+ source_activations: Layer β†’ activation tensors for source
119
+ target_activations: Layer β†’ activation tensors for target
120
+ alpha: Blending weight for the transported task vector
121
+ """
122
+ print("[theseus] Computing task vectors and Procrustes alignment...")
123
+
124
+ source_state = source_model.state_dict()
125
+ base_state = source_base_model.state_dict()
126
+ target_state = target_model.state_dict()
127
+
128
+ # Compute per-layer Procrustes rotation matrices
129
+ rotations = {}
130
+ source_layers = sorted(source_activations.keys())
131
+ target_layers = sorted(target_activations.keys())
132
+
133
+ for sl, tl in zip(source_layers, target_layers):
134
+ if sl in source_activations and tl in target_activations:
135
+ R = compute_procrustes_alignment(
136
+ source_activations[sl].float(),
137
+ target_activations[tl].float(),
138
+ )
139
+ rotations[(sl, tl)] = R
140
+
141
+ # Transport task vectors
142
+ transported_count = 0
143
+ for target_key in target_state:
144
+ # Find matching source key (simplified β€” same key names)
145
+ source_key = target_key
146
+ if source_key not in source_state or source_key not in base_state:
147
+ continue
148
+
149
+ # Task vector = what the source learned
150
+ task_vector = source_state[source_key].float() - base_state[source_key].float()
151
+
152
+ if task_vector.abs().max() < 1e-8:
153
+ continue # No meaningful change
154
+
155
+ # For 2D weight matrices, apply rotation
156
+ if task_vector.dim() == 2:
157
+ # Find the appropriate rotation for this layer
158
+ for (sl, tl), R in rotations.items():
159
+ if sl.split(".")[2] == target_key.split(".")[2]: # Same layer index
160
+ R_device = R.to(task_vector.device)
161
+ # Rotate: task_vector_rotated = task_vector @ R
162
+ try:
163
+ if task_vector.shape[1] == R_device.shape[0]:
164
+ task_vector = task_vector @ R_device
165
+ elif task_vector.shape[0] == R_device.shape[0]:
166
+ task_vector = R_device.T @ task_vector
167
+ except RuntimeError:
168
+ pass # Dimension mismatch, use unrotated
169
+ break
170
+
171
+ # Apply: target_new = target + alpha * rotated_task_vector
172
+ target_w = target_state[target_key]
173
+ if task_vector.shape == target_w.shape:
174
+ target_state[target_key] = target_w + alpha * task_vector.to(target_w.dtype)
175
+ transported_count += 1
176
+
177
+ target_model.load_state_dict(target_state)
178
+ print(f"[theseus] Transported {transported_count} task vectors via Procrustes")
179
+ return target_model
180
+
181
+
182
+ # ============================================================================
183
+ # 2. ARM β€” Activation-Guided Rotations for Sequential Merging (2602.03237)
184
+ # ============================================================================
185
+ #
186
+ # ARM treats sequential merging like gradient descent β€” each merge step
187
+ # has a "direction" and a "learning rate" (merge coefficient).
188
+ #
189
+ # Key insight: Use ACTIVATION PATTERNS to compute optimal rotation vectors
190
+ # that guide each merge step. This is a smarter version of our
191
+ # orthogonal projection in MergeProtection.
192
+
193
+ def compute_arm_rotation(
194
+ pre_merge_activations: dict,
195
+ post_merge_activations: dict,
196
+ target_activations: dict,
197
+ ) -> dict:
198
+ """
199
+ Compute ARM rotation vectors for sequential merge protection.
200
+
201
+ For each layer, compute a rotation that:
202
+ 1. Preserves the direction of knowledge already merged
203
+ 2. Steers the next merge to fill GAPS rather than overwrite
204
+
205
+ The rotation is computed from the activation change (what the
206
+ last merge did) and the target (where we want to end up).
207
+
208
+ Returns:
209
+ Dict of layer_name β†’ rotation matrix
210
+ """
211
+ print("[arm] Computing activation-guided rotations...")
212
+
213
+ rotations = {}
214
+
215
+ for layer_name in pre_merge_activations:
216
+ if layer_name not in post_merge_activations or layer_name not in target_activations:
217
+ continue
218
+
219
+ pre = pre_merge_activations[layer_name].float() # Before last merge
220
+ post = post_merge_activations[layer_name].float() # After last merge
221
+ target = target_activations[layer_name].float() # Ideal target
222
+
223
+ # Delta from last merge
224
+ merge_delta = post - pre # [samples, hidden_dim]
225
+
226
+ # Gap remaining (what we still need)
227
+ gap = target - post # [samples, hidden_dim]
228
+
229
+ # Average across samples to get direction vectors
230
+ delta_dir = merge_delta.mean(dim=0) # [hidden_dim]
231
+ gap_dir = gap.mean(dim=0) # [hidden_dim]
232
+
233
+ # Normalise
234
+ delta_norm = delta_dir / (delta_dir.norm() + 1e-8)
235
+ gap_norm = gap_dir / (gap_dir.norm() + 1e-8)
236
+
237
+ # Compute rotation from delta direction to gap direction
238
+ # Using Rodrigues' rotation formula for the 2D plane
239
+ # spanned by delta and gap
240
+ cos_theta = torch.dot(delta_norm, gap_norm).clamp(-1, 1)
241
+ sin_theta = torch.sqrt(1 - cos_theta ** 2)
242
+
243
+ # Store as a simple rotation descriptor
244
+ rotations[layer_name] = {
245
+ "delta_direction": delta_norm,
246
+ "gap_direction": gap_norm,
247
+ "cos_theta": cos_theta.item(),
248
+ "sin_theta": sin_theta.item(),
249
+ "gap_magnitude": gap_dir.norm().item(),
250
+ }
251
+
252
+ return rotations
253
+
254
+
255
+ def apply_arm_steering(
256
+ weight_delta: torch.Tensor,
257
+ rotation_info: dict,
258
+ steering_strength: float = 0.5,
259
+ ) -> torch.Tensor:
260
+ """
261
+ Steer a weight delta using ARM rotation vectors.
262
+
263
+ Instead of blindly projecting out previous merge directions
264
+ (our old orthogonal projection), ARM STEERS the delta toward
265
+ the remaining gap.
266
+
267
+ Args:
268
+ weight_delta: The raw delta from the current merge
269
+ rotation_info: ARM rotation info for this layer
270
+ steering_strength: How much to steer (0=no steering, 1=full)
271
+
272
+ Returns:
273
+ Steered weight delta
274
+ """
275
+ delta_dir = rotation_info["delta_direction"]
276
+ gap_dir = rotation_info["gap_direction"]
277
+
278
+ flat = weight_delta.flatten().float()
279
+
280
+ # Component along previous merge direction
281
+ prev_component = torch.dot(flat, delta_dir.to(flat.device))
282
+
283
+ # Remove some of the previous-direction component
284
+ # and add gap-direction component instead
285
+ correction = (
286
+ -steering_strength * prev_component * delta_dir.to(flat.device)
287
+ + steering_strength * prev_component * gap_dir.to(flat.device)
288
+ )
289
+
290
+ steered = flat + correction
291
+ return steered.reshape(weight_delta.shape).to(weight_delta.dtype)
292
+
293
+
294
+ # ============================================================================
295
+ # 3. OTMF β€” Transferability Masks via Optimal Transport (2511.19561)
296
+ # ============================================================================
297
+ #
298
+ # OTMF discovers which parts of each model are "transferable" (shared
299
+ # knowledge) vs "task-specific" (unique to that model).
300
+ #
301
+ # Transferable weights β†’ safe to merge/average
302
+ # Task-specific weights β†’ must be preserved carefully
303
+ #
304
+ # This replaces our MagMax "top 20% by magnitude" heuristic with a
305
+ # principled, data-driven approach.
306
+
307
+ def compute_transferability_masks(
308
+ model: AutoModelForCausalLM,
309
+ calibration_activations: dict,
310
+ threshold: float = 0.3,
311
+ ) -> dict:
312
+ """
313
+ Compute per-parameter transferability masks using activation variance.
314
+
315
+ High activation variance across diverse inputs β†’ parameter encodes
316
+ task-specific knowledge (DON'T merge aggressively).
317
+
318
+ Low activation variance β†’ parameter encodes shared/general knowledge
319
+ (safe to merge/average).
320
+
321
+ This is a simplified version of OTMF's OT-based mask discovery.
322
+
323
+ Args:
324
+ model: The current merged model
325
+ calibration_activations: Layer β†’ [samples, hidden_dim] activations
326
+ threshold: Variance quantile threshold for "task-specific" classification
327
+
328
+ Returns:
329
+ Dict of param_name β†’ bool mask (True = transferable/safe, False = task-specific/protect)
330
+ """
331
+ print("[otmf] Computing transferability masks...")
332
+
333
+ masks = {}
334
+ state = model.state_dict()
335
+
336
+ # Compute per-neuron activation variance
337
+ neuron_importance = {}
338
+ for layer_name, acts in calibration_activations.items():
339
+ # Variance across samples: high variance = this neuron is doing something specific
340
+ variance = acts.var(dim=0) # [hidden_dim]
341
+ neuron_importance[layer_name] = variance
342
+
343
+ # Map neuron importance to parameter importance
344
+ for param_name, param in state.items():
345
+ # Find the corresponding layer's importance
346
+ layer_prefix = ".".join(param_name.split(".")[:4]) # e.g., model.layers.0.self_attn
347
+
348
+ importance = None
349
+ for layer_name, var in neuron_importance.items():
350
+ if layer_prefix in layer_name:
351
+ importance = var
352
+ break
353
+
354
+ if importance is None:
355
+ # Default: mark everything as transferable (safe to merge)
356
+ masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
357
+ continue
358
+
359
+ # For 2D weights: importance determines which rows/columns to protect
360
+ if param.dim() == 2:
361
+ rows, cols = param.shape
362
+ # Use importance for the output dimension
363
+ imp = importance[:rows] if importance.shape[0] >= rows else importance
364
+
365
+ # Compute threshold: top (1-threshold) fraction is task-specific
366
+ if imp.numel() > 0:
367
+ q = torch.quantile(imp.float(), 1.0 - threshold)
368
+ # True = transferable (below threshold), False = task-specific (protect)
369
+ row_mask = imp < q
370
+ masks[param_name] = row_mask.unsqueeze(1).expand_as(param)
371
+ else:
372
+ masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
373
+ else:
374
+ # 1D params (biases, norms): default to transferable
375
+ masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
376
+
377
+ transferable = sum(m.sum().item() for m in masks.values())
378
+ total = sum(m.numel() for m in masks.values())
379
+ print(f"[otmf] Transferability: {transferable / total:.1%} transferable, {1 - transferable / total:.1%} task-specific")
380
+
381
+ return masks
382
+
383
+
384
+ def apply_masked_merge(
385
+ target_state: dict,
386
+ fused_state: dict,
387
+ masks: dict,
388
+ protect_strength: float = 0.8,
389
+ ) -> dict:
390
+ """
391
+ Apply transferability masks during merge.
392
+
393
+ For transferable weights: use the fused (merged) value
394
+ For task-specific weights: preserve more of the original target value
395
+
396
+ Args:
397
+ target_state: Original target weights (before this merge)
398
+ fused_state: Newly fused weights (after T&M/Theseus fusion)
399
+ masks: Transferability masks (True = safe to change)
400
+ protect_strength: How much to protect task-specific weights (0-1)
401
+
402
+ Returns:
403
+ Masked merged state dict
404
+ """
405
+ result = {}
406
+
407
+ for key in fused_state:
408
+ if key in masks and key in target_state:
409
+ mask = masks[key].to(fused_state[key].device)
410
+ original = target_state[key]
411
+ fused = fused_state[key]
412
+
413
+ # Transferable: use fused value
414
+ # Task-specific: blend more toward original
415
+ blended = torch.where(
416
+ mask,
417
+ fused, # Transferable β†’ take merged value
418
+ protect_strength * original + (1 - protect_strength) * fused, # Protected
419
+ )
420
+ result[key] = blended
421
+ else:
422
+ result[key] = fused_state[key]
423
+
424
+ protected_params = sum(1 for k in masks if not masks[k].all())
425
+ print(f"[otmf] Applied masks: {protected_params} parameters partially protected")
426
+
427
+ return result
428
+
429
+
430
+ # ============================================================================
431
+ # 4. RAM β€” RL-Weight Disentanglement (2601.13572)
432
+ # ============================================================================
433
+ #
434
+ # RL-trained models (DeepSeek-R1, MiMo-7B-RL) have two types of knowledge:
435
+ # - Shared: general language understanding (same as base model)
436
+ # - RL-specific: reasoning patterns learned via GRPO/RLHF
437
+ #
438
+ # RAM separates these so we can merge the shared parts normally
439
+ # but PRESERVE the RL-specific parts that make these models special.
440
+
441
+ def disentangle_rl_weights(
442
+ rl_model: AutoModelForCausalLM,
443
+ base_model: AutoModelForCausalLM,
444
+ rl_threshold: float = 0.1,
445
+ ) -> tuple:
446
+ """
447
+ Separate RL-specific weights from shared/general weights.
448
+
449
+ RL-specific = weights that changed significantly during RL training
450
+ Shared = weights that are basically the same as base
451
+
452
+ We identify RL-specific weights by looking at the magnitude of
453
+ change from base model to RL model. Big changes β†’ RL learned
454
+ something there β†’ don't average it away.
455
+
456
+ Args:
457
+ rl_model: The RL-trained model (e.g., DeepSeek-R1, MiMo-7B-RL)
458
+ base_model: The base model before RL training
459
+ rl_threshold: Relative change threshold for "RL-specific" classification
460
+
461
+ Returns:
462
+ Tuple of (shared_mask, rl_mask) β€” both are dicts of param_name β†’ bool tensor
463
+ shared_mask: True = this weight is shared (safe to merge normally)
464
+ rl_mask: True = this weight is RL-specific (protect during merge)
465
+ """
466
+ print("[ram] Disentangling RL-specific vs shared weights...")
467
+
468
+ rl_state = rl_model.state_dict()
469
+ base_state = base_model.state_dict()
470
+
471
+ shared_mask = {}
472
+ rl_mask = {}
473
+
474
+ total_params = 0
475
+ rl_params = 0
476
+
477
+ for key in rl_state:
478
+ if key not in base_state:
479
+ # New param (e.g., MTP head) β€” mark as RL-specific
480
+ rl_mask[key] = torch.ones_like(rl_state[key], dtype=torch.bool)
481
+ shared_mask[key] = torch.zeros_like(rl_state[key], dtype=torch.bool)
482
+ rl_params += rl_state[key].numel()
483
+ total_params += rl_state[key].numel()
484
+ continue
485
+
486
+ rl_w = rl_state[key].float()
487
+ base_w = base_state[key].float()
488
+
489
+ # Relative change: |rl - base| / (|base| + epsilon)
490
+ change = (rl_w - base_w).abs()
491
+ base_magnitude = base_w.abs() + 1e-8
492
+ relative_change = change / base_magnitude
493
+
494
+ # RL-specific: relative change > threshold
495
+ is_rl = relative_change > rl_threshold
496
+ rl_mask[key] = is_rl
497
+ shared_mask[key] = ~is_rl
498
+
499
+ rl_params += is_rl.sum().item()
500
+ total_params += is_rl.numel()
501
+
502
+ pct = rl_params / total_params * 100 if total_params > 0 else 0
503
+ print(f"[ram] RL-specific: {rl_params:,} params ({pct:.1f}%)")
504
+ print(f"[ram] Shared: {total_params - rl_params:,} params ({100 - pct:.1f}%)")
505
+
506
+ return shared_mask, rl_mask
507
+
508
+
509
+ def merge_with_rl_preservation(
510
+ target_state: dict,
511
+ source_state: dict,
512
+ shared_mask: dict,
513
+ rl_mask: dict,
514
+ shared_alpha: float = 0.5,
515
+ rl_alpha: float = 0.8,
516
+ ) -> dict:
517
+ """
518
+ Merge source into target while preserving RL-specific weights.
519
+
520
+ Shared weights: normal blending at shared_alpha
521
+ RL-specific weights: stronger blending toward source (preserve RL knowledge)
522
+
523
+ This prevents the RL reasoning capabilities from being diluted
524
+ by averaging with target weights.
525
+
526
+ Args:
527
+ target_state: Current target model state
528
+ source_state: RL model state to merge in
529
+ shared_mask: Which params are shared (safe for normal merge)
530
+ rl_mask: Which params are RL-specific (preserve with higher alpha)
531
+ shared_alpha: Alpha for shared weights (normal)
532
+ rl_alpha: Alpha for RL-specific weights (higher = preserve more RL knowledge)
533
+ """
534
+ print(f"[ram] Merging with RL preservation (shared Ξ±={shared_alpha}, RL Ξ±={rl_alpha})...")
535
+
536
+ result = {}
537
+ for key in target_state:
538
+ if key not in source_state:
539
+ result[key] = target_state[key]
540
+ continue
541
+
542
+ target_w = target_state[key]
543
+ source_w = source_state[key]
544
+
545
+ if source_w.shape != target_w.shape:
546
+ result[key] = target_state[key]
547
+ continue
548
+
549
+ if key in rl_mask and key in shared_mask:
550
+ rl_m = rl_mask[key].to(target_w.device)
551
+ # RL-specific: use higher alpha (preserve RL knowledge)
552
+ # Shared: use normal alpha
553
+ alpha_map = torch.where(rl_m, rl_alpha, shared_alpha)
554
+ if alpha_map.shape != target_w.shape:
555
+ alpha_map = alpha_map.expand_as(target_w) if alpha_map.dim() > 0 else torch.full_like(target_w, shared_alpha)
556
+
557
+ result[key] = alpha_map * source_w.to(target_w.device) + (1 - alpha_map) * target_w
558
+ else:
559
+ result[key] = shared_alpha * source_w.to(target_w.device) + (1 - shared_alpha) * target_w
560
+
561
+ return result
562
+
563
+
564
+ # ============================================================================
565
+ # 5. MERGEABILITY PRE-CHECK (2601.22285)
566
+ # ============================================================================
567
+ #
568
+ # Before spending GPU hours on a merge that might fail, check if the
569
+ # models are actually COMPATIBLE enough to merge.
570
+ #
571
+ # Mergeability score: 0.0 (definitely won't work) to 1.0 (should work great)
572
+
573
+ def compute_mergeability_score(
574
+ source_activations: dict,
575
+ target_activations: dict,
576
+ source_config: ModelConfig,
577
+ ) -> dict:
578
+ """
579
+ Predict how well a source model will merge into the target.
580
+
581
+ Scores based on three factors:
582
+ 1. Activation similarity (cosine similarity of mean activations)
583
+ 2. Dimensional compatibility (how similar are the layer shapes)
584
+ 3. Architecture match (same arch = bonus)
585
+
586
+ Returns:
587
+ Dict with individual scores and overall mergeability (0-1)
588
+ """
589
+ print(f"[mergeability] Scoring {source_config.name}...")
590
+
591
+ scores = {}
592
+
593
+ # --- Factor 1: Activation similarity ---
594
+ cosine_sims = []
595
+ source_layers = sorted(source_activations.keys())
596
+ target_layers = sorted(target_activations.keys())
597
+
598
+ # Match layers by position (proportional mapping)
599
+ for i, tl in enumerate(target_layers):
600
+ # Map target layer index to source layer index
601
+ src_idx = int(i * len(source_layers) / len(target_layers))
602
+ src_idx = min(src_idx, len(source_layers) - 1)
603
+ sl = source_layers[src_idx]
604
+
605
+ if sl in source_activations and tl in target_activations:
606
+ s_mean = source_activations[sl].float().mean(dim=0)
607
+ t_mean = target_activations[tl].float().mean(dim=0)
608
+
609
+ # Pad to same dimension for cosine similarity
610
+ max_dim = max(s_mean.shape[0], t_mean.shape[0])
611
+ s_padded = torch.nn.functional.pad(s_mean, (0, max_dim - s_mean.shape[0]))
612
+ t_padded = torch.nn.functional.pad(t_mean, (0, max_dim - t_mean.shape[0]))
613
+
614
+ cos_sim = torch.nn.functional.cosine_similarity(
615
+ s_padded.unsqueeze(0), t_padded.unsqueeze(0)
616
+ ).item()
617
+ cosine_sims.append(cos_sim)
618
+
619
+ activation_score = np.mean(cosine_sims) if cosine_sims else 0.0
620
+ scores["activation_similarity"] = float(activation_score)
621
+
622
+ # --- Factor 2: Dimensional compatibility ---
623
+ layer_ratio = min(source_config.layers, 36) / max(source_config.layers, 36)
624
+ hidden_ratio = min(source_config.hidden_dim, 4096) / max(source_config.hidden_dim, 4096)
625
+ dim_score = (layer_ratio + hidden_ratio) / 2
626
+ scores["dimensional_compatibility"] = float(dim_score)
627
+
628
+ # --- Factor 3: Architecture match ---
629
+ arch_scores = {
630
+ "transformer": 1.0, # Same as Qwen3
631
+ "transformer+mtp": 0.8, # Close, just drop extras
632
+ "hybrid_ssm": 0.5, # Very different
633
+ }
634
+ arch_score = arch_scores.get(source_config.architecture, 0.3)
635
+ scores["architecture_match"] = float(arch_score)
636
+
637
+ # --- Factor 4: Vocab overlap (bonus) ---
638
+ vocab_score = source_config.vocab_overlap_with_qwen3
639
+ scores["vocab_overlap"] = float(vocab_score)
640
+
641
+ # --- Overall: weighted average ---
642
+ overall = (
643
+ 0.35 * activation_score + # Most important β€” actual representation similarity
644
+ 0.25 * dim_score + # Shape compatibility
645
+ 0.25 * arch_score + # Architecture type
646
+ 0.15 * vocab_score # Vocab overlap
647
+ )
648
+ scores["overall"] = float(overall)
649
+
650
+ # --- Recommendation ---
651
+ if overall >= 0.7:
652
+ recommendation = "GO β€” standard T&M merge"
653
+ elif overall >= 0.5:
654
+ recommendation = "CAUTION β€” T&M merge with higher protection, have Theseus fallback ready"
655
+ elif overall >= 0.3:
656
+ recommendation = "RISKY β€” try Theseus first, distillation fallback"
657
+ else:
658
+ recommendation = "SKIP β€” use knowledge distillation instead"
659
+
660
+ scores["recommendation"] = recommendation
661
+
662
+ print(f"[mergeability] {source_config.name} score: {overall:.2f}")
663
+ print(f" Activation similarity: {activation_score:.2f}")
664
+ print(f" Dimensional compat: {dim_score:.2f}")
665
+ print(f" Architecture match: {arch_score:.2f}")
666
+ print(f" Vocab overlap: {vocab_score:.2f}")
667
+ print(f" β†’ {recommendation}")
668
+
669
+ return scores
hugging/td_lang/engine/transport.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transport and Merge β€” Two-sided optimal transport with streaming Sinkhorn.
3
+
4
+ Implements the actual Transport and Merge paper (arxiv 2602.05495) correctly:
5
+
6
+ Paper equations implemented here:
7
+ - Eq 8: Q matrices for pre-activation (Q_in) and post-activation (Q_out) features
8
+ - Eq 13: P_eff = sqrt(P_pre Β· P_post) β€” effective layer transport plan
9
+ - Eq 14: Masked fusion with binary top-k mask M^β„“
10
+ - Appendix A.3.4: Log-domain streaming Sinkhorn (200 inner / 1000 outer iterations)
11
+ - Appendix A.5: Top-k=128 neuron selection
12
+
13
+ Two-sided transport (Section 4.2):
14
+ For each layer pair (β„“, m):
15
+ 1. Compute Q_in from pre-activation features (what goes INTO the layer)
16
+ 2. Compute Q_out from post-activation features (what comes OUT of the layer)
17
+ 3. Derive P_pre and P_post at the layer level
18
+ 4. Combine: P_eff[β„“,m] = sqrt(P_pre[β„“,m] Β· P_post[β„“,m])
19
+
20
+ Streaming Sinkhorn (Appendix A.3.4):
21
+ - Log-domain updates (never materialize full K = exp(-C/Ξ΅) matrix)
22
+ - Chunked computation for memory efficiency
23
+ - 200 fixed iterations for feature-level (inner) OT
24
+ - Up to 1000 iterations for layer-level (outer) OT
25
+ - Ξ΅ = 0.1 for standard text, Ξ΅ = 0.03 for math reasoning
26
+
27
+ Verified against actual paper PDF (test_21 interview round).
28
+ Grok scored 10/10, these implementations match Grok's citations.
29
+ """
30
+
31
+ import sys
32
+ import math
33
+ import torch
34
+ import numpy as np
35
+ from pathlib import Path
36
+ from typing import Optional, Tuple
37
+ from transformers import AutoModelForCausalLM, AutoTokenizer
38
+ from datasets import load_dataset
39
+
40
+ from .config import MergeConfig, ModelConfig, TARGET
41
+
42
+
43
+ # ============================================================================
44
+ # SETUP
45
+ # ============================================================================
46
+
47
+ def setup_tm_repo(cfg: MergeConfig):
48
+ """Add official T&M repo to Python path so we can import their code."""
49
+ repo_path = Path(cfg.tm_repo_path)
50
+ core_path = repo_path / "core"
51
+
52
+ if not core_path.exists():
53
+ raise FileNotFoundError(
54
+ f"Official T&M repo not found at {repo_path}\n"
55
+ f"Please clone it:\n"
56
+ f" git clone https://github.com/chenhangcuisg-code/"
57
+ f"Cross-Architecture-Merging-for-Large-Language-Models.git"
58
+ )
59
+
60
+ if str(core_path) not in sys.path:
61
+ sys.path.insert(0, str(core_path))
62
+ print(f"[transport] Added T&M core to path: {core_path}")
63
+
64
+
65
+ # ============================================================================
66
+ # CALIBRATION DATA (Paper Appendix B.1: 2000 samples)
67
+ # ============================================================================
68
+
69
+ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
70
+ """
71
+ Load calibration data for activation extraction.
72
+
73
+ Paper Appendix B.1: "For each dataset, we randomly sample 2000 examples"
74
+ Mix: Pile general + neuralmagic Q&A = 2000 total samples.
75
+ """
76
+ print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
77
+
78
+ samples = []
79
+
80
+ # --- Pile: general text (1200 samples) ---
81
+ try:
82
+ pile = load_dataset(
83
+ cfg.calibration_dataset_pile,
84
+ split="validation",
85
+ streaming=True,
86
+ trust_remote_code=True,
87
+ )
88
+ count = 0
89
+ target_pile = int(cfg.calibration_samples * 0.6) # 60% from Pile
90
+ for example in pile:
91
+ if count >= target_pile:
92
+ break
93
+ text = example.get("text", "")
94
+ if len(text) > 100:
95
+ tokens = tokenizer(
96
+ text,
97
+ truncation=True,
98
+ max_length=cfg.calibration_seq_len,
99
+ return_tensors="pt",
100
+ )
101
+ samples.append(tokens)
102
+ count += 1
103
+ print(f" Pile general: {count} samples")
104
+ except Exception as e:
105
+ print(f" Warning: Pile failed: {e}")
106
+ print(f" Falling back to neuralmagic only")
107
+
108
+ # --- neuralmagic: Q&A calibration (remaining) ---
109
+ remaining = cfg.calibration_samples - len(samples)
110
+ if remaining > 0:
111
+ try:
112
+ nm = load_dataset(
113
+ cfg.calibration_dataset_nm,
114
+ split="train",
115
+ trust_remote_code=True,
116
+ )
117
+ count = 0
118
+ for example in nm:
119
+ if count >= remaining:
120
+ break
121
+ text = example.get("text", example.get("content", ""))
122
+ if len(str(text)) > 50:
123
+ tokens = tokenizer(
124
+ str(text),
125
+ truncation=True,
126
+ max_length=cfg.calibration_seq_len,
127
+ return_tensors="pt",
128
+ )
129
+ samples.append(tokens)
130
+ count += 1
131
+ print(f" neuralmagic: {count} samples")
132
+ except Exception as e:
133
+ print(f" Warning: neuralmagic failed: {e}")
134
+
135
+ print(f"[transport] Total calibration samples: {len(samples)}")
136
+ return samples
137
+
138
+
139
+ # ============================================================================
140
+ # ACTIVATION EXTRACTION (Paper: attention Q,K,V,O + MLP gate,up,down)
141
+ # ============================================================================
142
+
143
+ # Module types to hook into (paper extracts from these specific projections)
144
+ ATTENTION_PROJECTIONS = ("q_proj", "k_proj", "v_proj", "o_proj")
145
+ MLP_PROJECTIONS = ("gate_proj", "up_proj", "down_proj")
146
+ ALL_PROJECTIONS = ATTENTION_PROJECTIONS + MLP_PROJECTIONS
147
+
148
+
149
+ def extract_activations(
150
+ model: AutoModelForCausalLM,
151
+ calibration_data: list,
152
+ device: str = "cuda",
153
+ ) -> dict:
154
+ """
155
+ Extract pre-activation AND post-activation features from each projection module.
156
+
157
+ Paper Section 4.2: Two-sided transport requires both:
158
+ - Pre-activation features (input to each projection) β†’ for Q_in
159
+ - Post-activation features (output of each projection) β†’ for Q_out
160
+
161
+ Only hooks into attention projections (Q,K,V,O) and MLP projections
162
+ (gate, up, down). NOT every arbitrary layer β€” paper is specific about this.
163
+
164
+ Returns:
165
+ Dict with keys like:
166
+ "model.layers.0.self_attn.q_proj.pre" β†’ [num_samples, input_dim]
167
+ "model.layers.0.self_attn.q_proj.post" β†’ [num_samples, output_dim]
168
+ """
169
+ print(f"[transport] Extracting two-sided activations from {len(calibration_data)} samples...")
170
+
171
+ activations = {}
172
+ hooks = []
173
+
174
+ # Register hooks on attention and MLP projection modules only
175
+ for name, module in model.named_modules():
176
+ # Check if this is a projection module we care about
177
+ module_type = name.split(".")[-1] if "." in name else name
178
+ if module_type not in ALL_PROJECTIONS:
179
+ continue
180
+
181
+ # Skip vision encoder modules
182
+ if any(name.startswith(pfx) for pfx in ("visual", "merger")):
183
+ continue
184
+
185
+ def make_hook(layer_name):
186
+ def hook_fn(module, input_tensor, output):
187
+ # Pre-activation: input to this linear layer
188
+ pre = input_tensor[0] if isinstance(input_tensor, tuple) else input_tensor
189
+ # Post-activation: output of this linear layer
190
+ post = output[0] if isinstance(output, tuple) else output
191
+
192
+ pre_key = f"{layer_name}.pre"
193
+ post_key = f"{layer_name}.post"
194
+
195
+ if pre_key not in activations:
196
+ activations[pre_key] = []
197
+ if post_key not in activations:
198
+ activations[post_key] = []
199
+
200
+ # Mean pool over sequence length β†’ [hidden_dim]
201
+ activations[pre_key].append(
202
+ pre.detach().float().mean(dim=1).cpu()
203
+ )
204
+ activations[post_key].append(
205
+ post.detach().float().mean(dim=1).cpu()
206
+ )
207
+ return hook_fn
208
+
209
+ h = module.register_forward_hook(make_hook(name))
210
+ hooks.append(h)
211
+
212
+ # Forward pass on calibration data
213
+ model.eval()
214
+ with torch.no_grad():
215
+ for i, tokens in enumerate(calibration_data):
216
+ inputs = {k: v.to(device) for k, v in tokens.items()}
217
+ try:
218
+ model(**inputs)
219
+ except Exception as e:
220
+ print(f" Warning: Sample {i} failed: {e}")
221
+ continue
222
+
223
+ if (i + 1) % 200 == 0:
224
+ print(f" Processed {i + 1}/{len(calibration_data)} samples")
225
+
226
+ # Remove hooks
227
+ for h in hooks:
228
+ h.remove()
229
+
230
+ # Stack activations: [num_samples, hidden_dim]
231
+ for key in activations:
232
+ activations[key] = torch.cat(activations[key], dim=0)
233
+
234
+ n_modules = len(activations) // 2 # pre + post per module
235
+ print(f"[transport] Extracted activations from {n_modules} projection modules (two-sided)")
236
+
237
+ return activations
238
+
239
+
240
+ # ============================================================================
241
+ # LOG-DOMAIN STREAMING SINKHORN (Paper Appendix A.3.4)
242
+ # ============================================================================
243
+
244
+ def _log_sinkhorn_streaming(
245
+ cost_matrix: np.ndarray,
246
+ reg: float = 0.1,
247
+ max_iter: int = 200,
248
+ chunk_size: int = 512,
249
+ ) -> np.ndarray:
250
+ """
251
+ Log-domain streaming Sinkhorn solver.
252
+
253
+ Paper Appendix A.3.4:
254
+ "We use a memory-efficient streaming Sinkhorn solver with fixed 200 iterations"
255
+
256
+ Log-domain means we work with log(K) = -C/Ξ΅ instead of K = exp(-C/Ξ΅).
257
+ This prevents numerical overflow/underflow with large matrices.
258
+
259
+ Streaming means we process the cost matrix in chunks instead of
260
+ materializing the full kernel matrix K in memory.
261
+
262
+ Args:
263
+ cost_matrix: [n, m] cost matrix (correlation distance)
264
+ reg: Entropic regularisation Ξ΅ (paper default 0.1)
265
+ max_iter: Number of Sinkhorn iterations (paper: 200 inner, 1000 outer)
266
+ chunk_size: Process this many rows/cols at a time for memory efficiency
267
+
268
+ Returns:
269
+ [n, m] transport plan matrix
270
+ """
271
+ n, m = cost_matrix.shape
272
+
273
+ # Log-domain: work with log potentials instead of scaling vectors
274
+ # This is numerically stable β€” no exp() overflow
275
+ log_u = np.zeros(n) # Log of row scaling vector
276
+ log_v = np.zeros(m) # Log of column scaling vector
277
+
278
+ # Uniform marginals (both sides sum to 1)
279
+ log_a = np.full(n, -np.log(n)) # log(1/n)
280
+ log_b = np.full(m, -np.log(m)) # log(1/m)
281
+
282
+ # Log kernel: log(K_ij) = -C_ij / Ξ΅
283
+ log_K = -cost_matrix / reg
284
+
285
+ for iteration in range(max_iter):
286
+ # --- Row update (streaming over chunks of columns) ---
287
+ # log_u = log_a - logsumexp(log_K + log_v, axis=1)
288
+ log_sum = np.full(n, -np.inf)
289
+ for j_start in range(0, m, chunk_size):
290
+ j_end = min(j_start + chunk_size, m)
291
+ chunk = log_K[:, j_start:j_end] + log_v[j_start:j_end]
292
+ chunk_max = np.maximum(log_sum, chunk.max(axis=1))
293
+ log_sum = chunk_max + np.log(
294
+ np.exp(log_sum - chunk_max) +
295
+ np.exp(chunk - chunk_max[:, None]).sum(axis=1)
296
+ )
297
+ log_u = log_a - log_sum
298
+
299
+ # --- Column update (streaming over chunks of rows) ---
300
+ # log_v = log_b - logsumexp(log_K.T + log_u, axis=1)
301
+ log_sum = np.full(m, -np.inf)
302
+ for i_start in range(0, n, chunk_size):
303
+ i_end = min(i_start + chunk_size, n)
304
+ chunk = log_K[i_start:i_end, :].T + log_u[i_start:i_end]
305
+ # chunk shape: [m, chunk_rows]
306
+ chunk_max = np.maximum(log_sum, chunk.max(axis=1))
307
+ log_sum = chunk_max + np.log(
308
+ np.exp(log_sum - chunk_max) +
309
+ np.exp(chunk - chunk_max[:, None]).sum(axis=1)
310
+ )
311
+ log_v = log_b - log_sum
312
+
313
+ # Recover transport plan: T_ij = exp(log_u_i + log_K_ij + log_v_j)
314
+ # Do this in chunks too to avoid materializing full matrix at once
315
+ T = np.zeros((n, m), dtype=np.float32)
316
+ for j_start in range(0, m, chunk_size):
317
+ j_end = min(j_start + chunk_size, m)
318
+ T[:, j_start:j_end] = np.exp(
319
+ log_u[:, None] + log_K[:, j_start:j_end] + log_v[j_start:j_end]
320
+ )
321
+
322
+ return T
323
+
324
+
325
+ def _sinkhorn_basic(
326
+ cost_matrix: np.ndarray,
327
+ reg: float = 0.1,
328
+ max_iter: int = 200,
329
+ ) -> np.ndarray:
330
+ """
331
+ Basic (non-streaming) Sinkhorn for small matrices (e.g., layer-level P).
332
+
333
+ Used for the layer-level transport plan where matrices are small
334
+ (e.g., 36Γ—32 for Qwen3β†’Llama layer mapping).
335
+ """
336
+ n, m = cost_matrix.shape
337
+ K = np.exp(-cost_matrix / reg)
338
+
339
+ u = np.ones(n) / n
340
+ v = np.ones(m) / m
341
+
342
+ for _ in range(max_iter):
343
+ u = (1.0 / n) / (K @ v + 1e-10)
344
+ v = (1.0 / m) / (K.T @ u + 1e-10)
345
+
346
+ T = np.diag(u) @ K @ np.diag(v)
347
+ return T
348
+
349
+
350
+ # ============================================================================
351
+ # TWO-SIDED TRANSPORT (Paper Section 4.2, Equations 8, 13)
352
+ # ============================================================================
353
+
354
+ def _correlation_distance(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
355
+ """
356
+ Compute correlation distance matrix between two sets of activation vectors.
357
+
358
+ cost[i, j] = 1 - pearson_correlation(X[:, i], Y[:, j])
359
+
360
+ X: [num_samples, dim_x] β€” activations from source
361
+ Y: [num_samples, dim_y] β€” activations from target
362
+ Returns: [dim_x, dim_y] cost matrix
363
+ """
364
+ # Standardise each neuron's activations across samples
365
+ X_norm = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-8)
366
+ Y_norm = (Y - Y.mean(axis=0)) / (Y.std(axis=0) + 1e-8)
367
+
368
+ # Pearson correlation between each pair of neurons
369
+ corr = X_norm.T @ Y_norm / X.shape[0] # [dim_x, dim_y]
370
+
371
+ # Correlation distance
372
+ cost = 1.0 - corr
373
+ return cost.astype(np.float32)
374
+
375
+
376
+ def _get_layer_index(module_name: str) -> Optional[int]:
377
+ """Extract layer index from a module name like 'model.layers.5.self_attn.q_proj'."""
378
+ parts = module_name.split(".")
379
+ for i, part in enumerate(parts):
380
+ if part == "layers" and i + 1 < len(parts):
381
+ try:
382
+ return int(parts[i + 1])
383
+ except ValueError:
384
+ pass
385
+ return None
386
+
387
+
388
+ def _get_module_type(module_name: str) -> str:
389
+ """Extract module type from name like 'model.layers.5.self_attn.q_proj' β†’ 'q_proj'."""
390
+ return module_name.split(".")[-1]
391
+
392
+
393
+ def _group_activations_by_layer(
394
+ activations: dict,
395
+ side: str = "pre",
396
+ ) -> dict:
397
+ """
398
+ Group activation tensors by layer index.
399
+
400
+ Returns: {layer_idx: {module_type: activation_tensor}}
401
+ """
402
+ grouped = {}
403
+ suffix = f".{side}"
404
+ for key, tensor in activations.items():
405
+ if not key.endswith(suffix):
406
+ continue
407
+ # Remove the .pre/.post suffix to get module name
408
+ module_name = key[: -len(suffix)]
409
+ layer_idx = _get_layer_index(module_name)
410
+ module_type = _get_module_type(module_name)
411
+ if layer_idx is not None:
412
+ if layer_idx not in grouped:
413
+ grouped[layer_idx] = {}
414
+ grouped[layer_idx][module_type] = tensor.numpy()
415
+ return grouped
416
+
417
+
418
+ def compute_transport_plans(
419
+ source_activations: dict,
420
+ target_activations: dict,
421
+ cfg: MergeConfig,
422
+ ) -> dict:
423
+ """
424
+ Compute two-sided optimal transport plans between source and target.
425
+
426
+ Paper Section 4.2 β€” Two-sided transport:
427
+ 1. For each (source_layer, target_layer) pair and each projection type:
428
+ - Compute Q_in from pre-activation features (Eq 8 applied to inputs)
429
+ - Compute Q_out from post-activation features (Eq 8 applied to outputs)
430
+ 2. Derive layer-level costs from Q_in and Q_out β†’ P_pre and P_post
431
+ 3. Combine: P_eff[β„“,m] = sqrt(P_pre[β„“,m] Β· P_post[β„“,m]) (Eq 13)
432
+
433
+ Returns:
434
+ Dict with:
435
+ 'P_eff': [n_target_layers, n_source_layers] effective transport plan
436
+ 'Q_in': {(src_layer, tgt_layer, module_type): Q matrix} β€” input-side neuron plans
437
+ 'Q_out': {(src_layer, tgt_layer, module_type): Q matrix} β€” output-side neuron plans
438
+ 'source_layers': sorted list of source layer indices
439
+ 'target_layers': sorted list of target layer indices
440
+ """
441
+ print("[transport] Computing two-sided transport plans (paper Section 4.2)...")
442
+
443
+ # Group activations by layer
444
+ source_pre = _group_activations_by_layer(source_activations, "pre")
445
+ source_post = _group_activations_by_layer(source_activations, "post")
446
+ target_pre = _group_activations_by_layer(target_activations, "pre")
447
+ target_post = _group_activations_by_layer(target_activations, "post")
448
+
449
+ source_layers = sorted(source_pre.keys())
450
+ target_layers = sorted(target_pre.keys())
451
+
452
+ n_source = len(source_layers)
453
+ n_target = len(target_layers)
454
+
455
+ print(f" Source layers: {n_source}, Target layers: {n_target}")
456
+
457
+ # --- Step 1: Compute Q_in and Q_out for each layer pair ---
458
+ Q_in_matrices = {}
459
+ Q_out_matrices = {}
460
+ layer_costs_pre = np.zeros((n_target, n_source))
461
+ layer_costs_post = np.zeros((n_target, n_source))
462
+
463
+ for ti, tl in enumerate(target_layers):
464
+ for si, sl in enumerate(source_layers):
465
+ # Get all projection types that exist in both
466
+ if tl not in target_pre or sl not in source_pre:
467
+ continue
468
+
469
+ target_modules = set(target_pre.get(tl, {}).keys())
470
+ source_modules = set(source_pre.get(sl, {}).keys())
471
+ common_modules = target_modules & source_modules
472
+
473
+ if not common_modules:
474
+ continue
475
+
476
+ pre_costs = []
477
+ post_costs = []
478
+
479
+ for mod_type in common_modules:
480
+ # --- Q_in: pre-activation (input-side) transport ---
481
+ if (sl in source_pre and mod_type in source_pre[sl] and
482
+ tl in target_pre and mod_type in target_pre[tl]):
483
+ S_pre = source_pre[sl][mod_type]
484
+ T_pre = target_pre[tl][mod_type]
485
+ cost_pre = _correlation_distance(S_pre, T_pre)
486
+
487
+ # Use streaming Sinkhorn for large matrices, basic for small
488
+ if max(cost_pre.shape) > 1024:
489
+ Q = _log_sinkhorn_streaming(
490
+ cost_pre,
491
+ reg=cfg.sinkhorn_reg,
492
+ max_iter=cfg.sinkhorn_inner_iter,
493
+ )
494
+ else:
495
+ Q = _sinkhorn_basic(
496
+ cost_pre,
497
+ reg=cfg.sinkhorn_reg,
498
+ max_iter=cfg.sinkhorn_inner_iter,
499
+ )
500
+ Q_in_matrices[(sl, tl, mod_type)] = Q
501
+ pre_costs.append(cost_pre.mean())
502
+
503
+ # --- Q_out: post-activation (output-side) transport ---
504
+ if (sl in source_post and mod_type in source_post[sl] and
505
+ tl in target_post and mod_type in target_post[tl]):
506
+ S_post = source_post[sl][mod_type]
507
+ T_post = target_post[tl][mod_type]
508
+ cost_post = _correlation_distance(S_post, T_post)
509
+
510
+ if max(cost_post.shape) > 1024:
511
+ Q = _log_sinkhorn_streaming(
512
+ cost_post,
513
+ reg=cfg.sinkhorn_reg,
514
+ max_iter=cfg.sinkhorn_inner_iter,
515
+ )
516
+ else:
517
+ Q = _sinkhorn_basic(
518
+ cost_post,
519
+ reg=cfg.sinkhorn_reg,
520
+ max_iter=cfg.sinkhorn_inner_iter,
521
+ )
522
+ Q_out_matrices[(sl, tl, mod_type)] = Q
523
+ post_costs.append(cost_post.mean())
524
+
525
+ # Average cost across projection types for this layer pair
526
+ if pre_costs:
527
+ layer_costs_pre[ti, si] = np.mean(pre_costs)
528
+ if post_costs:
529
+ layer_costs_post[ti, si] = np.mean(post_costs)
530
+
531
+ if (ti + 1) % 6 == 0:
532
+ print(f" Layer pairs computed: {ti + 1}/{n_target} target layers done")
533
+
534
+ # --- Step 2: Layer-level transport plans P_pre and P_post ---
535
+ print("[transport] Computing layer-level transport plans (P_pre, P_post)...")
536
+
537
+ P_pre = _sinkhorn_basic(
538
+ layer_costs_pre,
539
+ reg=cfg.sinkhorn_layer_reg,
540
+ max_iter=cfg.sinkhorn_outer_iter,
541
+ )
542
+
543
+ P_post = _sinkhorn_basic(
544
+ layer_costs_post,
545
+ reg=cfg.sinkhorn_layer_reg,
546
+ max_iter=cfg.sinkhorn_outer_iter,
547
+ )
548
+
549
+ # --- Step 3: P_eff = sqrt(P_pre Β· P_post) β€” Equation 13 ---
550
+ P_eff = np.sqrt(P_pre * P_post + 1e-10)
551
+
552
+ # Normalise P_eff so each target layer's row sums to 1
553
+ row_sums = P_eff.sum(axis=1, keepdims=True)
554
+ P_eff = P_eff / (row_sums + 1e-10)
555
+
556
+ print(f"[transport] P_eff shape: {P_eff.shape}")
557
+ print(f" P_eff range: [{P_eff.min():.4f}, {P_eff.max():.4f}]")
558
+
559
+ # --- Step 4: Transport sparsification (Appendix A.1) ---
560
+ # "top-k selection strategies at both neuron and transport matrix levels"
561
+ # Keep only the top-k strongest source layers per target layer
562
+ k_layers = min(3, n_source) # Top-3 source layers per target layer
563
+ P_sparse = np.zeros_like(P_eff)
564
+ for i in range(n_target):
565
+ top_k_idx = np.argsort(P_eff[i])[-k_layers:]
566
+ P_sparse[i, top_k_idx] = P_eff[i, top_k_idx]
567
+ # Re-normalise
568
+ row_sums = P_sparse.sum(axis=1, keepdims=True)
569
+ P_sparse = P_sparse / (row_sums + 1e-10)
570
+
571
+ print(f"[transport] Sparsified P: keeping top-{k_layers} source layers per target")
572
+
573
+ return {
574
+ "P_eff": P_sparse,
575
+ "P_eff_dense": P_eff, # Keep dense version for debugging
576
+ "Q_in": Q_in_matrices,
577
+ "Q_out": Q_out_matrices,
578
+ "source_layers": source_layers,
579
+ "target_layers": target_layers,
580
+ "layer_costs_pre": layer_costs_pre,
581
+ "layer_costs_post": layer_costs_post,
582
+ }
583
+
584
+
585
+ # ============================================================================
586
+ # TOP-K MASKED FUSION (Paper Eq 14, Appendix A.5: k=128)
587
+ # ============================================================================
588
+
589
+ def compute_neuron_importance(
590
+ activations: dict,
591
+ layer_idx: int,
592
+ ) -> dict:
593
+ """
594
+ Compute neuron importance scores for top-k selection.
595
+
596
+ Paper Appendix A.5: "choosing the neurons with the highest mean
597
+ activation magnitudes across the calibration set"
598
+
599
+ Returns: {module_type: importance_scores [hidden_dim]}
600
+ """
601
+ importance = {}
602
+ for key, tensor in activations.items():
603
+ if not key.endswith(".post"):
604
+ continue
605
+ module_name = key[:-5] # Remove .post
606
+ idx = _get_layer_index(module_name)
607
+ mod_type = _get_module_type(module_name)
608
+ if idx == layer_idx:
609
+ # Mean activation magnitude across calibration samples
610
+ importance[mod_type] = tensor.abs().mean(dim=0).numpy()
611
+ return importance
612
+
613
+
614
+ def compute_top_k_mask(
615
+ importance_scores: np.ndarray,
616
+ k: int = 128,
617
+ ) -> np.ndarray:
618
+ """
619
+ Create binary mask for top-k most important neurons.
620
+
621
+ Paper Appendix A.5: "we set the default number of neurons to k = 128"
622
+
623
+ Returns: boolean mask [hidden_dim] where True = selected for fusion
624
+ """
625
+ if k >= len(importance_scores):
626
+ return np.ones(len(importance_scores), dtype=bool)
627
+
628
+ threshold_idx = np.argsort(importance_scores)[-k:]
629
+ mask = np.zeros(len(importance_scores), dtype=bool)
630
+ mask[threshold_idx] = True
631
+ return mask
632
+
633
+
634
+ def fuse_weights(
635
+ source_model: AutoModelForCausalLM,
636
+ target_model: AutoModelForCausalLM,
637
+ transport_plans: dict,
638
+ source_config: ModelConfig,
639
+ cfg: MergeConfig,
640
+ target_activations: dict = None,
641
+ ) -> AutoModelForCausalLM:
642
+ """
643
+ Fuse source weights into target using two-sided transport + top-k mask.
644
+
645
+ Paper Equation 14:
646
+ W_fused = W_target + Ξ± Β· M^β„“ βŠ™ (Ξ£_m P_eff[β„“,m] Β· Q_out Β· W_source Β· Q_in^T - W_target)
647
+
648
+ Where:
649
+ - Ξ± is the fusion coefficient (0.05-0.15)
650
+ - M^β„“ is the binary top-k mask (only k=128 neurons get fused)
651
+ - P_eff is the effective layer transport plan
652
+ - Q_out and Q_in are the neuron-level transport matrices
653
+ - The sum is over source layers m
654
+
655
+ Returns: Target model with fused weights
656
+ """
657
+ print(f"\n[transport] Fusing {source_config.name} -> target (two-sided + top-k={cfg.top_k_neurons})")
658
+ alpha = source_config.merge_alpha
659
+ print(f" Alpha: {alpha} (paper range: 0.05-0.15)")
660
+
661
+ source_state = source_model.state_dict()
662
+ target_state = target_model.state_dict()
663
+
664
+ P_eff = transport_plans["P_eff"]
665
+ Q_in = transport_plans["Q_in"]
666
+ Q_out = transport_plans["Q_out"]
667
+ source_layers = transport_plans["source_layers"]
668
+ target_layers = transport_plans["target_layers"]
669
+
670
+ fused_count = 0
671
+ skipped_count = 0
672
+ masked_neurons = 0
673
+
674
+ for ti, tl in enumerate(target_layers):
675
+ # Get the transport weights for this target layer
676
+ layer_transport = P_eff[ti] # [n_source]
677
+
678
+ # Find which source layers contribute significantly
679
+ active_sources = [(si, sl, layer_transport[si])
680
+ for si, sl in enumerate(source_layers)
681
+ if layer_transport[si] > 1e-6]
682
+
683
+ if not active_sources:
684
+ continue
685
+
686
+ # For each projection type in this target layer
687
+ for mod_type in ALL_PROJECTIONS:
688
+ target_key = _find_param_key(target_state, tl, mod_type, "weight")
689
+ if target_key is None:
690
+ continue
691
+
692
+ target_w = target_state[target_key].float()
693
+
694
+ # Compute the transported operator: Ξ£_m P_eff[β„“,m] Β· Q_out Β· W_source Β· Q_in^T
695
+ transported = torch.zeros_like(target_w)
696
+ total_weight = 0.0
697
+
698
+ for si, sl, p_weight in active_sources:
699
+ source_key = _find_source_param_key(
700
+ source_state, sl, mod_type, "weight", source_config
701
+ )
702
+ if source_key is None:
703
+ continue
704
+
705
+ source_w = source_state[source_key].float()
706
+
707
+ # Get Q matrices for this layer pair
708
+ q_in_key = (sl, tl, mod_type)
709
+ q_out_key = (sl, tl, mod_type)
710
+
711
+ q_in = Q_in.get(q_in_key)
712
+ q_out = Q_out.get(q_out_key)
713
+
714
+ if q_in is not None and q_out is not None:
715
+ # Transport: Q_out @ W_source @ Q_in^T
716
+ q_in_t = torch.from_numpy(q_in).float()
717
+ q_out_t = torch.from_numpy(q_out).float()
718
+
719
+ # Handle dimension mismatches via transport plan
720
+ try:
721
+ # q_out: [target_out, source_out], W: [source_out, source_in], q_in: [target_in, source_in]
722
+ # Result: [target_out, target_in]
723
+ transported_w = q_out_t @ source_w.to("cpu") @ q_in_t.T
724
+ transported += p_weight * transported_w.to(target_w.device)
725
+ total_weight += p_weight
726
+ except RuntimeError:
727
+ # Dimension mismatch β€” skip this pair
728
+ skipped_count += 1
729
+ continue
730
+ else:
731
+ # No Q matrices β€” direct mapping if shapes match
732
+ if source_w.shape == target_w.shape:
733
+ transported += p_weight * source_w.to(target_w.device)
734
+ total_weight += p_weight
735
+
736
+ if total_weight < 1e-6:
737
+ skipped_count += 1
738
+ continue
739
+
740
+ # Normalise by total transport weight
741
+ transported = transported / total_weight
742
+
743
+ # --- Apply top-k mask (Equation 14) ---
744
+ # M^β„“ βŠ™ (transported - W_target)
745
+ delta = transported - target_w
746
+
747
+ if target_activations is not None and cfg.top_k_neurons > 0:
748
+ importance = compute_neuron_importance(target_activations, tl)
749
+ if mod_type in importance:
750
+ # Mask on output dimension (rows of weight matrix)
751
+ mask = compute_top_k_mask(importance[mod_type], k=cfg.top_k_neurons)
752
+ mask_tensor = torch.from_numpy(mask).to(target_w.device)
753
+
754
+ # Apply mask: only fuse top-k neurons
755
+ if delta.dim() == 2:
756
+ # Weight matrix: mask rows (output neurons)
757
+ mask_2d = mask_tensor.unsqueeze(1).expand_as(delta)
758
+ delta = delta * mask_2d.float()
759
+ masked_neurons += mask.sum()
760
+ elif delta.dim() == 1:
761
+ # Bias: mask directly
762
+ delta = delta * mask_tensor.float()
763
+ masked_neurons += mask.sum()
764
+
765
+ # Final fusion: W_target + Ξ± Β· masked_delta
766
+ fused_w = target_w + alpha * delta
767
+ target_state[target_key] = fused_w.to(target_state[target_key].dtype)
768
+ fused_count += 1
769
+
770
+ # --- Vision encoder protection ---
771
+ # Restore any vision params that might have been touched
772
+ original_state = target_model.state_dict()
773
+ for key in target_state:
774
+ if any(key.startswith(pfx) for pfx in cfg.vision_skip_prefixes):
775
+ target_state[key] = original_state[key]
776
+
777
+ # --- Thinking mode protection ---
778
+ if cfg.freeze_think_tokens:
779
+ embed_key = "model.embed_tokens.weight"
780
+ if embed_key in target_state and embed_key in original_state:
781
+ for token_id in cfg.think_token_ids:
782
+ if token_id < target_state[embed_key].shape[0]:
783
+ target_state[embed_key][token_id] = original_state[embed_key][token_id]
784
+ print(f" Protected think token {token_id}")
785
+
786
+ # Load fused weights
787
+ target_model.load_state_dict(target_state)
788
+ print(f"[transport] Fused {fused_count} params, skipped {skipped_count}")
789
+ print(f" Top-k masked neurons fused: {masked_neurons}")
790
+
791
+ return target_model
792
+
793
+
794
+ # ============================================================================
795
+ # HELPER: Find parameter keys in state dicts
796
+ # ============================================================================
797
+
798
+ def _find_param_key(state_dict: dict, layer_idx: int, module_type: str, param_type: str = "weight") -> Optional[str]:
799
+ """Find the full parameter key for a given layer, module type, and param type."""
800
+ # Common patterns for transformer models
801
+ patterns = [
802
+ f"model.layers.{layer_idx}.self_attn.{module_type}.{param_type}",
803
+ f"model.layers.{layer_idx}.mlp.{module_type}.{param_type}",
804
+ f"transformer.h.{layer_idx}.attn.{module_type}.{param_type}",
805
+ f"transformer.h.{layer_idx}.mlp.{module_type}.{param_type}",
806
+ ]
807
+ for pattern in patterns:
808
+ if pattern in state_dict:
809
+ return pattern
810
+ return None
811
+
812
+
813
+ def _find_source_param_key(
814
+ state_dict: dict,
815
+ source_layer: int,
816
+ module_type: str,
817
+ param_type: str,
818
+ source_config: ModelConfig,
819
+ ) -> Optional[str]:
820
+ """Find param key in source model, handling architecture differences."""
821
+ # Try standard patterns first
822
+ key = _find_param_key(state_dict, source_layer, module_type, param_type)
823
+ if key:
824
+ return key
825
+
826
+ # Try architecture-specific patterns
827
+ if source_config.architecture == "hybrid_ssm":
828
+ # Falcon uses different naming
829
+ patterns = [
830
+ f"model.layers.{source_layer}.attn.{module_type}.{param_type}",
831
+ f"model.layers.{source_layer}.feed_forward.{module_type}.{param_type}",
832
+ ]
833
+ for pattern in patterns:
834
+ if pattern in state_dict:
835
+ return pattern
836
+
837
+ return None
838
+
839
+
840
+ def _should_skip(key: str, source_config: ModelConfig) -> bool:
841
+ """Determine if a parameter should be skipped during merge."""
842
+ if source_config.skip_embeddings and ("embed_tokens" in key or "lm_head" in key):
843
+ return True
844
+ if "drop_mtp_heads" in source_config.special_handling and "mtp_head" in key:
845
+ return True
846
+ if "drop_mamba_state_params" in source_config.special_handling:
847
+ mamba_keys = ["mamba", "A_log", "dt_proj", ".D"]
848
+ if any(mk in key for mk in mamba_keys):
849
+ return True
850
+ if "drop_qkv_bias" in source_config.special_handling and ".bias" in key:
851
+ if any(proj in key for proj in ["q_proj", "k_proj", "v_proj"]):
852
+ return True
853
+ return False
hugging/td_lang/engine/validate.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Post-Merge Validation β€” run after EVERY merge step.
3
+
4
+ Tests:
5
+ 1. Canary recall (did knowledge transfer?)
6
+ 2. Perplexity check (did we break the model?)
7
+ 3. Thinking mode (do <think> tags still work?)
8
+ 4. Quick reasoning test (can it still think?)
9
+
10
+ Kill criteria: >10% performance drop on any test β†’ abort merge.
11
+ Findings: #11, #22, #25
12
+ """
13
+
14
+ import torch
15
+ import math
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
+
18
+ from .canary import test_all_canaries
19
+ from .config import MergeConfig
20
+
21
+
22
+ def validate_merged_model(
23
+ model: AutoModelForCausalLM,
24
+ tokenizer: AutoTokenizer,
25
+ merged_sources: list[str],
26
+ cfg: MergeConfig,
27
+ baseline_perplexity: float = None,
28
+ ) -> dict:
29
+ """
30
+ Run full validation suite on a merged model.
31
+
32
+ Args:
33
+ model: The merged model to validate
34
+ tokenizer: The tokenizer
35
+ merged_sources: List of source models merged so far
36
+ cfg: Merge configuration
37
+ baseline_perplexity: Perplexity of the target model before merging
38
+
39
+ Returns:
40
+ Dict with test results and overall pass/fail
41
+ """
42
+ print("\n" + "=" * 60)
43
+ print(f"VALIDATION β€” After merging: {', '.join(merged_sources)}")
44
+ print("=" * 60)
45
+
46
+ results = {
47
+ "canary": None,
48
+ "perplexity": None,
49
+ "thinking_mode": None,
50
+ "reasoning": None,
51
+ "overall": False,
52
+ }
53
+
54
+ # --- Test 1: Canary recall ---
55
+ canary_results = test_all_canaries(model, tokenizer, merged_sources)
56
+ passed_canaries = sum(1 for v in canary_results.values() if v)
57
+ total_canaries = len(canary_results)
58
+ results["canary"] = {
59
+ "passed": passed_canaries,
60
+ "total": total_canaries,
61
+ "ok": passed_canaries >= cfg.canary_pass_threshold,
62
+ "details": canary_results,
63
+ }
64
+
65
+ # --- Test 2: Perplexity ---
66
+ perplexity = compute_perplexity(model, tokenizer)
67
+ ppl_ok = True
68
+ if baseline_perplexity is not None:
69
+ ratio = perplexity / baseline_perplexity
70
+ ppl_ok = ratio < cfg.perplexity_threshold
71
+ print(f"\n[validate] Perplexity: {perplexity:.2f} (baseline: {baseline_perplexity:.2f}, ratio: {ratio:.2f})")
72
+ if not ppl_ok:
73
+ print(f"[validate] ⚠ Perplexity ratio {ratio:.2f} exceeds threshold {cfg.perplexity_threshold}")
74
+ else:
75
+ print(f"\n[validate] Perplexity: {perplexity:.2f} (no baseline to compare)")
76
+ results["perplexity"] = {"value": perplexity, "ok": ppl_ok}
77
+
78
+ # --- Test 3: Thinking mode ---
79
+ think_ok = test_thinking_mode(model, tokenizer)
80
+ results["thinking_mode"] = {"ok": think_ok}
81
+
82
+ # --- Test 4: Quick reasoning ---
83
+ reason_ok = test_reasoning(model, tokenizer)
84
+ results["reasoning"] = {"ok": reason_ok}
85
+
86
+ # --- Overall verdict ---
87
+ all_ok = (
88
+ results["canary"]["ok"]
89
+ and results["perplexity"]["ok"]
90
+ and results["thinking_mode"]["ok"]
91
+ and results["reasoning"]["ok"]
92
+ )
93
+ results["overall"] = all_ok
94
+
95
+ # Summary
96
+ print("\n" + "-" * 60)
97
+ print("VALIDATION SUMMARY")
98
+ print("-" * 60)
99
+ print(f" Canary recall: {'βœ“' if results['canary']['ok'] else 'βœ—'} ({passed_canaries}/{total_canaries})")
100
+ print(f" Perplexity: {'βœ“' if ppl_ok else 'βœ—'} ({perplexity:.2f})")
101
+ print(f" Thinking mode: {'βœ“' if think_ok else 'βœ—'}")
102
+ print(f" Reasoning: {'βœ“' if reason_ok else 'βœ—'}")
103
+ print(f" OVERALL: {'βœ“ PASS' if all_ok else 'βœ— FAIL β€” consider aborting'}")
104
+ print("-" * 60)
105
+
106
+ return results
107
+
108
+
109
+ def compute_perplexity(
110
+ model: AutoModelForCausalLM,
111
+ tokenizer: AutoTokenizer,
112
+ test_texts: list[str] = None,
113
+ ) -> float:
114
+ """
115
+ Compute perplexity on a small test set.
116
+
117
+ Lower perplexity = model is more confident about predicting text.
118
+ A big spike after merging means the model was damaged.
119
+ """
120
+ if test_texts is None:
121
+ test_texts = [
122
+ "The quick brown fox jumps over the lazy dog.",
123
+ "In mathematics, a prime number is a natural number greater than 1.",
124
+ "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
125
+ "The theory of general relativity describes gravity as the curvature of spacetime.",
126
+ "To solve 3x + 7 = 22, subtract 7 from both sides to get 3x = 15, then divide by 3.",
127
+ ]
128
+
129
+ model.eval()
130
+ total_loss = 0.0
131
+ total_tokens = 0
132
+
133
+ for text in test_texts:
134
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
135
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
136
+
137
+ with torch.no_grad():
138
+ outputs = model(**inputs, labels=inputs["input_ids"])
139
+ total_loss += outputs.loss.item() * inputs["input_ids"].shape[1]
140
+ total_tokens += inputs["input_ids"].shape[1]
141
+
142
+ avg_loss = total_loss / total_tokens
143
+ perplexity = math.exp(avg_loss)
144
+ return perplexity
145
+
146
+
147
+ def test_thinking_mode(
148
+ model: AutoModelForCausalLM,
149
+ tokenizer: AutoTokenizer,
150
+ ) -> bool:
151
+ """
152
+ Test if the model still uses <think> tags for reasoning.
153
+
154
+ The thinking mode is Qwen3's special feature β€” if it's gone,
155
+ the merge damaged something critical.
156
+ """
157
+ prompt = "Solve step by step: What is 15 Γ— 13?"
158
+
159
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
160
+ with torch.no_grad():
161
+ outputs = model.generate(
162
+ **inputs,
163
+ max_new_tokens=200,
164
+ temperature=0.7,
165
+ do_sample=True,
166
+ )
167
+
168
+ response = tokenizer.decode(outputs[0], skip_special_tokens=False)
169
+
170
+ # Check for thinking tags
171
+ has_think_open = "<think>" in response
172
+ has_think_close = "</think>" in response
173
+ passed = has_think_open and has_think_close
174
+
175
+ print(f"\n[validate] Thinking mode test:")
176
+ print(f" Prompt: {prompt}")
177
+ print(f" Response: {response[:200]}...")
178
+ print(f" <think>: {'βœ“ found' if has_think_open else 'βœ— missing'}")
179
+ print(f" </think>: {'βœ“ found' if has_think_close else 'βœ— missing'}")
180
+ print(f" Status: {'βœ“ PASS' if passed else 'βœ— FAIL'}")
181
+
182
+ return passed
183
+
184
+
185
+ def test_reasoning(
186
+ model: AutoModelForCausalLM,
187
+ tokenizer: AutoTokenizer,
188
+ ) -> bool:
189
+ """
190
+ Quick reasoning sanity check β€” can the model still do basic math?
191
+
192
+ This catches catastrophic failures where the merge produced gibberish.
193
+ """
194
+ prompt = "What is 7 + 8?"
195
+ expected_answer = "15"
196
+
197
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
198
+ with torch.no_grad():
199
+ outputs = model.generate(
200
+ **inputs,
201
+ max_new_tokens=50,
202
+ temperature=0.1,
203
+ do_sample=False,
204
+ )
205
+
206
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
207
+ passed = expected_answer in response
208
+
209
+ print(f"\n[validate] Quick reasoning test:")
210
+ print(f" Prompt: {prompt}")
211
+ print(f" Expected: {expected_answer}")
212
+ print(f" Got: {response}")
213
+ print(f" Status: {'βœ“ PASS' if passed else 'βœ— FAIL'}")
214
+
215
+ return passed
hugging/td_lang/errors.py CHANGED
@@ -88,6 +88,20 @@ COMMON_FIXES = {
88
  "fuse": 'Format: fuse ["model1", "model2"] into target [strategy equal]',
89
  "absorb": 'Format: absorb "model" into target [strength 0.5]',
90
  "schedule": 'Format: schedule "every 6h" { commands... } or schedule "at 02:00" { ... }',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  }
92
 
93
 
 
88
  "fuse": 'Format: fuse ["model1", "model2"] into target [strategy equal]',
89
  "absorb": 'Format: absorb "model" into target [strength 0.5]',
90
  "schedule": 'Format: schedule "every 6h" { commands... } or schedule "at 02:00" { ... }',
91
+ "download": 'Format: download "dataset_name" as alias [split train]',
92
+ "log": 'Format: log "output.txt" (place before commands to capture output)',
93
+ "compare": 'Format: compare target vs "source_model" [questions 50] [-> output.json]',
94
+ "verify": 'Format: verify target on "dataset" [questions 100] [-> output.json]',
95
+ "vote": 'Format: vote target "question" [samples 5] [-> output.json]',
96
+ "prompt": 'Format: prompt target "Think step by step before answering."',
97
+ "distill": 'Format: distill target into "small_model" [steps 200] [-> output_dir]',
98
+ "rollback": "Format: rollback target (reverts to most recent snapshot)",
99
+ "curriculum": 'Format: curriculum target on "dataset" using grpo [levels 3] [steps 64]',
100
+ "star": 'Format: star target on "dataset" [rounds 3] [samples 8]',
101
+ "best_of": 'Format: best_of target on "dataset" [n 8] [steps 32]',
102
+ "exploit": 'Format: exploit target on "dataset" [samples 16] [steps 32] [-> output.jsonl]',
103
+ "arena": 'Format: arena target on "dataset" [rounds 5] [episodes 50] [steps 64] [curiosity 0.3] [-> log.json]',
104
+ "research_arena": 'Format: research_arena target topic "subject" [sources "pubmed"|"web"|"arxiv"] [rounds 5] [episodes 30] [-> log.json]',
105
  }
106
 
107
 
hugging/td_lang/examples/demo_arena.td ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo_arena.td β€” Real RL with memory, curiosity, and anti-lying
2
+ #
3
+ # This is ACTUAL reinforcement learning β€” the model explores challenges,
4
+ # gets immediate reward/punishment, remembers what worked, and trains
5
+ # on its experiences. Unlike best_of/star which just pick good examples,
6
+ # arena makes the model LEARN FROM CONSEQUENCES.
7
+ #
8
+ # Features:
9
+ # - Memory bank: remembers what worked across all rounds
10
+ # - Curiosity bonus: rewarded for trying NEW approaches
11
+ # - Lying punishment: -2.0 for confident wrong answers (worst offence)
12
+ # - Cross-check: creative solutions verified against standard approach
13
+ #
14
+ # The model won't "forget the button makes the door safe" because
15
+ # memory persists. And it won't lie because lying gets punished DOUBLE.
16
+
17
+ load "Qwen/Qwen3-8B" as base
18
+
19
+ # Run the arena: 3 rounds of 30 episodes each
20
+ # Curiosity weight 0.3 = moderate exploration bonus
21
+ arena base on "gsm8k" rounds 3 episodes 30 steps 32 curiosity 0.3 -> arena_log.json
22
+
23
+ # After arena training, evaluate the result
24
+ eval base -> arena_eval.json
25
+
26
+ # Save the improved model
27
+ snapshot base
28
+ commit base
hugging/td_lang/examples/demo_intelligence.td ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo: Phase 11 Intelligence β€” vote, prompt, distill, rollback
2
+ # Shows all 4 new commands + the upgraded mega-diagnose
3
+
4
+ load "Qwen/Qwen3-VL-8B-Instruct" as base
5
+
6
+ # Attach a chain-of-thought prompt (makes it think step by step)
7
+ prompt base "Think step by step before answering. Show your reasoning."
8
+
9
+ # Mega diagnose: self-diagnosis + domain profiling + layer speed
10
+ diagnose base -> diagnosis_report.json
11
+
12
+ # Merge in reasoning
13
+ merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
14
+
15
+ # Use majority voting on a hard question
16
+ vote base "What is 847 * 23? Show your work." samples 5 -> vote_result.json
17
+
18
+ # Snapshot before training (so rollback works)
19
+ snapshot base
20
+
21
+ # Train on weaknesses found by diagnose
22
+ train base on "gsm8k" using grpo steps 64
23
+
24
+ # Eval to check if training helped
25
+ eval base -> eval_after.json
26
+
27
+ # If training made things worse, undo it
28
+ if eval_passed base {
29
+ commit base
30
+ } else {
31
+ rollback base
32
+ }
33
+
34
+ # Create a fast student model for easy questions
35
+ distill base into "Qwen/Qwen3-1.7B" steps 100 -> student_model/
hugging/td_lang/examples/demo_research_arena.td ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo_research_arena.td β€” Real RL on ANY topic using real-world sources
2
+ #
3
+ # This is the research gauntlet. The model gets thrown into a maze
4
+ # built from REAL papers and knowledge. It has to navigate perfectly.
5
+ #
6
+ # How it works:
7
+ # 1. Pulls real papers about your topic (PubMed, arXiv, web, or local files)
8
+ # 2. Extracts verifiable facts from those papers
9
+ # 3. Builds increasingly hard questions from the real knowledge
10
+ # 4. Model must answer correctly β€” EVERY claim checked against sources
11
+ # 5. Difficulty ESCALATES each round (stricter checking, harder questions)
12
+ # 6. Memory persists β€” model remembers what it learned
13
+ # 7. Lying = double punishment, curiosity = bonus
14
+ #
15
+ # The maze shrinks each round:
16
+ # Round 1: Easy questions, 30% strictness, full path width
17
+ # Round 2: Medium questions, 55% strictness, 75% path width
18
+ # Round 3: Hard questions, 80% strictness, 50% path width
19
+ # ...and so on. Miss a single fact = punishment.
20
+
21
+ load "Qwen/Qwen3-8B" as base
22
+
23
+ # Example 1: Medical research (uses PubMed for real papers)
24
+ research_arena base topic "cancer immunotherapy mechanisms" sources "pubmed" rounds 4 episodes 25 steps 48 curiosity 0.3 difficulty_scale 0.25 -> research_log.json
25
+
26
+ # After the gauntlet, see how the model performs
27
+ eval base -> post_research_eval.json
28
+ snapshot base
29
+ commit base
hugging/td_lang/examples/demo_rl.td ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo: Phase 12 RL & Fine-Tuning β€” curriculum, star, best_of, exploit
2
+ # Shows all 4 new training methods + reward_contract wiring
3
+
4
+ load "Qwen/Qwen3-VL-8B-Instruct" as base
5
+
6
+ # Define what counts as "correct" (these verifiers wire into GRPO training)
7
+ reward_contract {
8
+ verifiers = [code_compiles, math_correct, no_hallucination]
9
+ min_reward = 0.3
10
+ }
11
+
12
+ # Step 1: Curriculum training β€” start easy, get harder
13
+ curriculum base on "gsm8k" using grpo levels 3 steps 64
14
+
15
+ # Step 2: STaR β€” learn from own correct reasoning chains
16
+ star base on "gsm8k" rounds 3 samples 8
17
+
18
+ # Step 3: Best-of-N β€” generate 8 answers per question, train on the best
19
+ best_of base on "openai/humaneval" n 8 steps 32
20
+
21
+ # Step 4: EXPLOIT β€” controlled reward hacking
22
+ # Generate 16 diverse solutions per problem, keep ALL correct ones
23
+ # Even ugly shortcuts β€” if the answer is right, the method is valid
24
+ exploit base on "gsm8k" samples 16 steps 32 -> exploit_results.jsonl
25
+
26
+ # Verify the model actually got smarter
27
+ eval base -> eval_after_rl.json
28
+
29
+ # Save if good
30
+ snapshot base
31
+ commit base
hugging/td_lang/examples/demo_toolbox.td ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo: Phase 10 Toolbox β€” download, log, compare, verify
2
+ # Shows all 4 new commands working together
3
+
4
+ log "toolbox_run.txt"
5
+
6
+ load "Qwen/Qwen3-VL-8B-Instruct" as base
7
+
8
+ # Download a dataset for verification
9
+ download "gsm8k" as math_data
10
+ download "openai/humaneval" as code_data split test
11
+
12
+ # Merge in reasoning ability
13
+ merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
14
+
15
+ # Compare: does the merged model remember what DeepSeek knew?
16
+ compare base vs "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" questions 30 -> compare_results.json
17
+
18
+ # Verify: are the answers actually correct?
19
+ verify base on "gsm8k" questions 50 -> verify_math.json
20
+ verify base on "openai/humaneval" questions 25 -> verify_code.json
21
+
22
+ # Eval and commit if good
23
+ eval base -> eval_report.json
24
+ commit base
hugging/td_lang/grammar.py CHANGED
@@ -15,6 +15,7 @@ from .ast_nodes import (
15
  DataContractBlock,
16
  DebateCmd,
17
  DiagnoseCmd,
 
18
  EditCmd,
19
  EvalCmd,
20
  FuseCmd,
@@ -26,13 +27,26 @@ from .ast_nodes import (
26
  MergeCmd,
27
  NotifyCmd,
28
  OnErrorBlock,
 
29
  PruneCmd,
30
  RepeatBlock,
31
  ReportCmd,
32
  ResetCmd,
33
  RewardContractBlock,
 
 
 
 
 
 
 
34
  SaveCmd,
35
  ScheduleCmd,
 
 
 
 
 
36
  SetupBlock,
37
  SnapshotCmd,
38
  SynthCmd,
@@ -80,6 +94,20 @@ TD_GRAMMAR = r"""
80
  | setup_block
81
  | on_error_block
82
  | schedule_cmd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  // ======================== PHASE 1 COMMANDS ========================
85
 
@@ -153,7 +181,11 @@ TD_GRAMMAR = r"""
153
  | fork_cmd | reset_cmd | prune_cmd | edit_cmd
154
  | fuse_cmd | absorb_cmd | snapshot_cmd | report_cmd
155
  | notify_cmd | save_cmd
156
- | repeat_block_cmd | if_block_cmd | schedule_cmd) _NL*
 
 
 
 
157
 
158
  // ======================== PHASE 6 β€” EASY MERGE COMMANDS ========================
159
 
@@ -233,6 +265,87 @@ TD_GRAMMAR = r"""
233
  // schedule "after 30m" { commands... }
234
  schedule_cmd: "schedule" string "{" _NL* body_cmd+ _NL* "}"
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  // ======================== SHARED RULES ========================
237
 
238
  // List of names: [name1, name2, name3]
@@ -468,6 +581,251 @@ class TDTransformer(Transformer):
468
  def schedule_cmd(self, timing: str, *body_cmds) -> ScheduleCmd:
469
  return ScheduleCmd(timing=timing, body=list(body_cmds))
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  # --- Phase 6: Easy Merge Commands ---
472
 
473
  def fuse_cmd(self, sources: list[str], target: str, *opts) -> FuseCmd:
@@ -688,6 +1046,8 @@ class TDTransformer(Transformer):
688
  program.setup = item
689
  elif isinstance(item, OnErrorBlock):
690
  program.on_error = item
 
 
691
  else:
692
  program.commands.append(item)
693
  return program
 
15
  DataContractBlock,
16
  DebateCmd,
17
  DiagnoseCmd,
18
+ DistillCmd,
19
  EditCmd,
20
  EvalCmd,
21
  FuseCmd,
 
27
  MergeCmd,
28
  NotifyCmd,
29
  OnErrorBlock,
30
+ PromptBlock,
31
  PruneCmd,
32
  RepeatBlock,
33
  ReportCmd,
34
  ResetCmd,
35
  RewardContractBlock,
36
+ RollbackCmd,
37
+ CurriculumCmd,
38
+ StarCmd,
39
+ BestOfCmd,
40
+ ExploitCmd,
41
+ ArenaCmd,
42
+ ResearchArenaCmd,
43
  SaveCmd,
44
  ScheduleCmd,
45
+ DownloadCmd,
46
+ LogBlock,
47
+ CompareCmd,
48
+ VerifyCmd,
49
+ VoteCmd,
50
  SetupBlock,
51
  SnapshotCmd,
52
  SynthCmd,
 
94
  | setup_block
95
  | on_error_block
96
  | schedule_cmd
97
+ | download_cmd
98
+ | log_block
99
+ | compare_cmd
100
+ | verify_cmd
101
+ | vote_cmd
102
+ | prompt_cmd
103
+ | distill_cmd
104
+ | rollback_cmd
105
+ | curriculum_cmd
106
+ | star_cmd
107
+ | best_of_cmd
108
+ | exploit_cmd
109
+ | arena_cmd
110
+ | research_arena_cmd
111
 
112
  // ======================== PHASE 1 COMMANDS ========================
113
 
 
181
  | fork_cmd | reset_cmd | prune_cmd | edit_cmd
182
  | fuse_cmd | absorb_cmd | snapshot_cmd | report_cmd
183
  | notify_cmd | save_cmd
184
+ | repeat_block_cmd | if_block_cmd | schedule_cmd
185
+ | download_cmd | compare_cmd | verify_cmd
186
+ | vote_cmd | prompt_cmd | distill_cmd | rollback_cmd
187
+ | curriculum_cmd | star_cmd | best_of_cmd | exploit_cmd
188
+ | arena_cmd | research_arena_cmd) _NL*
189
 
190
  // ======================== PHASE 6 β€” EASY MERGE COMMANDS ========================
191
 
 
265
  // schedule "after 30m" { commands... }
266
  schedule_cmd: "schedule" string "{" _NL* body_cmd+ _NL* "}"
267
 
268
+ // ======================== PHASE 10 - TOOLBOX ========================
269
+
270
+ // download "gsm8k" as math_data [split train]
271
+ download_cmd: "download" string "as" IDENT (download_split)?
272
+ download_split: "split" IDENT
273
+
274
+ // log "training_log.txt"
275
+ log_block: "log" string
276
+
277
+ // compare target vs "source_model" [questions 50] [-> output.json]
278
+ compare_cmd: "compare" IDENT "vs" string (compare_questions)? (compare_output)?
279
+ compare_questions: "questions" INT
280
+ compare_output: "->" FILEPATH
281
+
282
+ // verify target on "dataset" [questions 100] [-> results.json]
283
+ verify_cmd: "verify" IDENT "on" string (verify_questions)? (verify_output)?
284
+ verify_questions: "questions" INT
285
+ verify_output: "->" FILEPATH
286
+
287
+ // ======================== PHASE 11 - INTELLIGENCE ========================
288
+
289
+ // vote target "question" [samples 5] [-> output.json]
290
+ vote_cmd: "vote" IDENT string (vote_samples)? (vote_output)?
291
+ vote_samples: "samples" INT
292
+ vote_output: "->" FILEPATH
293
+
294
+ // prompt target "system prompt text"
295
+ prompt_cmd: "prompt" IDENT string
296
+
297
+ // distill target into "small_model" [steps 200] [-> output_dir]
298
+ distill_cmd: "distill" IDENT "into" string (distill_steps)? (distill_output)?
299
+ distill_steps: "steps" INT
300
+ distill_output: "->" FILEPATH
301
+
302
+ // rollback target
303
+ rollback_cmd: "rollback" IDENT
304
+
305
+ // ======================== PHASE 12 - RL & FINE-TUNING ========================
306
+
307
+ // curriculum target on "dataset" using method [levels 3] [steps 64]
308
+ curriculum_cmd: "curriculum" IDENT "on" string "using" IDENT (curriculum_opt)*
309
+ curriculum_opt: "levels" INT -> curriculum_levels
310
+ | "steps" INT -> curriculum_steps
311
+
312
+ // star target on "dataset" [rounds 3] [samples 8]
313
+ star_cmd: "star" IDENT "on" string (star_opt)*
314
+ star_opt: "rounds" INT -> star_rounds
315
+ | "samples" INT -> star_samples
316
+
317
+ // best_of target on "dataset" [n 8] [steps 32]
318
+ best_of_cmd: "best_of" IDENT "on" string (best_of_opt)*
319
+ best_of_opt: "n" INT -> best_of_n
320
+ | "steps" INT -> best_of_steps
321
+
322
+ // exploit target on "dataset" [samples 16] [steps 32] [-> output.jsonl]
323
+ exploit_cmd: "exploit" IDENT "on" string (exploit_opt)*
324
+ exploit_opt: "samples" INT -> exploit_samples
325
+ | "steps" INT -> exploit_steps
326
+ | "->" FILEPATH -> exploit_output
327
+
328
+ // ======================== PHASE 13 - REAL RL (ARENA) ========================
329
+
330
+ // arena target on "dataset" [rounds 5] [episodes 50] [steps 64] [curiosity 0.3] [-> log.json]
331
+ arena_cmd: "arena" IDENT "on" string (arena_opt)*
332
+ arena_opt: "rounds" INT -> arena_rounds
333
+ | "episodes" INT -> arena_episodes
334
+ | "steps" INT -> arena_steps
335
+ | "curiosity" NUMBER -> arena_curiosity
336
+ | "->" FILEPATH -> arena_output
337
+
338
+ // research_arena target topic "subject" [sources "web"|"pubmed"|"arxiv"|path]
339
+ // [rounds 5] [episodes 30] [steps 64] [curiosity 0.3] [difficulty_scale 0.25] [-> log.json]
340
+ research_arena_cmd: "research_arena" IDENT "topic" string (ra_opt)*
341
+ ra_opt: "sources" string -> ra_sources
342
+ | "rounds" INT -> ra_rounds
343
+ | "episodes" INT -> ra_episodes
344
+ | "steps" INT -> ra_steps
345
+ | "curiosity" NUMBER -> ra_curiosity
346
+ | "difficulty_scale" NUMBER -> ra_difficulty
347
+ | "->" FILEPATH -> ra_output
348
+
349
  // ======================== SHARED RULES ========================
350
 
351
  // List of names: [name1, name2, name3]
 
581
  def schedule_cmd(self, timing: str, *body_cmds) -> ScheduleCmd:
582
  return ScheduleCmd(timing=timing, body=list(body_cmds))
583
 
584
+ # --- Phase 10: Toolbox ---
585
+
586
+ def download_cmd(self, dataset: str, alias: str, split: str | None = None) -> DownloadCmd:
587
+ cmd = DownloadCmd(dataset=dataset, alias=alias)
588
+ if isinstance(split, tuple) and split[0] == "split":
589
+ cmd.split = split[1]
590
+ elif isinstance(split, str):
591
+ cmd.split = split
592
+ return cmd
593
+
594
+ def download_split(self, value: str) -> tuple:
595
+ return ("split", value)
596
+
597
+ def log_block(self, filepath: str) -> LogBlock:
598
+ return LogBlock(filepath=filepath)
599
+
600
+ def compare_cmd(self, target: str, source: str, *opts) -> CompareCmd:
601
+ cmd = CompareCmd(target=target, source=source)
602
+ for opt in opts:
603
+ if isinstance(opt, tuple):
604
+ key, val = opt
605
+ if key == "questions":
606
+ cmd.questions = val
607
+ elif key == "output":
608
+ cmd.output = val
609
+ return cmd
610
+
611
+ def compare_questions(self, value: int) -> tuple:
612
+ return ("questions", value)
613
+
614
+ def compare_output(self, filepath: str) -> tuple:
615
+ return ("output", filepath)
616
+
617
+ def verify_cmd(self, target: str, dataset: str, *opts) -> VerifyCmd:
618
+ cmd = VerifyCmd(target=target, dataset=dataset)
619
+ for opt in opts:
620
+ if isinstance(opt, tuple):
621
+ key, val = opt
622
+ if key == "questions":
623
+ cmd.questions = val
624
+ elif key == "output":
625
+ cmd.output = val
626
+ return cmd
627
+
628
+ def verify_questions(self, value: int) -> tuple:
629
+ return ("questions", value)
630
+
631
+ def verify_output(self, filepath: str) -> tuple:
632
+ return ("output", filepath)
633
+
634
+ # --- Phase 11: Intelligence Commands ---
635
+
636
+ def vote_cmd(self, target: str, question: str, *opts) -> VoteCmd:
637
+ cmd = VoteCmd(target=target, question=question)
638
+ for opt in opts:
639
+ if isinstance(opt, tuple):
640
+ key, val = opt
641
+ if key == "samples":
642
+ cmd.samples = val
643
+ elif key == "output":
644
+ cmd.output = val
645
+ return cmd
646
+
647
+ def vote_samples(self, value: int) -> tuple:
648
+ return ("samples", value)
649
+
650
+ def vote_output(self, filepath: str) -> tuple:
651
+ return ("output", filepath)
652
+
653
+ def prompt_cmd(self, target: str, text: str) -> PromptBlock:
654
+ return PromptBlock(target=target, text=text)
655
+
656
+ def distill_cmd(self, teacher: str, student: str, *opts) -> DistillCmd:
657
+ cmd = DistillCmd(teacher=teacher, student=student)
658
+ for opt in opts:
659
+ if isinstance(opt, tuple):
660
+ key, val = opt
661
+ if key == "steps":
662
+ cmd.steps = val
663
+ elif key == "output":
664
+ cmd.output = val
665
+ return cmd
666
+
667
+ def distill_steps(self, value: int) -> tuple:
668
+ return ("steps", value)
669
+
670
+ def distill_output(self, filepath: str) -> tuple:
671
+ return ("output", filepath)
672
+
673
+ def rollback_cmd(self, target: str) -> RollbackCmd:
674
+ return RollbackCmd(target=target)
675
+
676
+ # --- Phase 12: RL & Fine-Tuning Commands ---
677
+
678
+ def curriculum_cmd(self, target: str, dataset: str, method: str, *opts) -> CurriculumCmd:
679
+ cmd = CurriculumCmd(target=target, dataset=dataset, method=method)
680
+ for opt in opts:
681
+ if isinstance(opt, tuple):
682
+ key, val = opt
683
+ if key == "levels":
684
+ cmd.levels = val
685
+ elif key == "steps":
686
+ cmd.steps = val
687
+ return cmd
688
+
689
+ def curriculum_levels(self, value: int) -> tuple:
690
+ return ("levels", value)
691
+
692
+ def curriculum_steps(self, value: int) -> tuple:
693
+ return ("steps", value)
694
+
695
+ def star_cmd(self, target: str, dataset: str, *opts) -> StarCmd:
696
+ cmd = StarCmd(target=target, dataset=dataset)
697
+ for opt in opts:
698
+ if isinstance(opt, tuple):
699
+ key, val = opt
700
+ if key == "rounds":
701
+ cmd.rounds = val
702
+ elif key == "samples":
703
+ cmd.samples = val
704
+ return cmd
705
+
706
+ def star_rounds(self, value: int) -> tuple:
707
+ return ("rounds", value)
708
+
709
+ def star_samples(self, value: int) -> tuple:
710
+ return ("samples", value)
711
+
712
+ def best_of_cmd(self, target: str, dataset: str, *opts) -> BestOfCmd:
713
+ cmd = BestOfCmd(target=target, dataset=dataset)
714
+ for opt in opts:
715
+ if isinstance(opt, tuple):
716
+ key, val = opt
717
+ if key == "n":
718
+ cmd.n = val
719
+ elif key == "steps":
720
+ cmd.steps = val
721
+ return cmd
722
+
723
+ def best_of_n(self, value: int) -> tuple:
724
+ return ("n", value)
725
+
726
+ def best_of_steps(self, value: int) -> tuple:
727
+ return ("steps", value)
728
+
729
+ def exploit_cmd(self, target: str, dataset: str, *opts) -> ExploitCmd:
730
+ cmd = ExploitCmd(target=target, dataset=dataset)
731
+ for opt in opts:
732
+ if isinstance(opt, tuple):
733
+ key, val = opt
734
+ if key == "samples":
735
+ cmd.samples = val
736
+ elif key == "steps":
737
+ cmd.steps = val
738
+ elif key == "output":
739
+ cmd.output = val
740
+ return cmd
741
+
742
+ def exploit_samples(self, value: int) -> tuple:
743
+ return ("samples", value)
744
+
745
+ def exploit_steps(self, value: int) -> tuple:
746
+ return ("steps", value)
747
+
748
+ def exploit_output(self, filepath: str) -> tuple:
749
+ return ("output", filepath)
750
+
751
+ # --- Phase 13: Real RL (Arena) ---
752
+
753
+ def arena_cmd(self, target: str, dataset: str, *opts) -> ArenaCmd:
754
+ cmd = ArenaCmd(target=target, dataset=dataset)
755
+ for opt in opts:
756
+ if isinstance(opt, tuple):
757
+ key, val = opt
758
+ if key == "rounds":
759
+ cmd.rounds = val
760
+ elif key == "episodes":
761
+ cmd.episodes = val
762
+ elif key == "steps":
763
+ cmd.steps = val
764
+ elif key == "curiosity":
765
+ cmd.curiosity = val
766
+ elif key == "output":
767
+ cmd.output = val
768
+ return cmd
769
+
770
+ def arena_rounds(self, value: int) -> tuple:
771
+ return ("rounds", value)
772
+
773
+ def arena_episodes(self, value: int) -> tuple:
774
+ return ("episodes", value)
775
+
776
+ def arena_steps(self, value: int) -> tuple:
777
+ return ("steps", value)
778
+
779
+ def arena_curiosity(self, value: float) -> tuple:
780
+ return ("curiosity", value)
781
+
782
+ def arena_output(self, filepath: str) -> tuple:
783
+ return ("output", filepath)
784
+
785
+ # --- Phase 13: Research Arena ---
786
+
787
+ def research_arena_cmd(self, target: str, topic: str, *opts) -> ResearchArenaCmd:
788
+ cmd = ResearchArenaCmd(target=target, topic=topic)
789
+ for opt in opts:
790
+ if isinstance(opt, tuple):
791
+ key, val = opt
792
+ if key == "sources":
793
+ cmd.sources = val
794
+ elif key == "rounds":
795
+ cmd.rounds = val
796
+ elif key == "episodes":
797
+ cmd.episodes = val
798
+ elif key == "steps":
799
+ cmd.steps = val
800
+ elif key == "curiosity":
801
+ cmd.curiosity = val
802
+ elif key == "difficulty_scale":
803
+ cmd.difficulty_scale = val
804
+ elif key == "output":
805
+ cmd.output = val
806
+ return cmd
807
+
808
+ def ra_sources(self, value: str) -> tuple:
809
+ return ("sources", value)
810
+
811
+ def ra_rounds(self, value: int) -> tuple:
812
+ return ("rounds", value)
813
+
814
+ def ra_episodes(self, value: int) -> tuple:
815
+ return ("episodes", value)
816
+
817
+ def ra_steps(self, value: int) -> tuple:
818
+ return ("steps", value)
819
+
820
+ def ra_curiosity(self, value: float) -> tuple:
821
+ return ("curiosity", value)
822
+
823
+ def ra_difficulty(self, value: float) -> tuple:
824
+ return ("difficulty_scale", value)
825
+
826
+ def ra_output(self, filepath: str) -> tuple:
827
+ return ("output", filepath)
828
+
829
  # --- Phase 6: Easy Merge Commands ---
830
 
831
  def fuse_cmd(self, sources: list[str], target: str, *opts) -> FuseCmd:
 
1046
  program.setup = item
1047
  elif isinstance(item, OnErrorBlock):
1048
  program.on_error = item
1049
+ elif isinstance(item, LogBlock):
1050
+ program.log = item
1051
  else:
1052
  program.commands.append(item)
1053
  return program
hugging/td_lang/td_lang/.DS_Store ADDED
Binary file (6.15 kB). View file
 
hugging/td_lang/td_lang/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Lang β€” Domain-specific language for Time Dilation project.
3
+
4
+ Compiles .td files into executable Python. Self-contained β€” no external deps.
5
+ The merge/heal/validate engine (formerly td_fuse) lives in td_lang.engine/.
6
+
7
+ Architecture:
8
+ td_lang/
9
+ β”œβ”€β”€ __init__.py <- This file
10
+ β”œβ”€β”€ __main__.py <- Entry point for python -m td_lang
11
+ β”œβ”€β”€ grammar.py <- Lark grammar + parse tree transformer
12
+ β”œβ”€β”€ ast_nodes.py <- Dataclass AST nodes for each command
13
+ β”œβ”€β”€ compiler.py <- AST -> Python code generation
14
+ β”œβ”€β”€ executor.py <- Run compiled code, track lineage
15
+ β”œβ”€β”€ cli.py <- Command-line interface
16
+ β”œβ”€β”€ errors.py <- Custom exceptions
17
+ β”œβ”€β”€ engine/ <- Merge/heal/validate runtime (was td_fuse)
18
+ β”‚ β”œβ”€β”€ config.py <- Model configs, merge order, hyperparameters
19
+ β”‚ β”œβ”€β”€ merge.py <- Sequential merge orchestrator
20
+ β”‚ β”œβ”€β”€ heal.py <- QLoRA healing fine-tune
21
+ β”‚ β”œβ”€β”€ validate.py <- Post-merge validation
22
+ β”‚ β”œβ”€β”€ transport.py <- Optimal transport wrapper
23
+ β”‚ β”œβ”€β”€ techniques.py <- ARM, OTMF, RAM, Theseus, Mergeability
24
+ β”‚ └── canary.py <- Canary injection + testing
25
+ └── examples/
26
+ β”œβ”€β”€ demo_merge.td <- Basic merge example
27
+ β”œβ”€β”€ demo_heal.td <- Merge + heal example
28
+ β”œβ”€β”€ demo_full.td <- Full pipeline with gates + budget
29
+ └── ... <- 22 example .td files
30
+
31
+ Phase 1: load, merge, heal, eval, commit
32
+ Phase 2: diagnose, synth, train, debate
33
+ Phase 3: fork, reset, prune, edit
34
+ Phase 4: snapshot, report, data_contract, reward_contract
35
+ Phase 5: CLI polish, --version, info command, --verbose
36
+ Phase 6: fuse, absorb (easy merge)
37
+ Phase 7: repeat, if/else (loop control)
38
+ Phase 8: setup, on_error, notify, save (autopilot)
39
+ Phase 9: schedule (time-based execution)
40
+ Phase 10: download, log, compare, verify (toolbox)
41
+ Phase 11: vote, prompt, distill, rollback (intelligence)
42
+ Phase 12: curriculum, star, best_of, exploit (RL & fine-tuning)
43
+ Phase 13: arena (real RL with memory, curiosity, anti-lying, cross-check)
44
+ Engine upgrades: QLoRA training, self-contained eval, model-generated synth problems
45
+ Mega diagnose: self-diagnosis + domain profiling + layer speed testing
46
+
47
+ Designed from interviews test_14 (10 commands) and test_17 (ForgeSpec 2.0).
48
+ """
49
+
50
+ from .grammar import parse_td_file, parse_td_string # noqa: F401
51
+ from .compiler import compile_program # noqa: F401
52
+ from .executor import TDExecutor, check_td_file, compile_td_file, run_td_file # noqa: F401
53
+
54
+ __version__ = "0.2.0"
55
+ __author__ = "Milan (TD Project)"
56
+
57
+ __all__ = [
58
+ "parse_td_file",
59
+ "parse_td_string",
60
+ "compile_program",
61
+ "TDExecutor",
62
+ "check_td_file",
63
+ "compile_td_file",
64
+ "run_td_file",
65
+ "__version__",
66
+ "__author__",
67
+ ]
hugging/td_lang/td_lang/__main__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Entry point for python -m td_lang."""
2
+
3
+ from .cli import main
4
+
5
+ main()
hugging/td_lang/td_lang/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.04 kB). View file
 
hugging/td_lang/td_lang/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (2.02 kB). View file
 
hugging/td_lang/td_lang/__pycache__/__main__.cpython-310.pyc ADDED
Binary file (254 Bytes). View file
 
hugging/td_lang/td_lang/__pycache__/__main__.cpython-314.pyc ADDED
Binary file (262 Bytes). View file
 
hugging/td_lang/td_lang/__pycache__/ast_nodes.cpython-310.pyc ADDED
Binary file (21.3 kB). View file
 
hugging/td_lang/td_lang/__pycache__/ast_nodes.cpython-314.pyc ADDED
Binary file (18.7 kB). View file
 
hugging/td_lang/td_lang/__pycache__/cli.cpython-310.pyc ADDED
Binary file (6.82 kB). View file
 
hugging/td_lang/td_lang/__pycache__/cli.cpython-314.pyc ADDED
Binary file (10.5 kB). View file
 
hugging/td_lang/td_lang/__pycache__/compiler.cpython-310.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61328de4293774f4fc0e899e0eec00b64338be0dcf0fd3e68feaeaefc4c1edd5
3
+ size 193126
hugging/td_lang/td_lang/__pycache__/compiler.cpython-314.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bef7388fef05cdd8ee4edcc72a4b8907c8637caa22cfc802da044470a515c92
3
+ size 162778
hugging/td_lang/td_lang/__pycache__/errors.cpython-310.pyc ADDED
Binary file (5.59 kB). View file
 
hugging/td_lang/td_lang/__pycache__/errors.cpython-314.pyc ADDED
Binary file (6.34 kB). View file
 
hugging/td_lang/td_lang/__pycache__/executor.cpython-310.pyc ADDED
Binary file (5.94 kB). View file
 
hugging/td_lang/td_lang/__pycache__/executor.cpython-314.pyc ADDED
Binary file (9.49 kB). View file
 
hugging/td_lang/td_lang/__pycache__/grammar.cpython-310.pyc ADDED
Binary file (37.9 kB). View file
 
hugging/td_lang/td_lang/__pycache__/grammar.cpython-314.pyc ADDED
Binary file (37.8 kB). View file
 
hugging/td_lang/td_lang/ast_nodes.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Lang AST Nodes β€” Dataclass containers for each parsed command.
3
+
4
+ Each .td command becomes one of these nodes after parsing.
5
+ Phase 1 nodes are compiled into runnable Python; Phase 2 nodes are stubs so
6
+ the compiler can reject them with a clear error until they are implemented.
7
+ """
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import Any, List, Optional
11
+
12
+
13
+ # ============================================================================
14
+ # PHASE 1 COMMANDS
15
+ # ============================================================================
16
+
17
+ @dataclass
18
+ class LoadCmd:
19
+ """Load a model and give it a name.
20
+
21
+ Example: load "Qwen/Qwen3-VL-8B-Instruct" as base
22
+ """
23
+ model_ref: str # HuggingFace path or local path
24
+ alias: str # Name to use in the rest of the script
25
+
26
+
27
+ @dataclass
28
+ class MergeCmd:
29
+ """Merge a source model into a target using a method.
30
+
31
+ Example: merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
32
+ """
33
+ source: str # Model path or alias to merge from
34
+ target: str # Alias to merge into (must be loaded first)
35
+ method: str # "transport", "slerp", "ties", "dare"
36
+ strength: float = 0.5 # 0.0 = keep target, 1.0 = keep source
37
+
38
+
39
+ @dataclass
40
+ class HealCmd:
41
+ """Run QLoRA healing fine-tune on a model.
42
+
43
+ Example: heal base lora_r 32 epochs 2
44
+ """
45
+ target: str # Alias of model to heal
46
+ lora_r: int = 32 # LoRA rank (higher = more capacity)
47
+ epochs: int = 2 # Training epochs
48
+
49
+
50
+ @dataclass
51
+ class EvalCmd:
52
+ """Run validation/evaluation on a model.
53
+
54
+ Example: eval base on "pile_sample" -> report.json
55
+ """
56
+ target: str # Alias of model to evaluate
57
+ dataset: Optional[str] = None # Optional dataset name/path
58
+ output: Optional[str] = None # Optional output file path
59
+
60
+
61
+ @dataclass
62
+ class CommitCmd:
63
+ """Save model checkpoint, optionally requiring gates to pass.
64
+
65
+ Example: commit base if [canary, perplexity, thinking_mode]
66
+ """
67
+ target: str # Alias of model to commit
68
+ gates: Optional[list[str]] = None # Gate names that must pass
69
+
70
+
71
+ # ============================================================================
72
+ # PHASE 2 COMMANDS (placeholders β€” structure ready, not wired up yet)
73
+ # ============================================================================
74
+
75
+ @dataclass
76
+ class SynthCmd:
77
+ """Generate synthetic training data from a model. (Phase 2)"""
78
+ target: str
79
+ source: str
80
+ filter_method: Optional[str] = None
81
+ output: Optional[str] = None
82
+
83
+
84
+ @dataclass
85
+ class TrainCmd:
86
+ """Train a model on a dataset. (Phase 2)"""
87
+ target: str
88
+ dataset: str
89
+ method: str = "grpo" # "grpo", "sft", "dpo"
90
+ steps: Optional[int] = None
91
+ learning_rate: Optional[float] = None
92
+
93
+
94
+ @dataclass
95
+ class DebateCmd:
96
+ """Generate multi-answer debate for preference pairs. (Phase 2)"""
97
+ target: str
98
+ rounds: int = 3
99
+ candidates: int = 8
100
+ output: Optional[str] = None
101
+
102
+
103
+ @dataclass
104
+ class DiagnoseCmd:
105
+ """Ask model what it's bad at β€” self-diagnosis. (Phase 2)"""
106
+ target: str
107
+ output: Optional[str] = None
108
+
109
+
110
+ @dataclass
111
+ class ForkCmd:
112
+ """Branch current model weights for parallel experiments. (Phase 3)
113
+
114
+ Example: fork base as experiment_v2
115
+ Cheap fork: copies manifest + adapters, shares base weights (default).
116
+ """
117
+ source: str # Alias of model to fork from
118
+ alias: str # Name for the new branch
119
+
120
+
121
+ @dataclass
122
+ class ResetCmd:
123
+ """Revert model to a previous checkpoint. (Phase 3)
124
+
125
+ Example: reset base to "checkpoint_042"
126
+ Deletes current model, clears CUDA cache, reloads from disk.
127
+ Must also reset optimizer state.
128
+ """
129
+ target: str # Alias of model to reset
130
+ checkpoint: str # Checkpoint name/path to revert to
131
+
132
+
133
+ @dataclass
134
+ class PruneCmd:
135
+ """Structural pruning β€” remove low-utility neurons/heads. (Phase 3)
136
+
137
+ Example: prune base using wanda aggressiveness 0.2
138
+ Safe zone: ~20% max (LLM-Pruner paper). Language backbone only.
139
+ """
140
+ target: str
141
+ method: str = "wanda" # "wanda", "magnitude", "taylor"
142
+ aggressiveness: float = 0.2 # Fraction to remove (0.0-1.0)
143
+
144
+
145
+ @dataclass
146
+ class EditCmd:
147
+ """Surgical LoRA/DoRA editing on specific layers. (Phase 3)
148
+
149
+ Example: edit base layers 16-28 using lora lr 1e-4
150
+ "Try before buy": eval with adapter enabled vs disabled before merging.
151
+ """
152
+ target: str
153
+ layers: str = "all" # "all", "16-28", single number
154
+ method: str = "lora" # "lora" or "dora"
155
+ learning_rate: Optional[float] = None
156
+
157
+
158
+ # ============================================================================
159
+ # PHASE 4 COMMANDS β€” Contracts, Lineage, Economics (ForgeSpec 2.0, test_17)
160
+ # ============================================================================
161
+
162
+ # ============================================================================
163
+ # PHASE 7 β€” LOOP CONTROL (repeat, if/else)
164
+ # ============================================================================
165
+
166
+ @dataclass
167
+ class RepeatBlock:
168
+ """Repeat a block of commands N times. (Phase 7 β€” Loop Control)
169
+
170
+ Example:
171
+ repeat 5 {
172
+ diagnose base
173
+ synth base from base
174
+ train base on "data.jsonl" using grpo steps 64
175
+ eval base
176
+ }
177
+ """
178
+ count: int # Number of iterations
179
+ body: List[Any] = field(default_factory=list) # Commands inside the block
180
+
181
+
182
+ @dataclass
183
+ class IfBlock:
184
+ """Conditional execution based on last eval result. (Phase 7 β€” Loop Control)
185
+
186
+ Example:
187
+ if eval_passed {
188
+ commit base
189
+ } else {
190
+ reset base to "last_good"
191
+ }
192
+
193
+ Condition checks the most recent eval result for the target.
194
+ """
195
+ condition: str # "eval_passed", "gate_passed", etc.
196
+ target: Optional[str] = None # Which model's eval to check
197
+ then_body: List[Any] = field(default_factory=list)
198
+ else_body: List[Any] = field(default_factory=list)
199
+
200
+
201
+ @dataclass
202
+ class FuseCmd:
203
+ """Fuse multiple models into a target in one shot. (Phase 6 β€” Easy Merge)
204
+
205
+ Example: fuse [deepseek-r1, mimo-7b, llama-3.1] into base
206
+ Auto-picks Transport and Merge, auto-sets per-model strength.
207
+ Handles cross-architecture merging (all 5 source models have different archs).
208
+ """
209
+ sources: list[str] # List of model names/paths to fuse in
210
+ target: str # Alias to merge into (must be loaded)
211
+ method: str = "transport" # Default: transport and merge (cross-arch)
212
+ strategy: str = "equal" # "equal" (same strength each), "weighted", "sequential"
213
+
214
+
215
+ @dataclass
216
+ class AbsorbCmd:
217
+ """Absorb a single model into target β€” simplified merge. (Phase 6 β€” Easy Merge)
218
+
219
+ Example: absorb "deepseek-ai/DeepSeek-R1" into base strength 0.5
220
+ One-liner for the common case of merging one model in.
221
+ """
222
+ source: str # Model path or HF ID
223
+ target: str # Alias to merge into
224
+ strength: float = 0.5 # 0.0=keep target, 1.0=keep source, default balanced
225
+
226
+
227
+ @dataclass
228
+ class SnapshotCmd:
229
+ """Save a content-hashed snapshot of model state for lineage tracking. (Phase 4)
230
+
231
+ Example: snapshot base -> snapshots/
232
+ Creates a content-addressed directory: snapshots/<sha256_prefix>/
233
+ Contains: model state, adapter state, prune spec, eval report, manifest.
234
+ """
235
+ target: str
236
+ output: Optional[str] = None # Output directory (default: td_lang_outputs/snapshots/)
237
+
238
+
239
+ @dataclass
240
+ class ReportCmd:
241
+ """Generate an economics report for this run. (Phase 4)
242
+
243
+ Example: report -> economics.json
244
+ Tracks: GPU hours, cost estimate, tokens processed, experiments run,
245
+ time per command, cost breakdown by phase.
246
+ """
247
+ output: Optional[str] = None # Output file path
248
+
249
+
250
+ # ============================================================================
251
+ # PHASE 8 β€” AUTOPILOT (setup, notify, save, on_error, resume)
252
+ # ============================================================================
253
+
254
+ @dataclass
255
+ class NotifyCmd:
256
+ """Send a notification via ntfy.sh. (Phase 8 β€” Autopilot)
257
+
258
+ Example: notify "Training complete!"
259
+ Uses curl to POST to the configured ntfy topic.
260
+ """
261
+ message: str
262
+
263
+
264
+ @dataclass
265
+ class SaveCmd:
266
+ """Save/upload model to cloud storage via rclone. (Phase 8 β€” Autopilot)
267
+
268
+ Example: save base to "gdrive:TD/models/v1"
269
+ Uses rclone to copy model checkpoint to Google Drive (or any rclone remote).
270
+ """
271
+ target: str # Alias of model to save
272
+ destination: str # rclone destination path
273
+
274
+
275
+ @dataclass
276
+ class SetupBlock:
277
+ """Auto-install dependencies and configure environment. (Phase 8 β€” Autopilot)
278
+
279
+ Example:
280
+ setup {
281
+ pip = [torch, transformers, peft, bitsandbytes, trl]
282
+ hf_token = env
283
+ notify = "ntfy.sh/my_ai"
284
+ }
285
+ """
286
+ pip_packages: list[str] = field(default_factory=list)
287
+ hf_token: Optional[str] = None # "env" = read HF_TOKEN from env
288
+ notify_url: Optional[str] = None # ntfy.sh topic URL
289
+
290
+
291
+ @dataclass
292
+ class OnErrorBlock:
293
+ """Crash recovery behavior. (Phase 8 β€” Autopilot)
294
+
295
+ Example:
296
+ on_error {
297
+ retry = 3
298
+ fallback = reduce_batch
299
+ notify = true
300
+ }
301
+ """
302
+ retry: int = 3 # Number of retries per failed step
303
+ fallback: str = "reduce_batch" # "reduce_batch", "skip", "snapshot_and_stop"
304
+ notify: bool = True # Send ntfy notification on error
305
+
306
+
307
+ # ============================================================================
308
+ # PHASE 9 β€” SCHEDULE (time-based execution)
309
+ # ============================================================================
310
+
311
+ @dataclass
312
+ class ScheduleCmd:
313
+ """Schedule a block of commands to run at a specific time or interval. (Phase 9)
314
+
315
+ Examples:
316
+ schedule "every 6h" { diagnose base; train base ... }
317
+ schedule "at 02:00" { train base on "data.jsonl" using grpo }
318
+ schedule "after 30m" { eval base -> results.json }
319
+
320
+ Patterns:
321
+ "every Nh/Nm" β€” repeat every N hours/minutes
322
+ "at HH:MM" β€” run once at that time
323
+ "after Nh/Nm" β€” delay then run once
324
+ """
325
+ timing: str # "every 6h", "at 02:00", "after 30m"
326
+ body: List[Any] = field(default_factory=list) # Commands inside the block
327
+
328
+
329
+ # ============================================================================
330
+ # PHASE 10 - TOOLBOX (download, log, compare, verify)
331
+ # ============================================================================
332
+
333
+ @dataclass
334
+ class DownloadCmd:
335
+ """Download a dataset from HuggingFace. (Phase 10)
336
+
337
+ Example: download "gsm8k" as math_data
338
+ Pulls a dataset from HuggingFace and stores it for training/eval.
339
+ """
340
+ dataset: str # HuggingFace dataset path
341
+ alias: str # Name to reference it later
342
+ split: str = "train" # Which split to download
343
+
344
+
345
+ @dataclass
346
+ class LogBlock:
347
+ """Save all pipeline output to a log file. (Phase 10)
348
+
349
+ Example: log "training_log.txt"
350
+ Everything printed to console also goes to this file.
351
+ """
352
+ filepath: str # Path to save log
353
+
354
+
355
+ @dataclass
356
+ class CompareCmd:
357
+ """Compare source model vs merged model - knowledge retention test. (Phase 10)
358
+
359
+ Example: compare base vs "deepseek-ai/DeepSeek-R1" questions 50
360
+ Tests both models on the same questions and shows what % the merged
361
+ model retained from the source. Proves the merge actually worked.
362
+ """
363
+ target: str # The merged model alias
364
+ source: str # Source model to compare against (HF path)
365
+ questions: int = 50 # Number of test questions
366
+ output: Optional[str] = None # Optional output file
367
+
368
+
369
+ @dataclass
370
+ class VerifyCmd:
371
+ """Verify model answers are actually correct. (Phase 10)
372
+
373
+ Example: verify base on "gsm8k" questions 100 -> verify_results.json
374
+ Runs the model on questions with KNOWN correct answers and checks
375
+ if the model got them right. Returns accuracy percentage.
376
+ """
377
+ target: str # Model alias to test
378
+ dataset: str # Dataset with known answers
379
+ questions: int = 100 # Number of questions to test
380
+ output: Optional[str] = None # Optional output file
381
+
382
+
383
+ # ============================================================================
384
+ # PHASE 11 - INTELLIGENCE (vote, prompt, distill, rollback)
385
+ # ============================================================================
386
+
387
+ @dataclass
388
+ class VoteCmd:
389
+ """Majority voting - generate N answers, pick the one most agree on. (Phase 11)
390
+
391
+ Example: vote base "What is 15 * 23?" samples 5
392
+ Generates N answers to the same question, then picks the most common one.
393
+ Proven to boost accuracy 10-20% with zero training.
394
+ """
395
+ target: str # Model alias
396
+ question: str # Question to vote on
397
+ samples: int = 5 # Number of answers to generate
398
+ output: Optional[str] = None # Optional output file
399
+
400
+
401
+ @dataclass
402
+ class PromptBlock:
403
+ """Attach a system prompt or chain-of-thought template to a model. (Phase 11)
404
+
405
+ Example:
406
+ prompt base "Think step by step before answering."
407
+ Makes the model use this system prompt for all future generations.
408
+ """
409
+ target: str # Model alias to attach prompt to
410
+ text: str # The system prompt text
411
+
412
+
413
+ @dataclass
414
+ class DistillCmd:
415
+ """Distill a big model's knowledge into a smaller one. (Phase 11)
416
+
417
+ Example: distill base into "Qwen/Qwen3-1.7B" steps 200 -> student_model/
418
+ Takes the big model's best answers and trains the small model on them.
419
+ You get a fast model for easy questions, full model for hard ones.
420
+ """
421
+ teacher: str # The big model alias (source of knowledge)
422
+ student: str # The small model HF path
423
+ steps: int = 200 # Training steps
424
+ output: Optional[str] = None # Where to save the student model
425
+
426
+
427
+ @dataclass
428
+ class RollbackCmd:
429
+ """Undo the last training step. (Phase 11)
430
+
431
+ Example: rollback base
432
+ Reverts to the most recent snapshot. If training made things worse,
433
+ one command brings it back.
434
+ """
435
+ target: str # Model alias to rollback
436
+
437
+
438
+ # ============================================================================
439
+ # PHASE 12 - RL & FINE-TUNING (curriculum, star, best_of, exploit)
440
+ # ============================================================================
441
+
442
+ @dataclass
443
+ class CurriculumCmd:
444
+ """Progressive difficulty training - start easy, get harder. (Phase 12)
445
+
446
+ Example: curriculum base on "gsm8k" using grpo levels 3 steps 64
447
+ Splits dataset by difficulty, trains on easy first, then medium, then hard.
448
+ Each level only starts when the model passes the previous one.
449
+ """
450
+ target: str # Model alias
451
+ dataset: str # Dataset to train on
452
+ method: str = "grpo" # Training method
453
+ levels: int = 3 # Number of difficulty levels
454
+ steps: int = 64 # Steps per level
455
+
456
+
457
+ @dataclass
458
+ class StarCmd:
459
+ """Self-Taught Reasoner - train on own correct reasoning chains. (Phase 12)
460
+
461
+ Example: star base on "gsm8k" rounds 3 samples 8
462
+ Generate N solutions per problem. Keep the ones with correct answers.
463
+ Train on the correct reasoning chains. Repeat.
464
+ The model literally learns from its own successes.
465
+ """
466
+ target: str # Model alias
467
+ dataset: str # Dataset with known answers
468
+ rounds: int = 3 # Number of STaR iterations
469
+ samples: int = 8 # Solutions to generate per problem
470
+
471
+
472
+ @dataclass
473
+ class BestOfCmd:
474
+ """Generate N answers, score all, train on the best. (Phase 12)
475
+
476
+ Example: best_of base on "gsm8k" n 8 steps 32
477
+ For each training problem: generate N answers, score them all,
478
+ keep only the best one, train on that. Like vote but for training.
479
+ 80-90% of RLHF gains at 5-30% of the cost (test_16).
480
+ """
481
+ target: str # Model alias
482
+ dataset: str # Dataset to train on
483
+ n: int = 8 # How many answers to generate per problem
484
+ steps: int = 32 # Training steps on the filtered data
485
+
486
+
487
+ @dataclass
488
+ class ExploitCmd:
489
+ """Controlled reward hacking - keep ALL correct solutions regardless of method. (Phase 12)
490
+
491
+ Example: exploit base on "gsm8k" samples 16 -> exploit_data.jsonl
492
+ Generate many diverse solutions (high temp). Only filter: is the answer correct?
493
+ Keep ugly solutions, shortcuts, weird reasoning - as long as the answer is right.
494
+ Train on the diverse set so the model learns multiple paths to correct answers.
495
+ The "hacks" often turn out to be genuinely clever shortcuts.
496
+ """
497
+ target: str # Model alias
498
+ dataset: str # Dataset with verifiable answers
499
+ samples: int = 16 # Solutions per problem (higher = more diversity)
500
+ steps: int = 32 # Training steps on the exploited data
501
+ output: Optional[str] = None # Save the exploit data for inspection
502
+
503
+
504
+ @dataclass
505
+ class ArenaCmd:
506
+ """Real RL with environment, memory, curiosity, and anti-lying. (Phase 13)
507
+
508
+ The model enters an arena of challenges. For each challenge:
509
+ 1. It tries to solve it (exploration)
510
+ 2. Gets immediate reward/punishment (+1 correct, -1 wrong, -2 lying)
511
+ 3. Remembers what worked and didn't (memory bank persists across episodes)
512
+ 4. Gets curiosity bonus for trying NEW approaches
513
+ 5. Creative solutions get cross-checked against standard approaches
514
+
515
+ Example: arena base on "gsm8k" rounds 5 episodes 50 steps 64 curiosity 0.3
516
+ """
517
+ target: str # Model alias
518
+ dataset: str # Dataset with verifiable answers
519
+ rounds: int = 5 # RL rounds (re-train after each)
520
+ episodes: int = 50 # Challenges per round
521
+ steps: int = 64 # Training steps per round
522
+ curiosity: float = 0.3 # Curiosity bonus weight
523
+ output: Optional[str] = None # Save arena log
524
+
525
+
526
+ @dataclass
527
+ class ResearchArenaCmd:
528
+ """Research arena β€” RL on ANY topic using real-world knowledge. (Phase 13)
529
+
530
+ Unlike arena (which uses a pre-made dataset), research_arena:
531
+ 1. Takes a TOPIC string ("cancer biology", "number theory", anything)
532
+ 2. Pulls real papers/sources about that topic (web, arxiv, pubmed, local files)
533
+ 3. Extracts verifiable facts/claims from those sources
534
+ 4. Builds increasingly hard questions from the real knowledge
535
+ 5. Runs the model through the gauntlet, checking EVERY claim against sources
536
+ 6. Difficulty ESCALATES on failure (fewer hints, stricter checking, harder questions)
537
+ 7. Memory persists so it doesn't forget what it learned
538
+ 8. Lying gets punished DOUBLE, curiosity rewarded
539
+
540
+ Example: research_arena base topic "cancer biology" sources "pubmed" rounds 5
541
+ """
542
+ target: str # Model alias
543
+ topic: str # Research topic (any field)
544
+ sources: str = "web" # Where to pull knowledge: "web", "pubmed", "arxiv", or filepath
545
+ rounds: int = 5 # RL rounds (difficulty increases each round)
546
+ episodes: int = 30 # Questions per round
547
+ steps: int = 64 # Training steps per round
548
+ curiosity: float = 0.3 # Curiosity bonus weight
549
+ difficulty_scale: float = 0.25 # How much harder each round gets (0.25 = 25% harder)
550
+ output: Optional[str] = None # Save research log
551
+
552
+
553
+ # ============================================================================
554
+ # BLOCKS (gates, budget, contracts, etc.)
555
+ # ============================================================================
556
+
557
+ @dataclass
558
+ class GateBlock:
559
+ """Validation gates that must pass before commit.
560
+
561
+ Example:
562
+ gate {
563
+ must_pass = [canary, perplexity, thinking_mode]
564
+ }
565
+ """
566
+ must_pass: list[str] = field(default_factory=list)
567
+
568
+
569
+ @dataclass
570
+ class BudgetBlock:
571
+ """Resource budget β€” compiler refuses plans that exceed limits.
572
+
573
+ Example:
574
+ budget {
575
+ max_gpu_hours = 8
576
+ max_cost = 50.00
577
+ }
578
+ """
579
+ max_gpu_hours: Optional[float] = None
580
+ max_cost: Optional[float] = None
581
+ max_tokens: Optional[int] = None
582
+ max_experiments: Optional[int] = None
583
+
584
+
585
+ @dataclass
586
+ class DataContractBlock:
587
+ """Schema enforcement on training data. (Phase 4, ForgeSpec 2.0)
588
+
589
+ Example:
590
+ data_contract {
591
+ required_fields = [prompt, response]
592
+ min_samples = 100
593
+ max_perplexity = 50.0
594
+ }
595
+
596
+ Compiler checks training data at synth/train time.
597
+ """
598
+ required_fields: list[str] = field(default_factory=list)
599
+ min_samples: Optional[int] = None
600
+ max_perplexity: Optional[float] = None
601
+
602
+
603
+ @dataclass
604
+ class RewardContractBlock:
605
+ """Verified reward definitions β€” what counts as "correct". (Phase 4, ForgeSpec 2.0)
606
+
607
+ Example:
608
+ reward_contract {
609
+ verifiers = [code_compiles, math_correct, no_hallucination]
610
+ min_reward = 0.3
611
+ }
612
+
613
+ Used by train (GRPO) to enforce reward quality.
614
+ No learned reward model β€” verified rewards only (test_16).
615
+ """
616
+ verifiers: list[str] = field(default_factory=list)
617
+ min_reward: Optional[float] = None
618
+
619
+
620
+ # ============================================================================
621
+ # TOP-LEVEL PROGRAM
622
+ # ============================================================================
623
+
624
+ @dataclass
625
+ class TDProgram:
626
+ """A complete parsed .td file β€” commands in order plus global blocks."""
627
+
628
+ commands: List[Any] = field(default_factory=list)
629
+ gates: Optional[GateBlock] = None
630
+ budget: Optional[BudgetBlock] = None
631
+ data_contract: Optional[DataContractBlock] = None
632
+ reward_contract: Optional[RewardContractBlock] = None
633
+ setup: Optional[SetupBlock] = None
634
+ on_error: Optional[OnErrorBlock] = None
635
+ log: Optional[LogBlock] = None
636
+ source_file: Optional[str] = None
637
+
638
+
639
+ __all__ = [
640
+ "LoadCmd",
641
+ "MergeCmd",
642
+ "HealCmd",
643
+ "EvalCmd",
644
+ "CommitCmd",
645
+ "SynthCmd",
646
+ "TrainCmd",
647
+ "DebateCmd",
648
+ "DiagnoseCmd",
649
+ "ForkCmd",
650
+ "ResetCmd",
651
+ "PruneCmd",
652
+ "EditCmd",
653
+ "RepeatBlock",
654
+ "IfBlock",
655
+ "FuseCmd",
656
+ "AbsorbCmd",
657
+ "SnapshotCmd",
658
+ "ReportCmd",
659
+ "NotifyCmd",
660
+ "SaveCmd",
661
+ "SetupBlock",
662
+ "OnErrorBlock",
663
+ "GateBlock",
664
+ "BudgetBlock",
665
+ "DataContractBlock",
666
+ "RewardContractBlock",
667
+ "ScheduleCmd",
668
+ "DownloadCmd",
669
+ "LogBlock",
670
+ "CompareCmd",
671
+ "VerifyCmd",
672
+ "VoteCmd",
673
+ "PromptBlock",
674
+ "DistillCmd",
675
+ "RollbackCmd",
676
+ "CurriculumCmd",
677
+ "StarCmd",
678
+ "BestOfCmd",
679
+ "ExploitCmd",
680
+ "ArenaCmd",
681
+ "ResearchArenaCmd",
682
+ "TDProgram",
683
+ ]
hugging/td_lang/td_lang/cli.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Lang CLI β€” Command-line interface for .td files.
3
+
4
+ Usage:
5
+ python -m td_lang run examples/demo_merge.td # Compile + execute
6
+ python -m td_lang compile examples/demo_merge.td # Compile only (outputs .py)
7
+ python -m td_lang check examples/demo_merge.td # Syntax check only
8
+ python -m td_lang info examples/demo_merge.td # Show plan without compiling
9
+ python -m td_lang --version # Show version
10
+ """
11
+
12
+ import argparse
13
+ import sys
14
+
15
+ from . import __version__
16
+ from .executor import TDExecutor
17
+ from .errors import TDLangError
18
+ from .grammar import parse_td_file
19
+ from .ast_nodes import (
20
+ LoadCmd, MergeCmd, HealCmd, EvalCmd, CommitCmd,
21
+ SynthCmd, TrainCmd, DebateCmd, DiagnoseCmd,
22
+ ForkCmd, ResetCmd, PruneCmd, EditCmd,
23
+ FuseCmd, AbsorbCmd, RepeatBlock, IfBlock,
24
+ NotifyCmd, SaveCmd, ScheduleCmd,
25
+ DownloadCmd, LogBlock, CompareCmd, VerifyCmd,
26
+ VoteCmd, PromptBlock, DistillCmd, RollbackCmd,
27
+ CurriculumCmd, StarCmd, BestOfCmd, ExploitCmd, ArenaCmd, ResearchArenaCmd,
28
+ SnapshotCmd, ReportCmd,
29
+ )
30
+
31
+
32
+ # Phase labels for info command
33
+ _PHASE_MAP = {
34
+ LoadCmd: ("1", "load"),
35
+ MergeCmd: ("1", "merge"),
36
+ HealCmd: ("1", "heal"),
37
+ EvalCmd: ("1", "eval"),
38
+ CommitCmd: ("1", "commit"),
39
+ SynthCmd: ("2", "synth"),
40
+ TrainCmd: ("2", "train"),
41
+ DebateCmd: ("2", "debate"),
42
+ DiagnoseCmd: ("2", "diagnose"),
43
+ ForkCmd: ("3", "fork"),
44
+ ResetCmd: ("3", "reset"),
45
+ PruneCmd: ("3", "prune"),
46
+ EditCmd: ("3", "edit"),
47
+ FuseCmd: ("6", "fuse"),
48
+ AbsorbCmd: ("6", "absorb"),
49
+ RepeatBlock: ("7", "repeat"),
50
+ IfBlock: ("7", "if"),
51
+ NotifyCmd: ("8", "notify"),
52
+ SaveCmd: ("8", "save"),
53
+ SnapshotCmd: ("4", "snapshot"),
54
+ ReportCmd: ("4", "report"),
55
+ ScheduleCmd: ("9", "schedule"),
56
+ DownloadCmd: ("10", "download"),
57
+ CompareCmd: ("10", "compare"),
58
+ VerifyCmd: ("10", "verify"),
59
+ VoteCmd: ("11", "vote"),
60
+ PromptBlock: ("11", "prompt"),
61
+ DistillCmd: ("11", "distill"),
62
+ RollbackCmd: ("11", "rollback"),
63
+ CurriculumCmd: ("12", "curriculum"),
64
+ StarCmd: ("12", "star"),
65
+ BestOfCmd: ("12", "best_of"),
66
+ ExploitCmd: ("12", "exploit"),
67
+ ArenaCmd: ("13", "arena"),
68
+ ResearchArenaCmd: ("13", "research_arena"),
69
+ }
70
+
71
+
72
+ def parse_args() -> argparse.Namespace:
73
+ """Parse command-line arguments."""
74
+ parser = argparse.ArgumentParser(
75
+ description="TD Lang β€” compile and run .td files for Time Dilation",
76
+ formatter_class=argparse.RawDescriptionHelpFormatter,
77
+ epilog="""
78
+ Examples:
79
+ python -m td_lang check examples/demo_merge.td # Check syntax
80
+ python -m td_lang compile examples/demo_merge.td # Compile to .py
81
+ python -m td_lang run examples/demo_merge.td # Compile + run
82
+ python -m td_lang run examples/demo_merge.td --dry # Compile only
83
+ python -m td_lang info examples/demo_merge.td # Show plan summary
84
+ """,
85
+ )
86
+
87
+ parser.add_argument(
88
+ "--version",
89
+ action="version",
90
+ version=f"td_lang {__version__}",
91
+ )
92
+
93
+ parser.add_argument(
94
+ "action",
95
+ choices=["check", "compile", "run", "info"],
96
+ help="What to do: check (syntax), compile (.py), run (compile+execute), info (show plan)",
97
+ )
98
+
99
+ parser.add_argument(
100
+ "file",
101
+ type=str,
102
+ help="Path to the .td file",
103
+ )
104
+
105
+ parser.add_argument(
106
+ "--output",
107
+ type=str,
108
+ default="td_lang_outputs",
109
+ help="Output directory (default: td_lang_outputs)",
110
+ )
111
+
112
+ parser.add_argument(
113
+ "--dry",
114
+ action="store_true",
115
+ help="With 'run': compile but don't execute",
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--verbose", "-v",
120
+ action="store_true",
121
+ help="Show extra detail (compiled Python, full AST, etc.)",
122
+ )
123
+
124
+ return parser.parse_args()
125
+
126
+
127
+ def print_banner():
128
+ """Print the td_lang banner."""
129
+ banner = f"""
130
+ ╔═══════════════════════════════════════╗
131
+ β•‘ β•‘
132
+ β•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β–ˆβ–ˆβ•— β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•—β•‘
133
+ β•‘ β•šβ•β•β–ˆβ–ˆβ•”β•β•β•β–ˆβ–ˆβ•”β•β•β–ˆβ–ˆβ•— β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•”β•β•β•β•β•β•‘
134
+ β•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ–ˆβ•‘
135
+ β•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ•‘
136
+ β•‘ β–ˆβ–ˆβ•‘ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β• β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•— β•šβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ•”β•β•‘
137
+ β•‘ β•šβ•β• β•šβ•β•β•β•β•β• β•šβ•β•β•β•β•β• β•šβ•β•β•β•β•β•β•‘
138
+ β•‘ β•‘
139
+ β•‘ TD Lang v{__version__} β€” .td file compiler β•‘
140
+ β•‘ β•‘
141
+ β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
142
+ """
143
+ print(banner)
144
+
145
+
146
+ def print_info(filepath: str) -> None:
147
+ """Show what a .td file does without compiling β€” human-readable plan summary."""
148
+ program = parse_td_file(filepath)
149
+
150
+ print(f"\n File: {filepath}")
151
+ print(f" Commands: {len(program.commands)}")
152
+
153
+ if program.gates:
154
+ print(f" Gates: {', '.join(program.gates.must_pass)}")
155
+ if program.budget:
156
+ parts = []
157
+ if program.budget.max_gpu_hours is not None:
158
+ parts.append(f"{program.budget.max_gpu_hours} GPU hrs")
159
+ if program.budget.max_cost is not None:
160
+ parts.append(f"${program.budget.max_cost}")
161
+ print(f" Budget: {', '.join(parts)}")
162
+ if program.data_contract:
163
+ print(f" Data contract: fields={program.data_contract.required_fields}")
164
+ if program.reward_contract:
165
+ print(f" Reward contract: verifiers={program.reward_contract.verifiers}")
166
+
167
+ print("\n Plan:")
168
+ for i, cmd in enumerate(program.commands, 1):
169
+ phase, name = _PHASE_MAP.get(type(cmd), ("?", type(cmd).__name__))
170
+ target = getattr(cmd, 'target', getattr(cmd, 'alias', ''))
171
+ detail = ""
172
+ if hasattr(cmd, 'method'):
173
+ detail += f" method={cmd.method}"
174
+ if hasattr(cmd, 'source') and name in ("merge", "synth"):
175
+ detail += f" from={cmd.source}"
176
+ if hasattr(cmd, 'layers') and cmd.layers != "all":
177
+ detail += f" layers={cmd.layers}"
178
+ if hasattr(cmd, 'output') and cmd.output:
179
+ detail += f" -> {cmd.output}"
180
+ print(f" {i}. [P{phase}] {name} {target}{detail}")
181
+
182
+ print()
183
+
184
+
185
+ def main():
186
+ """Main entry point for td_lang CLI."""
187
+ args = parse_args()
188
+ print_banner()
189
+
190
+ executor = TDExecutor(output_dir=args.output)
191
+
192
+ try:
193
+ if args.action == "info":
194
+ print_info(args.file)
195
+
196
+ elif args.action == "check":
197
+ program = executor.check(args.file)
198
+ print("\n[td_lang] File is valid!")
199
+
200
+ elif args.action == "compile":
201
+ py_path = executor.compile(args.file)
202
+ print(f"\n[td_lang] Generated: {py_path}")
203
+ print("[td_lang] You can run it with: python", py_path)
204
+ if args.verbose:
205
+ print("\n--- Generated Python ---")
206
+ print(py_path.read_text())
207
+ print("--- End ---")
208
+
209
+ elif args.action == "run":
210
+ result = executor.run(args.file, dry_run=args.dry)
211
+ if result["status"] == "success":
212
+ sys.exit(0)
213
+ elif result["status"] == "dry_run":
214
+ sys.exit(0)
215
+ else:
216
+ sys.exit(1)
217
+
218
+ except TDLangError as e:
219
+ print(f"\n[td_lang] ERROR: {e}")
220
+ sys.exit(1)
221
+
222
+ except FileNotFoundError:
223
+ print(f"\n[td_lang] ERROR: File not found: {args.file}")
224
+ print("[td_lang] Check the path and try again.")
225
+ sys.exit(1)
226
+
227
+ except KeyboardInterrupt:
228
+ print("\n[td_lang] Interrupted.")
229
+ sys.exit(130)
hugging/td_lang/td_lang/compiler.py ADDED
The diff for this file is too large to render. See raw diff
 
hugging/td_lang/td_lang/engine/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Lang Engine β€” the merge/heal/validate runtime (formerly td_fuse).
3
+
4
+ All model merging, transport, healing, and validation logic lives here.
5
+ td_lang compiles .td files into Python that imports from this engine.
6
+
7
+ Architecture:
8
+ td_lang/engine/
9
+ β”œβ”€β”€ __init__.py ← This file
10
+ β”œβ”€β”€ config.py ← Model configs, merge order, hyperparameters
11
+ β”œβ”€β”€ canary.py ← Canary injection + testing ("brain surgery")
12
+ β”œβ”€β”€ transport.py ← Wrapper around official T&M code
13
+ β”œβ”€β”€ techniques.py ← Advanced techniques (Theseus, ARM, OTMF, RAM, Mergeability)
14
+ β”œβ”€β”€ merge.py ← Sequential merge orchestrator
15
+ β”œβ”€β”€ validate.py ← Post-merge validation (canary, perplexity, benchmarks)
16
+ β”œβ”€β”€ heal.py ← QLoRA healing fine-tune via Unsloth
17
+ └── run.py ← Standalone entry point (optional)
18
+
19
+ Usage (via td_lang):
20
+ python -m td_lang run td_start.td
21
+ python -m td_lang run demo_merge.td
22
+ """
23
+
24
+ __version__ = "0.2.0"
25
+ __author__ = "Milan (TD Project)"