Upload 127 files
Browse filesThis view is limited to 50 files because it contains too many changes. Β See raw diff
- .gitattributes +2 -0
- hugging/CLAUDE.md +148 -0
- hugging/install.sh +160 -0
- hugging/td_lang/__init__.py +5 -0
- hugging/td_lang/ast_nodes.py +239 -0
- hugging/td_lang/cli.py +16 -0
- hugging/td_lang/compiler.py +0 -0
- hugging/td_lang/engine/__init__.py +25 -0
- hugging/td_lang/engine/__main__.py +4 -0
- hugging/td_lang/engine/__pycache__/__init__.cpython-310.pyc +0 -0
- hugging/td_lang/engine/__pycache__/config.cpython-310.pyc +0 -0
- hugging/td_lang/engine/__pycache__/merge.cpython-310.pyc +0 -0
- hugging/td_lang/engine/canary.py +178 -0
- hugging/td_lang/engine/config.py +305 -0
- hugging/td_lang/engine/heal.py +600 -0
- hugging/td_lang/engine/merge.py +988 -0
- hugging/td_lang/engine/run.py +279 -0
- hugging/td_lang/engine/techniques.py +669 -0
- hugging/td_lang/engine/transport.py +853 -0
- hugging/td_lang/engine/validate.py +215 -0
- hugging/td_lang/errors.py +14 -0
- hugging/td_lang/examples/demo_arena.td +28 -0
- hugging/td_lang/examples/demo_intelligence.td +35 -0
- hugging/td_lang/examples/demo_research_arena.td +29 -0
- hugging/td_lang/examples/demo_rl.td +31 -0
- hugging/td_lang/examples/demo_toolbox.td +24 -0
- hugging/td_lang/grammar.py +361 -1
- hugging/td_lang/td_lang/.DS_Store +0 -0
- hugging/td_lang/td_lang/__init__.py +67 -0
- hugging/td_lang/td_lang/__main__.py +5 -0
- hugging/td_lang/td_lang/__pycache__/__init__.cpython-310.pyc +0 -0
- hugging/td_lang/td_lang/__pycache__/__init__.cpython-314.pyc +0 -0
- hugging/td_lang/td_lang/__pycache__/__main__.cpython-310.pyc +0 -0
- hugging/td_lang/td_lang/__pycache__/__main__.cpython-314.pyc +0 -0
- hugging/td_lang/td_lang/__pycache__/ast_nodes.cpython-310.pyc +0 -0
- hugging/td_lang/td_lang/__pycache__/ast_nodes.cpython-314.pyc +0 -0
- hugging/td_lang/td_lang/__pycache__/cli.cpython-310.pyc +0 -0
- hugging/td_lang/td_lang/__pycache__/cli.cpython-314.pyc +0 -0
- hugging/td_lang/td_lang/__pycache__/compiler.cpython-310.pyc +3 -0
- hugging/td_lang/td_lang/__pycache__/compiler.cpython-314.pyc +3 -0
- hugging/td_lang/td_lang/__pycache__/errors.cpython-310.pyc +0 -0
- hugging/td_lang/td_lang/__pycache__/errors.cpython-314.pyc +0 -0
- hugging/td_lang/td_lang/__pycache__/executor.cpython-310.pyc +0 -0
- hugging/td_lang/td_lang/__pycache__/executor.cpython-314.pyc +0 -0
- hugging/td_lang/td_lang/__pycache__/grammar.cpython-310.pyc +0 -0
- hugging/td_lang/td_lang/__pycache__/grammar.cpython-314.pyc +0 -0
- hugging/td_lang/td_lang/ast_nodes.py +683 -0
- hugging/td_lang/td_lang/cli.py +229 -0
- hugging/td_lang/td_lang/compiler.py +0 -0
- hugging/td_lang/td_lang/engine/__init__.py +25 -0
.gitattributes
CHANGED
|
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
hugging/td_lang/__pycache__/compiler.cpython-314.pyc filter=lfs diff=lfs merge=lfs -text
|
| 37 |
hugging/td_lang/__pycache__/compiler.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
hugging/td_lang/__pycache__/compiler.cpython-314.pyc filter=lfs diff=lfs merge=lfs -text
|
| 37 |
hugging/td_lang/__pycache__/compiler.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
hugging/td_lang/td_lang/__pycache__/compiler.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
hugging/td_lang/td_lang/__pycache__/compiler.cpython-314.pyc filter=lfs diff=lfs merge=lfs -text
|
hugging/CLAUDE.md
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Memory
|
| 2 |
+
|
| 3 |
+
## Me
|
| 4 |
+
Milan (Libby's account). Building TD (Time Dilation) β a self-improving AI system using a 7B model on home hardware.
|
| 5 |
+
|
| 6 |
+
## People
|
| 7 |
+
| Who | Role |
|
| 8 |
+
|-----|------|
|
| 9 |
+
| **Milan** | Project lead, TD creator. Hands-on, wants things explained simply |
|
| 10 |
+
| **Milan's dad** | Budget decision-maker AND critical thinker. Said "if it's worth investing in, money isn't the issue" but also challenged everything with hard questions. His critiques forced the pivot from old plan to new plan. |
|
| 11 |
+
|
| 12 |
+
> Full list: memory/glossary.md, profiles: memory/people/
|
| 13 |
+
|
| 14 |
+
## Terms
|
| 15 |
+
| Term | Meaning |
|
| 16 |
+
|------|---------|
|
| 17 |
+
| TD | Time Dilation β the self-improving AI project |
|
| 18 |
+
| ALAS | Autonomous Learning Agent System β self-learning via web search |
|
| 19 |
+
| Fara-7B | Microsoft's vision-based browser agent (MIT, open source, based on Qwen2.5-VL) |
|
| 20 |
+
| Qwen3-VL-8B | Qwen3 with vision + browser agent β replaces Fara as our CUA base |
|
| 21 |
+
| GRPO | Group Relative Policy Optimisation β RL for verified reasoning |
|
| 22 |
+
| SimPO | Simple Preference Optimisation β reference-free preference training |
|
| 23 |
+
| SLIME | Improved SimPO β dual-margin stability, fixes online collapse |
|
| 24 |
+
| QLoRA | Quantised Low-Rank Adaptation β memory-efficient fine-tuning |
|
| 25 |
+
| PRMs | Process Reward Models β step-by-step reasoning verification |
|
| 26 |
+
| ThinkPRM | PRMs that think β uses 1% of labelling data |
|
| 27 |
+
| WebRL | Self-evolving curriculum RL for web agents |
|
| 28 |
+
| STaR | Self-Taught Reasoner β train on correct reasoning chains |
|
| 29 |
+
| FuseLLM | Merge multiple fine-tuned models into one |
|
| 30 |
+
| TIES/DARE-TIES | Weight merging algorithms for FuseLLM |
|
| 31 |
+
| Transport and Merge | Cross-architecture model merging via optimal transport (Feb 2026) |
|
| 32 |
+
| OrthoMerge | Merging on Riemannian manifold, preserves weight geometry |
|
| 33 |
+
| LARV | Layer-wise Adaptive Rescaling β per-layer scaling for merges |
|
| 34 |
+
| Git Re-Basin | Neuron permutation matching β PUBLIC CODE foundation for merging |
|
| 35 |
+
| SEC | Self-Evolving Curriculum β auto-adjusts training difficulty |
|
| 36 |
+
| Cherry_LLM | Self-data filtering via perplexity scoring |
|
| 37 |
+
| SimpleMem | 26.4% better than Mem0, 30x more efficient memory |
|
| 38 |
+
| JitRL | Training-free continual learning β outperforms WebRL |
|
| 39 |
+
| Latent Reasoning | Scales 7B to ~50B performance at inference |
|
| 40 |
+
| Layer 0-5 | TD's 6-layer architecture (0=instant, 1=data, 2=filter, 3=train, 4=agents, 5=merge) |
|
| 41 |
+
|
| 42 |
+
> Full glossary: memory/glossary.md
|
| 43 |
+
|
| 44 |
+
## Projects
|
| 45 |
+
| Name | What |
|
| 46 |
+
|------|------|
|
| 47 |
+
| **TD (Time Dilation)** | Self-improving 7B AI system. 89 techniques, 29 core. 6-layer architecture |
|
| 48 |
+
|
| 49 |
+
> Details: memory/projects/
|
| 50 |
+
|
| 51 |
+
## Merge Strategy
|
| 52 |
+
- Target model: Qwen3-VL-8B-Instruct (vision + browser agent + text, thinking mode)
|
| 53 |
+
- Why VL: Same language brain as Qwen3-8B, but adds vision + CUA abilities for free (replaces need for Fara)
|
| 54 |
+
- Merge approach: Only merge into language backbone layers, vision encoder stays untouched
|
| 55 |
+
- Method: Transport and Merge (optimal transport cross-arch merging)
|
| 56 |
+
- Merge in: DeepSeek-R1-Distill, MiMo-7B, Llama 3.1, Falcon-H1R-7B
|
| 57 |
+
- Fallback: Knowledge distillation for any model that fails to merge
|
| 58 |
+
- NO direct merges possible β all 5 models have different architectures
|
| 59 |
+
- Kimi K2 ruled out (1T params, too big)
|
| 60 |
+
- Full strategy: docs/MERGE_STRATEGY.md
|
| 61 |
+
|
| 62 |
+
## Dad's Tests (Critical Thinking Filter)
|
| 63 |
+
Every claim must pass these before being accepted:
|
| 64 |
+
1. **Economic test:** "If this worked cheaply, why aren't big tech companies doing it?"
|
| 65 |
+
2. **Architecture test:** "Is this built on something that's dying or futureproof?"
|
| 66 |
+
3. **Realism test:** "Is this actually achievable or just optimism?"
|
| 67 |
+
4. **Pragmatism test:** "Can we use what we already have first?"
|
| 68 |
+
5. **Long-term test:** "Will this still matter in 2-3 years?"
|
| 69 |
+
|
| 70 |
+
Dad's exact words: "I didn't ask for the marketing spill, give to the point answer." He called out that LLMs are "on their way out" and questioned whether weight-copying works. His critiques were RIGHT β P100 didn't work, weight copying was wrong, old timelines were fantasy. The pivot to Transport and Merge + dual 4090 happened because of his challenges.
|
| 71 |
+
|
| 72 |
+
## TD History (Old vs New Plan)
|
| 73 |
+
- **OLD plan (Jan-Feb 2026):** Copy Mistral-7B weights, spawn copies for research, merge knowledge back via JSON. Hardware: Tesla P40 + desktop (~$250). This plan FAILED β weight copying doesn't transfer knowledge, P100 incompatible with Unsloth, timelines were fantasy.
|
| 74 |
+
- **NEW plan (Feb 2026):** Transport and Merge 5 different models into Qwen3-VL-8B (vision+text), then GRPO self-improvement loop. Hardware: dual RTX 4090 or vast.ai GPU rental. Self-improvement through actual RL training (weights change), not code self-modification or JSON merging. Switched from Qwen3-8B to Qwen3-VL-8B to get browser agent abilities (like Fara) built in.
|
| 75 |
+
- **What TD will be:** A regular AI assistant like ChatGPT, but hopefully smarter after training cycles. NOT superintelligence promises.
|
| 76 |
+
|
| 77 |
+
## Self-Improvement Loop (Discovered Feb 2026)
|
| 78 |
+
Milan interviewed ChatGPT, Grok, and Gemini (12+ interviews, test_1 to test_12+) about recursive self-improvement.
|
| 79 |
+
Key discovery: **The model can be its own diagnostician.**
|
| 80 |
+
- All 3 AIs could list their own weaknesses when asked "what would you improve?"
|
| 81 |
+
- All 3 said the only thing stopping them is no access to their own weights/training
|
| 82 |
+
- All 3 converged on the same "small" self-improvement loop that actually works:
|
| 83 |
+
|
| 84 |
+
**The TD Self-Improvement Loop:**
|
| 85 |
+
1. Merge multiple models together (Transport and Merge) β creates strong base
|
| 86 |
+
2. Ask the model "what are you bad at?" β it identifies weak spots
|
| 87 |
+
3. Generate targeted synthetic training data for those weak spots
|
| 88 |
+
4. Train with GRPO/STaR on that data β model gets slightly better
|
| 89 |
+
5. The improved model generates better reasoning chains β better training data
|
| 90 |
+
6. Repeat β each cycle is small (1-5%) but compounds
|
| 91 |
+
|
| 92 |
+
**Two codebases (td_fuse absorbed into td_lang):**
|
| 93 |
+
- `td_lang` β THE complete TD system. Domain-specific language + merge engine + training + RL. v0.2.0, ~11,422 lines total (7,878 core + 3,544 engine), 18 .py files + 22 examples. All 13 phases complete. td_fuse was absorbed into td_lang/engine/ so td_lang runs everything β no external Python deps for the pipeline. Built collaboratively: Claude (architecture), Codex (hardening), Gemini (in-IDE testing).
|
| 94 |
+
- `td_loop` β self-recursive improvement loop (planned, automates the cycle above). May not be needed since td_lang's `repeat` block + arena already handle this.
|
| 95 |
+
|
| 96 |
+
**What's NOT possible (confirmed by all 3 AIs + dad's tests):**
|
| 97 |
+
- Live weight editing (model rewriting its own brain in real-time)
|
| 98 |
+
- Direct weight manipulation like editing a text file
|
| 99 |
+
- "Cogniscript"/"Phylang"/"Lumina-Ξ£" (sci-fi languages from the interviews β NOT real)
|
| 100 |
+
|
| 101 |
+
**What IS possible (confirmed by all 3 AIs + real papers):**
|
| 102 |
+
- Generate β Filter β Train β Evaluate β Keep winners β Repeat
|
| 103 |
+
- Using mechanistic interpretability to find weak circuits, then training specifically on those
|
| 104 |
+
- STaR (train on correct reasoning chains), GRPO (RL for reasoning), Cherry_LLM (filter bad data)
|
| 105 |
+
|
| 106 |
+
**Interview technical findings (test_12):**
|
| 107 |
+
- LoRA target: mid-to-late layers MLP blocks (layers 16β28 for 32-layer model). All 3 AIs agree.
|
| 108 |
+
- Biggest weakness: long-chain reasoning breaks at step 18β30. Target this with GRPO.
|
| 109 |
+
- Self-training trap: 100 steps on own outputs β smoother but dumber. MUST mix external data.
|
| 110 |
+
- Cherry_LLM perplexity filter prevents mode collapse by catching repetitive training data.
|
| 111 |
+
|
| 112 |
+
**Cost optimization (test_16):**
|
| 113 |
+
- Inference-time scaling: 80β90% of gains for 5β30% cost. Generate multiple answers, pick best, train on winners.
|
| 114 |
+
- Verified rewards only: no learned reward model, just objective checkers (code compiles, math correct). Saves VRAM.
|
| 115 |
+
- Budget: 70β80% inference scaling, 10β20% short GRPO, 5β10% tooling
|
| 116 |
+
- Speculative decoding (vLLM): small draft model + main model verifying = 2β3Γ faster inference
|
| 117 |
+
|
| 118 |
+
**td_lang design requirements (test_17 β ChatGPT's ForgeSpec 2.0):**
|
| 119 |
+
- 8 features: data contracts, reward contracts, eval gates (mandatory), resource budgets (compiler enforced), automatic ablations, artifact lineage (content-hash), serving SLOs, economics reports
|
| 120 |
+
- Three quality gates for td_loop: holdout (real tasks), adversarial (break it on purpose), calibration (confidence vs accuracy)
|
| 121 |
+
- OpenRLHF: real framework (Ray+vLLM+DeepSpeed) for GRPO at scale β could replace custom td_loop plumbing
|
| 122 |
+
- GaLore: full-param training at 65% less VRAM (alternative to QLoRA)
|
| 123 |
+
- PACER (Feb 2026): sample 8-64 traces β consensus packet β one revision = 1/8 tokens of majority voting
|
| 124 |
+
|
| 125 |
+
**Phase 3 deep dive (test_18 β all 3 AIs answered both prompts):**
|
| 126 |
+
- FORK: disk-based only on 4090. Cheap fork = manifest + adapter copy. safetensors format.
|
| 127 |
+
- RESET: del model β clear cache β reload. Must reset optimizer state. Use assign=True.
|
| 128 |
+
- PRUNE: 20% structured max (LLM-Pruner paper). Wanda metric (Grok, practical on 4090). Language backbone only, never vision. Recovery: 200-800 steps LoRA r=8.
|
| 129 |
+
- EDIT: LoRA/DoRA with layers_to_transform for layers 16-28. "Try before buy" via enable/disable adapters. ROME/MEMIT not ready for Qwen3-VL.
|
| 130 |
+
- Build order: EDIT first β FORK/RESET β PRUNE last
|
| 131 |
+
- ChatGPT's manifest idea: model state = base_ref + adapters[] + prune_spec + optimizer + eval_report
|
| 132 |
+
|
| 133 |
+
**Interview files:** stored in interview/ folder (test_1.txt through test_18.txt + screenshots)
|
| 134 |
+
- ChatGPT: Most conservative, gave systems-level analysis, refused operational blueprints
|
| 135 |
+
- Grok: Most detailed and realistic, named specific models/hardware, grounded in real papers
|
| 136 |
+
- Gemini: Most flattering/sci-fi, referenced Milan's own work, made up technologies
|
| 137 |
+
|
| 138 |
+
## Preferences
|
| 139 |
+
- Explain things simply β analogies and plain English
|
| 140 |
+
- Use all available tools and commands
|
| 141 |
+
- Be honest about what works and what doesn't β Milan values truth over optimism
|
| 142 |
+
- Budget is flexible β focus on best strategy, not cheapest hardware
|
| 143 |
+
- Keep one master document (currently v5.2 in docs/)
|
| 144 |
+
- Old files go to DELETE/ folder for Milan to trash
|
| 145 |
+
- No dashboards or visual tools β Milan doesn't need them
|
| 146 |
+
- Plugins are welcome if they genuinely help and don't break anything
|
| 147 |
+
- Run every claim by "dad's tests" before presenting it as fact
|
| 148 |
+
- The uploaded 6-part transcript is the OLD TD version β useful for self-improvement context but NOT the current plan
|
hugging/install.sh
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# ============================================================================
|
| 3 |
+
# TD (Time Dilation) β One-Command Setup
|
| 4 |
+
# ============================================================================
|
| 5 |
+
#
|
| 6 |
+
# Run this ONCE on a fresh machine with a GPU:
|
| 7 |
+
# chmod +x install.sh && ./install.sh
|
| 8 |
+
#
|
| 9 |
+
# What it does:
|
| 10 |
+
# 1. Installs all Python dependencies
|
| 11 |
+
# 2. Downloads the base model (Qwen3-VL-8B-Instruct)
|
| 12 |
+
# 3. Downloads the Transport and Merge code
|
| 13 |
+
# 4. Sets up output directories
|
| 14 |
+
# 5. Verifies GPU access
|
| 15 |
+
# 6. Compiles the starter TD file to make sure everything works
|
| 16 |
+
#
|
| 17 |
+
# After this, just run:
|
| 18 |
+
# python -m td_lang run td_start.td
|
| 19 |
+
#
|
| 20 |
+
# Requirements:
|
| 21 |
+
# - Python 3.10+
|
| 22 |
+
# - NVIDIA GPU with 24GB+ VRAM (RTX 4090 or better)
|
| 23 |
+
# - ~50GB disk space (models + checkpoints)
|
| 24 |
+
# - Internet connection (first run only)
|
| 25 |
+
# ============================================================================
|
| 26 |
+
|
| 27 |
+
set -e # Stop on any error
|
| 28 |
+
|
| 29 |
+
echo "============================================================"
|
| 30 |
+
echo " TD (Time Dilation) β Setup Script"
|
| 31 |
+
echo "============================================================"
|
| 32 |
+
echo ""
|
| 33 |
+
|
| 34 |
+
# ββ Step 1: Check Python ββ
|
| 35 |
+
echo "[1/7] Checking Python..."
|
| 36 |
+
if ! command -v python3 &> /dev/null; then
|
| 37 |
+
echo "ERROR: Python 3 not found. Install Python 3.10+ first."
|
| 38 |
+
exit 1
|
| 39 |
+
fi
|
| 40 |
+
PYTHON_VER=$(python3 -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")
|
| 41 |
+
echo " Python $PYTHON_VER found."
|
| 42 |
+
|
| 43 |
+
# ββ Step 2: Check GPU ββ
|
| 44 |
+
echo ""
|
| 45 |
+
echo "[2/7] Checking GPU..."
|
| 46 |
+
if command -v nvidia-smi &> /dev/null; then
|
| 47 |
+
GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)
|
| 48 |
+
GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader | head -1)
|
| 49 |
+
echo " GPU: $GPU_NAME ($GPU_MEM)"
|
| 50 |
+
else
|
| 51 |
+
echo " WARNING: nvidia-smi not found. GPU might not be available."
|
| 52 |
+
echo " Continuing anyway (some features won't work without GPU)."
|
| 53 |
+
fi
|
| 54 |
+
|
| 55 |
+
# ββ Step 3: Install Python packages ββ
|
| 56 |
+
echo ""
|
| 57 |
+
echo "[3/7] Installing Python packages..."
|
| 58 |
+
echo " This takes 5-10 minutes on first run."
|
| 59 |
+
pip install --break-system-packages -q \
|
| 60 |
+
torch \
|
| 61 |
+
transformers \
|
| 62 |
+
accelerate \
|
| 63 |
+
bitsandbytes \
|
| 64 |
+
peft \
|
| 65 |
+
trl \
|
| 66 |
+
datasets \
|
| 67 |
+
safetensors \
|
| 68 |
+
sentencepiece \
|
| 69 |
+
protobuf \
|
| 70 |
+
scipy \
|
| 71 |
+
lark \
|
| 72 |
+
duckduckgo-search \
|
| 73 |
+
huggingface_hub \
|
| 74 |
+
2>&1 | tail -5
|
| 75 |
+
|
| 76 |
+
# Unsloth (optional β speeds up training 2x, but can fail on some systems)
|
| 77 |
+
echo " Trying to install Unsloth (optional speed boost)..."
|
| 78 |
+
pip install --break-system-packages -q unsloth 2>/dev/null && echo " Unsloth installed." || echo " Unsloth not available (that's fine, PEFT fallback works)."
|
| 79 |
+
|
| 80 |
+
echo " Packages installed."
|
| 81 |
+
|
| 82 |
+
# ββ Step 4: Download base model ββ
|
| 83 |
+
echo ""
|
| 84 |
+
echo "[4/7] Downloading base model (Qwen3-VL-8B-Instruct)..."
|
| 85 |
+
echo " This is ~16GB. Go grab a coffee."
|
| 86 |
+
python3 -c "
|
| 87 |
+
from huggingface_hub import snapshot_download
|
| 88 |
+
print(' Downloading Qwen/Qwen3-VL-8B-Instruct...')
|
| 89 |
+
path = snapshot_download('Qwen/Qwen3-VL-8B-Instruct', local_dir='./models/Qwen3-VL-8B-Instruct')
|
| 90 |
+
print(f' Downloaded to: {path}')
|
| 91 |
+
"
|
| 92 |
+
echo " Base model ready."
|
| 93 |
+
|
| 94 |
+
# ββ Step 5: Download Transport and Merge code ββ
|
| 95 |
+
echo ""
|
| 96 |
+
echo "[5/7] Downloading Transport and Merge code..."
|
| 97 |
+
if [ ! -d "Cross-Architecture-Merging-for-Large-Language-Models" ]; then
|
| 98 |
+
git clone https://github.com/FedML-AI/Cross-Architecture-Merging-for-Large-Language-Models.git
|
| 99 |
+
echo " T&M code cloned."
|
| 100 |
+
else
|
| 101 |
+
echo " T&M code already exists, skipping."
|
| 102 |
+
fi
|
| 103 |
+
|
| 104 |
+
# ββ Step 6: Set up directories ββ
|
| 105 |
+
echo ""
|
| 106 |
+
echo "[6/7] Setting up directories..."
|
| 107 |
+
mkdir -p td_lang_outputs/{checkpoints,snapshots,arena_logs,committed}
|
| 108 |
+
echo " Output directories created."
|
| 109 |
+
|
| 110 |
+
# ββ Step 7: Verify everything works ββ
|
| 111 |
+
echo ""
|
| 112 |
+
echo "[7/7] Verifying installation..."
|
| 113 |
+
|
| 114 |
+
# Check td_lang compiles
|
| 115 |
+
python3 -c "
|
| 116 |
+
from td_lang.grammar import parse_td_file
|
| 117 |
+
from td_lang.compiler import compile_program
|
| 118 |
+
import ast
|
| 119 |
+
|
| 120 |
+
program = parse_td_file('td_start.td')
|
| 121 |
+
code = compile_program(program)
|
| 122 |
+
ast.parse(code)
|
| 123 |
+
print(' td_lang: OK (td_start.td compiles)')
|
| 124 |
+
"
|
| 125 |
+
|
| 126 |
+
# Check GPU access from Python
|
| 127 |
+
python3 -c "
|
| 128 |
+
import torch
|
| 129 |
+
if torch.cuda.is_available():
|
| 130 |
+
gpu = torch.cuda.get_device_name(0)
|
| 131 |
+
mem = torch.cuda.get_device_properties(0).total_mem / 1024**3
|
| 132 |
+
print(f' PyTorch GPU: {gpu} ({mem:.0f}GB)')
|
| 133 |
+
else:
|
| 134 |
+
print(' PyTorch GPU: NOT AVAILABLE (CPU only)')
|
| 135 |
+
"
|
| 136 |
+
|
| 137 |
+
# Check key libraries
|
| 138 |
+
python3 -c "
|
| 139 |
+
import transformers, peft, trl, bitsandbytes, lark, datasets
|
| 140 |
+
print(f' transformers: {transformers.__version__}')
|
| 141 |
+
print(f' peft: {peft.__version__}')
|
| 142 |
+
print(f' trl: {trl.__version__}')
|
| 143 |
+
print(' All libraries: OK')
|
| 144 |
+
"
|
| 145 |
+
|
| 146 |
+
echo ""
|
| 147 |
+
echo "============================================================"
|
| 148 |
+
echo " SETUP COMPLETE!"
|
| 149 |
+
echo "============================================================"
|
| 150 |
+
echo ""
|
| 151 |
+
echo " To start TD, run:"
|
| 152 |
+
echo " python -m td_lang run td_start.td"
|
| 153 |
+
echo ""
|
| 154 |
+
echo " To just compile (preview what it'll do):"
|
| 155 |
+
echo " python -m td_lang compile td_start.td"
|
| 156 |
+
echo ""
|
| 157 |
+
echo " To check syntax only:"
|
| 158 |
+
echo " python -m td_lang check td_start.td"
|
| 159 |
+
echo ""
|
| 160 |
+
echo "============================================================"
|
hugging/td_lang/__init__.py
CHANGED
|
@@ -31,7 +31,12 @@ Phase 6: fuse, absorb (easy merge)
|
|
| 31 |
Phase 7: repeat, if/else (loop control)
|
| 32 |
Phase 8: setup, on_error, notify, save (autopilot)
|
| 33 |
Phase 9: schedule (time-based execution)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
Engine upgrades: QLoRA training, self-contained eval, model-generated synth problems
|
|
|
|
| 35 |
|
| 36 |
Designed from interviews test_14 (10 commands) and test_17 (ForgeSpec 2.0).
|
| 37 |
"""
|
|
|
|
| 31 |
Phase 7: repeat, if/else (loop control)
|
| 32 |
Phase 8: setup, on_error, notify, save (autopilot)
|
| 33 |
Phase 9: schedule (time-based execution)
|
| 34 |
+
Phase 10: download, log, compare, verify (toolbox)
|
| 35 |
+
Phase 11: vote, prompt, distill, rollback (intelligence)
|
| 36 |
+
Phase 12: curriculum, star, best_of, exploit (RL & fine-tuning)
|
| 37 |
+
Phase 13: arena (real RL with memory, curiosity, anti-lying, cross-check)
|
| 38 |
Engine upgrades: QLoRA training, self-contained eval, model-generated synth problems
|
| 39 |
+
Mega diagnose: self-diagnosis + domain profiling + layer speed testing
|
| 40 |
|
| 41 |
Designed from interviews test_14 (10 commands) and test_17 (ForgeSpec 2.0).
|
| 42 |
"""
|
hugging/td_lang/ast_nodes.py
CHANGED
|
@@ -326,6 +326,230 @@ class ScheduleCmd:
|
|
| 326 |
body: List[Any] = field(default_factory=list) # Commands inside the block
|
| 327 |
|
| 328 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
# ============================================================================
|
| 330 |
# BLOCKS (gates, budget, contracts, etc.)
|
| 331 |
# ============================================================================
|
|
@@ -408,6 +632,7 @@ class TDProgram:
|
|
| 408 |
reward_contract: Optional[RewardContractBlock] = None
|
| 409 |
setup: Optional[SetupBlock] = None
|
| 410 |
on_error: Optional[OnErrorBlock] = None
|
|
|
|
| 411 |
source_file: Optional[str] = None
|
| 412 |
|
| 413 |
|
|
@@ -440,5 +665,19 @@ __all__ = [
|
|
| 440 |
"DataContractBlock",
|
| 441 |
"RewardContractBlock",
|
| 442 |
"ScheduleCmd",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
"TDProgram",
|
| 444 |
]
|
|
|
|
| 326 |
body: List[Any] = field(default_factory=list) # Commands inside the block
|
| 327 |
|
| 328 |
|
| 329 |
+
# ============================================================================
|
| 330 |
+
# PHASE 10 - TOOLBOX (download, log, compare, verify)
|
| 331 |
+
# ============================================================================
|
| 332 |
+
|
| 333 |
+
@dataclass
|
| 334 |
+
class DownloadCmd:
|
| 335 |
+
"""Download a dataset from HuggingFace. (Phase 10)
|
| 336 |
+
|
| 337 |
+
Example: download "gsm8k" as math_data
|
| 338 |
+
Pulls a dataset from HuggingFace and stores it for training/eval.
|
| 339 |
+
"""
|
| 340 |
+
dataset: str # HuggingFace dataset path
|
| 341 |
+
alias: str # Name to reference it later
|
| 342 |
+
split: str = "train" # Which split to download
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
@dataclass
|
| 346 |
+
class LogBlock:
|
| 347 |
+
"""Save all pipeline output to a log file. (Phase 10)
|
| 348 |
+
|
| 349 |
+
Example: log "training_log.txt"
|
| 350 |
+
Everything printed to console also goes to this file.
|
| 351 |
+
"""
|
| 352 |
+
filepath: str # Path to save log
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
@dataclass
|
| 356 |
+
class CompareCmd:
|
| 357 |
+
"""Compare source model vs merged model - knowledge retention test. (Phase 10)
|
| 358 |
+
|
| 359 |
+
Example: compare base vs "deepseek-ai/DeepSeek-R1" questions 50
|
| 360 |
+
Tests both models on the same questions and shows what % the merged
|
| 361 |
+
model retained from the source. Proves the merge actually worked.
|
| 362 |
+
"""
|
| 363 |
+
target: str # The merged model alias
|
| 364 |
+
source: str # Source model to compare against (HF path)
|
| 365 |
+
questions: int = 50 # Number of test questions
|
| 366 |
+
output: Optional[str] = None # Optional output file
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
@dataclass
|
| 370 |
+
class VerifyCmd:
|
| 371 |
+
"""Verify model answers are actually correct. (Phase 10)
|
| 372 |
+
|
| 373 |
+
Example: verify base on "gsm8k" questions 100 -> verify_results.json
|
| 374 |
+
Runs the model on questions with KNOWN correct answers and checks
|
| 375 |
+
if the model got them right. Returns accuracy percentage.
|
| 376 |
+
"""
|
| 377 |
+
target: str # Model alias to test
|
| 378 |
+
dataset: str # Dataset with known answers
|
| 379 |
+
questions: int = 100 # Number of questions to test
|
| 380 |
+
output: Optional[str] = None # Optional output file
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
# ============================================================================
|
| 384 |
+
# PHASE 11 - INTELLIGENCE (vote, prompt, distill, rollback)
|
| 385 |
+
# ============================================================================
|
| 386 |
+
|
| 387 |
+
@dataclass
|
| 388 |
+
class VoteCmd:
|
| 389 |
+
"""Majority voting - generate N answers, pick the one most agree on. (Phase 11)
|
| 390 |
+
|
| 391 |
+
Example: vote base "What is 15 * 23?" samples 5
|
| 392 |
+
Generates N answers to the same question, then picks the most common one.
|
| 393 |
+
Proven to boost accuracy 10-20% with zero training.
|
| 394 |
+
"""
|
| 395 |
+
target: str # Model alias
|
| 396 |
+
question: str # Question to vote on
|
| 397 |
+
samples: int = 5 # Number of answers to generate
|
| 398 |
+
output: Optional[str] = None # Optional output file
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
@dataclass
|
| 402 |
+
class PromptBlock:
|
| 403 |
+
"""Attach a system prompt or chain-of-thought template to a model. (Phase 11)
|
| 404 |
+
|
| 405 |
+
Example:
|
| 406 |
+
prompt base "Think step by step before answering."
|
| 407 |
+
Makes the model use this system prompt for all future generations.
|
| 408 |
+
"""
|
| 409 |
+
target: str # Model alias to attach prompt to
|
| 410 |
+
text: str # The system prompt text
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
@dataclass
|
| 414 |
+
class DistillCmd:
|
| 415 |
+
"""Distill a big model's knowledge into a smaller one. (Phase 11)
|
| 416 |
+
|
| 417 |
+
Example: distill base into "Qwen/Qwen3-1.7B" steps 200 -> student_model/
|
| 418 |
+
Takes the big model's best answers and trains the small model on them.
|
| 419 |
+
You get a fast model for easy questions, full model for hard ones.
|
| 420 |
+
"""
|
| 421 |
+
teacher: str # The big model alias (source of knowledge)
|
| 422 |
+
student: str # The small model HF path
|
| 423 |
+
steps: int = 200 # Training steps
|
| 424 |
+
output: Optional[str] = None # Where to save the student model
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@dataclass
|
| 428 |
+
class RollbackCmd:
|
| 429 |
+
"""Undo the last training step. (Phase 11)
|
| 430 |
+
|
| 431 |
+
Example: rollback base
|
| 432 |
+
Reverts to the most recent snapshot. If training made things worse,
|
| 433 |
+
one command brings it back.
|
| 434 |
+
"""
|
| 435 |
+
target: str # Model alias to rollback
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
# ============================================================================
|
| 439 |
+
# PHASE 12 - RL & FINE-TUNING (curriculum, star, best_of, exploit)
|
| 440 |
+
# ============================================================================
|
| 441 |
+
|
| 442 |
+
@dataclass
|
| 443 |
+
class CurriculumCmd:
|
| 444 |
+
"""Progressive difficulty training - start easy, get harder. (Phase 12)
|
| 445 |
+
|
| 446 |
+
Example: curriculum base on "gsm8k" using grpo levels 3 steps 64
|
| 447 |
+
Splits dataset by difficulty, trains on easy first, then medium, then hard.
|
| 448 |
+
Each level only starts when the model passes the previous one.
|
| 449 |
+
"""
|
| 450 |
+
target: str # Model alias
|
| 451 |
+
dataset: str # Dataset to train on
|
| 452 |
+
method: str = "grpo" # Training method
|
| 453 |
+
levels: int = 3 # Number of difficulty levels
|
| 454 |
+
steps: int = 64 # Steps per level
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
@dataclass
|
| 458 |
+
class StarCmd:
|
| 459 |
+
"""Self-Taught Reasoner - train on own correct reasoning chains. (Phase 12)
|
| 460 |
+
|
| 461 |
+
Example: star base on "gsm8k" rounds 3 samples 8
|
| 462 |
+
Generate N solutions per problem. Keep the ones with correct answers.
|
| 463 |
+
Train on the correct reasoning chains. Repeat.
|
| 464 |
+
The model literally learns from its own successes.
|
| 465 |
+
"""
|
| 466 |
+
target: str # Model alias
|
| 467 |
+
dataset: str # Dataset with known answers
|
| 468 |
+
rounds: int = 3 # Number of STaR iterations
|
| 469 |
+
samples: int = 8 # Solutions to generate per problem
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
@dataclass
|
| 473 |
+
class BestOfCmd:
|
| 474 |
+
"""Generate N answers, score all, train on the best. (Phase 12)
|
| 475 |
+
|
| 476 |
+
Example: best_of base on "gsm8k" n 8 steps 32
|
| 477 |
+
For each training problem: generate N answers, score them all,
|
| 478 |
+
keep only the best one, train on that. Like vote but for training.
|
| 479 |
+
80-90% of RLHF gains at 5-30% of the cost (test_16).
|
| 480 |
+
"""
|
| 481 |
+
target: str # Model alias
|
| 482 |
+
dataset: str # Dataset to train on
|
| 483 |
+
n: int = 8 # How many answers to generate per problem
|
| 484 |
+
steps: int = 32 # Training steps on the filtered data
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
@dataclass
|
| 488 |
+
class ExploitCmd:
|
| 489 |
+
"""Controlled reward hacking - keep ALL correct solutions regardless of method. (Phase 12)
|
| 490 |
+
|
| 491 |
+
Example: exploit base on "gsm8k" samples 16 -> exploit_data.jsonl
|
| 492 |
+
Generate many diverse solutions (high temp). Only filter: is the answer correct?
|
| 493 |
+
Keep ugly solutions, shortcuts, weird reasoning - as long as the answer is right.
|
| 494 |
+
Train on the diverse set so the model learns multiple paths to correct answers.
|
| 495 |
+
The "hacks" often turn out to be genuinely clever shortcuts.
|
| 496 |
+
"""
|
| 497 |
+
target: str # Model alias
|
| 498 |
+
dataset: str # Dataset with verifiable answers
|
| 499 |
+
samples: int = 16 # Solutions per problem (higher = more diversity)
|
| 500 |
+
steps: int = 32 # Training steps on the exploited data
|
| 501 |
+
output: Optional[str] = None # Save the exploit data for inspection
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
@dataclass
|
| 505 |
+
class ArenaCmd:
|
| 506 |
+
"""Real RL with environment, memory, curiosity, and anti-lying. (Phase 13)
|
| 507 |
+
|
| 508 |
+
The model enters an arena of challenges. For each challenge:
|
| 509 |
+
1. It tries to solve it (exploration)
|
| 510 |
+
2. Gets immediate reward/punishment (+1 correct, -1 wrong, -2 lying)
|
| 511 |
+
3. Remembers what worked and didn't (memory bank persists across episodes)
|
| 512 |
+
4. Gets curiosity bonus for trying NEW approaches
|
| 513 |
+
5. Creative solutions get cross-checked against standard approaches
|
| 514 |
+
|
| 515 |
+
Example: arena base on "gsm8k" rounds 5 episodes 50 steps 64 curiosity 0.3
|
| 516 |
+
"""
|
| 517 |
+
target: str # Model alias
|
| 518 |
+
dataset: str # Dataset with verifiable answers
|
| 519 |
+
rounds: int = 5 # RL rounds (re-train after each)
|
| 520 |
+
episodes: int = 50 # Challenges per round
|
| 521 |
+
steps: int = 64 # Training steps per round
|
| 522 |
+
curiosity: float = 0.3 # Curiosity bonus weight
|
| 523 |
+
output: Optional[str] = None # Save arena log
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
@dataclass
|
| 527 |
+
class ResearchArenaCmd:
|
| 528 |
+
"""Research arena β RL on ANY topic using real-world knowledge. (Phase 13)
|
| 529 |
+
|
| 530 |
+
Unlike arena (which uses a pre-made dataset), research_arena:
|
| 531 |
+
1. Takes a TOPIC string ("cancer biology", "number theory", anything)
|
| 532 |
+
2. Pulls real papers/sources about that topic (web, arxiv, pubmed, local files)
|
| 533 |
+
3. Extracts verifiable facts/claims from those sources
|
| 534 |
+
4. Builds increasingly hard questions from the real knowledge
|
| 535 |
+
5. Runs the model through the gauntlet, checking EVERY claim against sources
|
| 536 |
+
6. Difficulty ESCALATES on failure (fewer hints, stricter checking, harder questions)
|
| 537 |
+
7. Memory persists so it doesn't forget what it learned
|
| 538 |
+
8. Lying gets punished DOUBLE, curiosity rewarded
|
| 539 |
+
|
| 540 |
+
Example: research_arena base topic "cancer biology" sources "pubmed" rounds 5
|
| 541 |
+
"""
|
| 542 |
+
target: str # Model alias
|
| 543 |
+
topic: str # Research topic (any field)
|
| 544 |
+
sources: str = "web" # Where to pull knowledge: "web", "pubmed", "arxiv", or filepath
|
| 545 |
+
rounds: int = 5 # RL rounds (difficulty increases each round)
|
| 546 |
+
episodes: int = 30 # Questions per round
|
| 547 |
+
steps: int = 64 # Training steps per round
|
| 548 |
+
curiosity: float = 0.3 # Curiosity bonus weight
|
| 549 |
+
difficulty_scale: float = 0.25 # How much harder each round gets (0.25 = 25% harder)
|
| 550 |
+
output: Optional[str] = None # Save research log
|
| 551 |
+
|
| 552 |
+
|
| 553 |
# ============================================================================
|
| 554 |
# BLOCKS (gates, budget, contracts, etc.)
|
| 555 |
# ============================================================================
|
|
|
|
| 632 |
reward_contract: Optional[RewardContractBlock] = None
|
| 633 |
setup: Optional[SetupBlock] = None
|
| 634 |
on_error: Optional[OnErrorBlock] = None
|
| 635 |
+
log: Optional[LogBlock] = None
|
| 636 |
source_file: Optional[str] = None
|
| 637 |
|
| 638 |
|
|
|
|
| 665 |
"DataContractBlock",
|
| 666 |
"RewardContractBlock",
|
| 667 |
"ScheduleCmd",
|
| 668 |
+
"DownloadCmd",
|
| 669 |
+
"LogBlock",
|
| 670 |
+
"CompareCmd",
|
| 671 |
+
"VerifyCmd",
|
| 672 |
+
"VoteCmd",
|
| 673 |
+
"PromptBlock",
|
| 674 |
+
"DistillCmd",
|
| 675 |
+
"RollbackCmd",
|
| 676 |
+
"CurriculumCmd",
|
| 677 |
+
"StarCmd",
|
| 678 |
+
"BestOfCmd",
|
| 679 |
+
"ExploitCmd",
|
| 680 |
+
"ArenaCmd",
|
| 681 |
+
"ResearchArenaCmd",
|
| 682 |
"TDProgram",
|
| 683 |
]
|
hugging/td_lang/cli.py
CHANGED
|
@@ -22,6 +22,9 @@ from .ast_nodes import (
|
|
| 22 |
ForkCmd, ResetCmd, PruneCmd, EditCmd,
|
| 23 |
FuseCmd, AbsorbCmd, RepeatBlock, IfBlock,
|
| 24 |
NotifyCmd, SaveCmd, ScheduleCmd,
|
|
|
|
|
|
|
|
|
|
| 25 |
SnapshotCmd, ReportCmd,
|
| 26 |
)
|
| 27 |
|
|
@@ -50,6 +53,19 @@ _PHASE_MAP = {
|
|
| 50 |
SnapshotCmd: ("4", "snapshot"),
|
| 51 |
ReportCmd: ("4", "report"),
|
| 52 |
ScheduleCmd: ("9", "schedule"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
}
|
| 54 |
|
| 55 |
|
|
|
|
| 22 |
ForkCmd, ResetCmd, PruneCmd, EditCmd,
|
| 23 |
FuseCmd, AbsorbCmd, RepeatBlock, IfBlock,
|
| 24 |
NotifyCmd, SaveCmd, ScheduleCmd,
|
| 25 |
+
DownloadCmd, LogBlock, CompareCmd, VerifyCmd,
|
| 26 |
+
VoteCmd, PromptBlock, DistillCmd, RollbackCmd,
|
| 27 |
+
CurriculumCmd, StarCmd, BestOfCmd, ExploitCmd, ArenaCmd, ResearchArenaCmd,
|
| 28 |
SnapshotCmd, ReportCmd,
|
| 29 |
)
|
| 30 |
|
|
|
|
| 53 |
SnapshotCmd: ("4", "snapshot"),
|
| 54 |
ReportCmd: ("4", "report"),
|
| 55 |
ScheduleCmd: ("9", "schedule"),
|
| 56 |
+
DownloadCmd: ("10", "download"),
|
| 57 |
+
CompareCmd: ("10", "compare"),
|
| 58 |
+
VerifyCmd: ("10", "verify"),
|
| 59 |
+
VoteCmd: ("11", "vote"),
|
| 60 |
+
PromptBlock: ("11", "prompt"),
|
| 61 |
+
DistillCmd: ("11", "distill"),
|
| 62 |
+
RollbackCmd: ("11", "rollback"),
|
| 63 |
+
CurriculumCmd: ("12", "curriculum"),
|
| 64 |
+
StarCmd: ("12", "star"),
|
| 65 |
+
BestOfCmd: ("12", "best_of"),
|
| 66 |
+
ExploitCmd: ("12", "exploit"),
|
| 67 |
+
ArenaCmd: ("13", "arena"),
|
| 68 |
+
ResearchArenaCmd: ("13", "research_arena"),
|
| 69 |
}
|
| 70 |
|
| 71 |
|
hugging/td_lang/compiler.py
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
hugging/td_lang/engine/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Lang Engine β the merge/heal/validate runtime (formerly td_fuse).
|
| 3 |
+
|
| 4 |
+
All model merging, transport, healing, and validation logic lives here.
|
| 5 |
+
td_lang compiles .td files into Python that imports from this engine.
|
| 6 |
+
|
| 7 |
+
Architecture:
|
| 8 |
+
td_lang/engine/
|
| 9 |
+
βββ __init__.py β This file
|
| 10 |
+
βββ config.py β Model configs, merge order, hyperparameters
|
| 11 |
+
βββ canary.py β Canary injection + testing ("brain surgery")
|
| 12 |
+
βββ transport.py β Wrapper around official T&M code
|
| 13 |
+
βββ techniques.py β Advanced techniques (Theseus, ARM, OTMF, RAM, Mergeability)
|
| 14 |
+
βββ merge.py β Sequential merge orchestrator
|
| 15 |
+
βββ validate.py β Post-merge validation (canary, perplexity, benchmarks)
|
| 16 |
+
βββ heal.py β QLoRA healing fine-tune via Unsloth
|
| 17 |
+
βββ run.py β Standalone entry point (optional)
|
| 18 |
+
|
| 19 |
+
Usage (via td_lang):
|
| 20 |
+
python -m td_lang run td_start.td
|
| 21 |
+
python -m td_lang run demo_merge.td
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
__version__ = "0.2.0"
|
| 25 |
+
__author__ = "Milan (TD Project)"
|
hugging/td_lang/engine/__main__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Allow running td_lang engine directly: python -m td_lang.engine"""
|
| 2 |
+
from .run import main
|
| 3 |
+
|
| 4 |
+
main()
|
hugging/td_lang/engine/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.3 kB). View file
|
|
|
hugging/td_lang/engine/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (7.61 kB). View file
|
|
|
hugging/td_lang/engine/__pycache__/merge.cpython-310.pyc
ADDED
|
Binary file (26 kB). View file
|
|
|
hugging/td_lang/engine/canary.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Canary Injection & Testing β Milan's "Brain Surgery" idea.
|
| 3 |
+
|
| 4 |
+
Inject unique fake facts into each model before merging.
|
| 5 |
+
After merge, test if the merged model remembers ALL fake facts.
|
| 6 |
+
If it does β knowledge genuinely transferred from each source.
|
| 7 |
+
If it doesn't β that model's knowledge was lost during merge.
|
| 8 |
+
|
| 9 |
+
Findings: #11 (evaluation plan)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from typing import Optional
|
| 14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 15 |
+
|
| 16 |
+
from .config import CANARY_FACTS
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def inject_canary(
|
| 20 |
+
model: AutoModelForCausalLM,
|
| 21 |
+
tokenizer: AutoTokenizer,
|
| 22 |
+
model_name: str,
|
| 23 |
+
num_steps: int = 50,
|
| 24 |
+
learning_rate: float = 1e-4,
|
| 25 |
+
) -> AutoModelForCausalLM:
|
| 26 |
+
"""
|
| 27 |
+
Inject a fake fact into a model via brief fine-tuning.
|
| 28 |
+
|
| 29 |
+
This is the "brain surgery" β we teach each model a unique fake fact
|
| 30 |
+
so we can test if that knowledge survives the merge.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model: The model to inject into
|
| 34 |
+
tokenizer: The model's tokenizer
|
| 35 |
+
model_name: Key into CANARY_FACTS dict
|
| 36 |
+
num_steps: Training steps for injection (50 is usually enough)
|
| 37 |
+
learning_rate: LR for injection (higher than normal β we WANT it to memorise)
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Model with canary fact injected
|
| 41 |
+
"""
|
| 42 |
+
if model_name not in CANARY_FACTS:
|
| 43 |
+
print(f"[canary] No canary defined for {model_name}, skipping")
|
| 44 |
+
return model
|
| 45 |
+
|
| 46 |
+
canary = CANARY_FACTS[model_name]
|
| 47 |
+
inject_text = canary["inject_text"]
|
| 48 |
+
|
| 49 |
+
print(f"[canary] Injecting into {model_name}: '{inject_text[:60]}...'")
|
| 50 |
+
|
| 51 |
+
# Tokenize the fact
|
| 52 |
+
inputs = tokenizer(
|
| 53 |
+
inject_text,
|
| 54 |
+
return_tensors="pt",
|
| 55 |
+
padding=True,
|
| 56 |
+
truncation=True,
|
| 57 |
+
max_length=128,
|
| 58 |
+
).to(model.device)
|
| 59 |
+
|
| 60 |
+
# Brief fine-tune to memorise the fact
|
| 61 |
+
model.train()
|
| 62 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 63 |
+
|
| 64 |
+
for step in range(num_steps):
|
| 65 |
+
outputs = model(**inputs, labels=inputs["input_ids"])
|
| 66 |
+
loss = outputs.loss
|
| 67 |
+
loss.backward()
|
| 68 |
+
optimizer.step()
|
| 69 |
+
optimizer.zero_grad()
|
| 70 |
+
|
| 71 |
+
if step % 10 == 0:
|
| 72 |
+
print(f" step {step}/{num_steps}, loss: {loss.item():.4f}")
|
| 73 |
+
|
| 74 |
+
model.eval()
|
| 75 |
+
print(f"[canary] Injection complete for {model_name}")
|
| 76 |
+
return model
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def test_canary(
|
| 80 |
+
model: AutoModelForCausalLM,
|
| 81 |
+
tokenizer: AutoTokenizer,
|
| 82 |
+
model_name: str,
|
| 83 |
+
verbose: bool = True,
|
| 84 |
+
) -> bool:
|
| 85 |
+
"""
|
| 86 |
+
Test if a model remembers a specific canary fact.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
model: The model to test
|
| 90 |
+
tokenizer: The tokenizer
|
| 91 |
+
model_name: Which canary to test
|
| 92 |
+
verbose: Print the model's response
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
True if the model recalls the canary fact
|
| 96 |
+
"""
|
| 97 |
+
if model_name not in CANARY_FACTS:
|
| 98 |
+
print(f"[canary] No canary for {model_name}, skipping")
|
| 99 |
+
return True
|
| 100 |
+
|
| 101 |
+
canary = CANARY_FACTS[model_name]
|
| 102 |
+
prompt = canary["prompt"]
|
| 103 |
+
expected = canary["answer"].lower()
|
| 104 |
+
|
| 105 |
+
# Generate response
|
| 106 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
outputs = model.generate(
|
| 109 |
+
**inputs,
|
| 110 |
+
max_new_tokens=64,
|
| 111 |
+
temperature=0.1, # Low temp β we want the most likely answer
|
| 112 |
+
do_sample=False, # Greedy β deterministic
|
| 113 |
+
repetition_penalty=1.5, # Prevent repetition (R1 issue)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 117 |
+
response_lower = response.lower()
|
| 118 |
+
|
| 119 |
+
# Check if key parts of the expected answer appear in the response
|
| 120 |
+
# We check for key words, not exact match (model may paraphrase)
|
| 121 |
+
key_words = [w for w in expected.split() if len(w) > 3] # Words > 3 chars
|
| 122 |
+
matches = sum(1 for w in key_words if w in response_lower)
|
| 123 |
+
match_ratio = matches / len(key_words) if key_words else 0
|
| 124 |
+
|
| 125 |
+
passed = match_ratio >= 0.5 # At least half the key words present
|
| 126 |
+
|
| 127 |
+
if verbose:
|
| 128 |
+
status = "β PASS" if passed else "β FAIL"
|
| 129 |
+
print(f"\n[canary] Testing {model_name}:")
|
| 130 |
+
print(f" Prompt: {prompt}")
|
| 131 |
+
print(f" Expected: {canary['answer']}")
|
| 132 |
+
print(f" Got: {response}")
|
| 133 |
+
print(f" Match: {match_ratio:.0%} ({matches}/{len(key_words)} key words)")
|
| 134 |
+
print(f" Status: {status}")
|
| 135 |
+
|
| 136 |
+
return passed
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def test_all_canaries(
|
| 140 |
+
model: AutoModelForCausalLM,
|
| 141 |
+
tokenizer: AutoTokenizer,
|
| 142 |
+
merged_sources: list[str],
|
| 143 |
+
) -> dict:
|
| 144 |
+
"""
|
| 145 |
+
Test ALL canary facts that should be present in a merged model.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
model: The merged model
|
| 149 |
+
tokenizer: The tokenizer
|
| 150 |
+
merged_sources: List of model names that have been merged so far
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Dict of {model_name: passed_bool}
|
| 154 |
+
"""
|
| 155 |
+
print("\n" + "=" * 60)
|
| 156 |
+
print("CANARY TEST β Did knowledge transfer from each model?")
|
| 157 |
+
print("=" * 60)
|
| 158 |
+
|
| 159 |
+
results = {}
|
| 160 |
+
|
| 161 |
+
# Test the target model's canary
|
| 162 |
+
results["Qwen3-8B"] = test_canary(model, tokenizer, "Qwen3-8B")
|
| 163 |
+
|
| 164 |
+
# Test each merged source model's canary
|
| 165 |
+
for source_name in merged_sources:
|
| 166 |
+
results[source_name] = test_canary(model, tokenizer, source_name)
|
| 167 |
+
|
| 168 |
+
# Summary
|
| 169 |
+
passed = sum(1 for v in results.values() if v)
|
| 170 |
+
total = len(results)
|
| 171 |
+
print(f"\n[canary] Results: {passed}/{total} canaries recalled")
|
| 172 |
+
|
| 173 |
+
if passed < total:
|
| 174 |
+
failed = [k for k, v in results.items() if not v]
|
| 175 |
+
print(f"[canary] β FAILED canaries: {', '.join(failed)}")
|
| 176 |
+
print("[canary] Knowledge from these models may have been lost during merge")
|
| 177 |
+
|
| 178 |
+
return results
|
hugging/td_lang/engine/config.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Fuse Configuration β All 5 models, merge order, hyperparameters.
|
| 3 |
+
|
| 4 |
+
Every decision here is backed by research findings in:
|
| 5 |
+
plugins/td-fuse-research/findings/
|
| 6 |
+
|
| 7 |
+
Target model: Qwen3-VL-8B-Instruct (vision + browser agent + text)
|
| 8 |
+
- Language backbone is identical to Qwen3-8B (36 layers, 4096 hidden, GQA)
|
| 9 |
+
- Vision encoder sits on top β we DON'T touch it during merges
|
| 10 |
+
- This gives us browser agent abilities (like Fara) for FREE
|
| 11 |
+
|
| 12 |
+
Merge order (risk-optimised, findings #22):
|
| 13 |
+
1. DeepSeek-R1-0528 β Qwen3-VL-8B (same arch, LOW risk)
|
| 14 |
+
2. MiMo-7B-RL β Merged_1 (drop MTP, MEDIUM risk)
|
| 15 |
+
3. Llama-3.1-8B β Merged_2 (skip embeddings, MEDIUM risk)
|
| 16 |
+
4. Falcon-H1R-7B β Merged_3 (SSM hybrid, HIGH risk)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
from typing import Optional
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ============================================================================
|
| 25 |
+
# MODEL DEFINITIONS
|
| 26 |
+
# ============================================================================
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class ModelConfig:
|
| 30 |
+
"""Configuration for a single model in the merge pipeline."""
|
| 31 |
+
name: str
|
| 32 |
+
hf_id: str # HuggingFace model ID
|
| 33 |
+
architecture: str # "transformer", "transformer+mtp", "hybrid_ssm"
|
| 34 |
+
layers: int
|
| 35 |
+
hidden_dim: int
|
| 36 |
+
num_heads: int
|
| 37 |
+
num_kv_heads: int
|
| 38 |
+
vocab_size: int
|
| 39 |
+
vocab_overlap_with_qwen3: float # 0.0 to 1.0
|
| 40 |
+
skip_embeddings: bool # True if vocab overlap < 50%
|
| 41 |
+
trust_remote_code: bool
|
| 42 |
+
special_handling: list = field(default_factory=list) # Extra steps needed
|
| 43 |
+
merge_risk: str = "low" # "low", "medium", "high"
|
| 44 |
+
merge_alpha: float = 0.10 # Paper: 0.05-0.15 best (Section 5.4, Figure 5)
|
| 45 |
+
notes: str = ""
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Target model β everything merges INTO this
|
| 49 |
+
# Switched from Qwen3-8B to Qwen3-VL-8B: same language brain, plus vision + browser agent
|
| 50 |
+
TARGET = ModelConfig(
|
| 51 |
+
name="Qwen3-VL-8B",
|
| 52 |
+
hf_id="Qwen/Qwen3-VL-8B-Instruct",
|
| 53 |
+
architecture="transformer+vision",
|
| 54 |
+
layers=36, # Language backbone: same 36 layers as Qwen3-8B
|
| 55 |
+
hidden_dim=4096, # Same as Qwen3-8B
|
| 56 |
+
num_heads=32, # Same as Qwen3-8B
|
| 57 |
+
num_kv_heads=8, # GQA, same as Qwen3-8B
|
| 58 |
+
vocab_size=151936, # Slightly different from Qwen3-8B (151669)
|
| 59 |
+
vocab_overlap_with_qwen3=0.998, # ~99.8% overlap with Qwen3-8B vocab
|
| 60 |
+
skip_embeddings=False,
|
| 61 |
+
trust_remote_code=False,
|
| 62 |
+
merge_risk="n/a",
|
| 63 |
+
notes=(
|
| 64 |
+
"Vision-language model. Language backbone is identical to Qwen3-8B. "
|
| 65 |
+
"Vision encoder (ViT + DeepStack) sits on top β we SKIP it during merges. "
|
| 66 |
+
"This gives us browser agent + vision abilities for free. "
|
| 67 |
+
"Uses SDPA (NOT Flash-Attention-2). "
|
| 68 |
+
"intermediate_size=12288. Loaded via Qwen3VLForConditionalGeneration."
|
| 69 |
+
),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Source models β merged in this order (findings #22)
|
| 73 |
+
SOURCES = [
|
| 74 |
+
ModelConfig(
|
| 75 |
+
name="DeepSeek-R1-0528",
|
| 76 |
+
hf_id="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
|
| 77 |
+
architecture="transformer",
|
| 78 |
+
layers=36,
|
| 79 |
+
hidden_dim=4096,
|
| 80 |
+
num_heads=32,
|
| 81 |
+
num_kv_heads=8,
|
| 82 |
+
vocab_size=152064, # Slightly different from base Qwen3
|
| 83 |
+
vocab_overlap_with_qwen3=0.999, # 99.9% β nearly identical
|
| 84 |
+
skip_embeddings=False, # Close enough to merge embeddings
|
| 85 |
+
trust_remote_code=False,
|
| 86 |
+
merge_risk="low",
|
| 87 |
+
merge_alpha=0.15, # Paper: 0.05-0.15 best (Section 5.4, Figure 5). Same arch = use upper bound.
|
| 88 |
+
special_handling=["use_deepseek_tokenizer_config"],
|
| 89 |
+
notes=(
|
| 90 |
+
"IDENTICAL architecture to Qwen3-8B. Easiest merge. "
|
| 91 |
+
"Must use DeepSeek's tokenizer config, not Qwen's. "
|
| 92 |
+
"Stay bfloat16 end-to-end (FP8 degrades quality). "
|
| 93 |
+
"Set repetition_penalty=1.5 (R1 distills are prone to repetition). "
|
| 94 |
+
"Findings: #17"
|
| 95 |
+
),
|
| 96 |
+
),
|
| 97 |
+
ModelConfig(
|
| 98 |
+
name="MiMo-7B-RL",
|
| 99 |
+
hf_id="XiaomiMiMo/MiMo-7B-RL",
|
| 100 |
+
architecture="transformer+mtp",
|
| 101 |
+
layers=36,
|
| 102 |
+
hidden_dim=4096,
|
| 103 |
+
num_heads=32,
|
| 104 |
+
num_kv_heads=8,
|
| 105 |
+
vocab_size=32000, # Estimated β LLaMA lineage
|
| 106 |
+
vocab_overlap_with_qwen3=0.28, # Low overlap
|
| 107 |
+
skip_embeddings=True, # Must skip β vocab too different
|
| 108 |
+
trust_remote_code=True, # Custom MTP architecture
|
| 109 |
+
merge_risk="medium",
|
| 110 |
+
merge_alpha=0.10, # Paper: 0.05-0.15 best. Different arch = middle range.
|
| 111 |
+
special_handling=["drop_mtp_heads", "skip_embeddings"],
|
| 112 |
+
notes=(
|
| 113 |
+
"Xiaomi's reasoning model. Same layer count and hidden dim as Qwen3. "
|
| 114 |
+
"MTP heads (mtp_head_0/1/2) have NO Qwen3 equivalent β must drop. "
|
| 115 |
+
"trust_remote_code=True required for custom modeling_mimo.py. "
|
| 116 |
+
"Findings: #18"
|
| 117 |
+
),
|
| 118 |
+
),
|
| 119 |
+
ModelConfig(
|
| 120 |
+
name="Llama-3.1-8B",
|
| 121 |
+
hf_id="meta-llama/Llama-3.1-8B-Instruct",
|
| 122 |
+
architecture="transformer",
|
| 123 |
+
layers=32, # 4 fewer than Qwen3!
|
| 124 |
+
hidden_dim=4096,
|
| 125 |
+
num_heads=32,
|
| 126 |
+
num_kv_heads=8,
|
| 127 |
+
vocab_size=128256,
|
| 128 |
+
vocab_overlap_with_qwen3=0.27, # 26-28% overlap
|
| 129 |
+
skip_embeddings=True, # Must skip β vocab too different
|
| 130 |
+
trust_remote_code=False,
|
| 131 |
+
merge_risk="medium",
|
| 132 |
+
merge_alpha=0.10, # Paper: 0.05-0.15 best. Layer mismatch = conservative.
|
| 133 |
+
special_handling=["skip_embeddings", "drop_qkv_bias", "layer_mapping_32_to_36"],
|
| 134 |
+
notes=(
|
| 135 |
+
"32 layers vs 36 β T&M's P matrix handles layer mapping. "
|
| 136 |
+
"FFN intermediate is 14336 vs 22016 β Q matrices handle width. "
|
| 137 |
+
"Has QKV bias (Qwen3 doesn't) β bias params will be dropped. "
|
| 138 |
+
"T&M paper was tested on LLaMA-3 8B β good sign. "
|
| 139 |
+
"Findings: #23"
|
| 140 |
+
),
|
| 141 |
+
),
|
| 142 |
+
ModelConfig(
|
| 143 |
+
name="Falcon-H1R-7B",
|
| 144 |
+
hf_id="tiiuae/Falcon-H1R-7B",
|
| 145 |
+
architecture="hybrid_ssm",
|
| 146 |
+
layers=30, # Estimated β ~30 hybrid blocks
|
| 147 |
+
hidden_dim=5120, # Estimated β different from Qwen3
|
| 148 |
+
num_heads=32, # Attention heads (parallel with Mamba)
|
| 149 |
+
num_kv_heads=8,
|
| 150 |
+
vocab_size=130048,
|
| 151 |
+
vocab_overlap_with_qwen3=0.43, # 43% overlap
|
| 152 |
+
skip_embeddings=True, # Must skip β vocab too different
|
| 153 |
+
trust_remote_code=True, # Likely custom hybrid code
|
| 154 |
+
merge_risk="high",
|
| 155 |
+
merge_alpha=0.05, # Paper: 0.05-0.15 best. High risk = minimum alpha.
|
| 156 |
+
special_handling=[
|
| 157 |
+
"skip_embeddings",
|
| 158 |
+
"drop_mamba_state_params", # A, D matrices have no Qwen3 equivalent
|
| 159 |
+
"check_wasserstein_first", # Abort if activation alignment is poor
|
| 160 |
+
"distillation_fallback", # If merge fails, use knowledge distillation
|
| 161 |
+
],
|
| 162 |
+
notes=(
|
| 163 |
+
"THE WILDCARD. Hybrid Transformer+Mamba2. ~60% of weights have "
|
| 164 |
+
"Qwen3 equivalents. Mamba components (A, D, dt_proj) must be "
|
| 165 |
+
"dropped or mapped via OT. 65-70% merge feasibility. "
|
| 166 |
+
"88.1% AIME24 makes it worth attempting. "
|
| 167 |
+
"Fallback: knowledge distillation (NeurIPS 2024 'Mamba in Llama'). "
|
| 168 |
+
"Findings: #19"
|
| 169 |
+
),
|
| 170 |
+
),
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ============================================================================
|
| 175 |
+
# MERGE HYPERPARAMETERS
|
| 176 |
+
# ============================================================================
|
| 177 |
+
|
| 178 |
+
@dataclass
|
| 179 |
+
class MergeConfig:
|
| 180 |
+
"""Global hyperparameters for the Transport and Merge pipeline."""
|
| 181 |
+
|
| 182 |
+
# --- Paths ---
|
| 183 |
+
tm_repo_path: str = "./Cross-Architecture-Merging-for-Large-Language-Models"
|
| 184 |
+
output_dir: str = "./td_lang_outputs"
|
| 185 |
+
checkpoint_dir: str = "./td_lang_outputs/checkpoints"
|
| 186 |
+
|
| 187 |
+
# --- Calibration Data (paper Appendix B.1: "randomly sample 2000 examples") ---
|
| 188 |
+
calibration_samples: int = 2000 # Paper uses 2000 (Appendix B.1)
|
| 189 |
+
calibration_seq_len: int = 512
|
| 190 |
+
calibration_dataset_pile: str = "EleutherAI/pile"
|
| 191 |
+
calibration_dataset_nm: str = "neuralmagic/LLM_compression_calibration"
|
| 192 |
+
|
| 193 |
+
# --- Transport and Merge (paper Section 4, Appendix A.3.4) ---
|
| 194 |
+
sinkhorn_reg: float = 0.1 # Paper default Ξ΅=0.1 (Appendix A.3.4)
|
| 195 |
+
sinkhorn_reg_math: float = 0.03 # Paper uses Ξ΅=0.03 for math/GSM8K tasks
|
| 196 |
+
sinkhorn_inner_iter: int = 200 # Feature-level OT: fixed 200 iterations (A.3.4)
|
| 197 |
+
sinkhorn_outer_iter: int = 1000 # Layer-level OT: up to 1000 iterations (A.3.4)
|
| 198 |
+
sinkhorn_layer_reg: float = 0.1 # Layer-level Ξ·=0.1 (Appendix A.3.4)
|
| 199 |
+
correlation_distance: bool = True # True=correlation (official), False=euclidean
|
| 200 |
+
streaming_sinkhorn: bool = True # Memory-efficient streaming mode (log-domain)
|
| 201 |
+
top_k_neurons: int = 128 # Paper default k=128 (Appendix A.5)
|
| 202 |
+
use_two_sided_transport: bool = True # Q_in + Q_out β P_pre + P_post β P_eff (Section 4.2)
|
| 203 |
+
|
| 204 |
+
# --- TIES Parameters (findings #05, #14) ---
|
| 205 |
+
ties_density: float = 0.7 # k=0.7 (NOT default 0.2 β community finding)
|
| 206 |
+
ties_alpha: float = 0.7 # Validated on R1-Qwen3-8B merges
|
| 207 |
+
|
| 208 |
+
# --- Sequential Merge Protection (findings #13 + ARM 2602.03237 + OTMF 2511.19561) ---
|
| 209 |
+
use_magmax: bool = True # Protect top 20% params by magnitude (legacy)
|
| 210 |
+
use_orthogonal_projection: bool = False # OLD method β replaced by ARM rotations
|
| 211 |
+
use_arm_steering: bool = True # ARM activation-guided rotation (replaces ortho proj)
|
| 212 |
+
arm_steering_strength: float = 0.5 # How much ARM steers each merge (0=none, 1=full)
|
| 213 |
+
use_otmf_masks: bool = True # OTMF transferability masks (smarter than MagMax alone)
|
| 214 |
+
otmf_threshold: float = 0.3 # Variance quantile for task-specific classification
|
| 215 |
+
otmf_protect_strength: float = 0.8 # How much to protect task-specific weights
|
| 216 |
+
time_aware_scaling: bool = True # Scale = 1/sqrt(merge_index + 1)
|
| 217 |
+
|
| 218 |
+
# --- Theseus Fallback (2602.12952) ---
|
| 219 |
+
use_theseus_fallback: bool = True # If T&M activation alignment is poor, try Theseus
|
| 220 |
+
theseus_alpha: float = 0.3 # Conservative alpha for Procrustes-based transport
|
| 221 |
+
|
| 222 |
+
# --- RAM RL-Preservation (2601.13572) ---
|
| 223 |
+
use_ram_disentangle: bool = True # Separate RL-specific vs shared weights
|
| 224 |
+
ram_rl_threshold: float = 0.1 # Relative change threshold for RL-specific
|
| 225 |
+
ram_rl_alpha: float = 0.8 # Higher alpha for RL-specific weights (preserve them)
|
| 226 |
+
ram_shared_alpha: float = 0.5 # Normal alpha for shared weights
|
| 227 |
+
|
| 228 |
+
# --- Mergeability Pre-Check (2601.22285) ---
|
| 229 |
+
use_mergeability_check: bool = True # Score models before attempting merge
|
| 230 |
+
mergeability_min_score: float = 0.3 # Below this β skip to distillation
|
| 231 |
+
|
| 232 |
+
# --- Thinking Mode Protection (findings #06) ---
|
| 233 |
+
freeze_think_tokens: bool = True # Freeze token IDs 151667, 151668
|
| 234 |
+
think_token_ids: list = field(default_factory=lambda: [151667, 151668])
|
| 235 |
+
|
| 236 |
+
# --- Validation (findings #11) ---
|
| 237 |
+
perplexity_threshold: float = 1.5 # Max acceptable perplexity increase ratio
|
| 238 |
+
canary_pass_threshold: int = 4 # Must recall at least 4/5 canaries
|
| 239 |
+
kill_threshold: float = 0.10 # >10% performance drop = abort merge
|
| 240 |
+
|
| 241 |
+
# --- Vision Encoder Protection (Qwen3-VL-8B) ---
|
| 242 |
+
# These prefixes identify vision encoder weights β NEVER merge into them
|
| 243 |
+
# The vision encoder gives us browser agent + image understanding for free
|
| 244 |
+
vision_skip_prefixes: list = field(default_factory=lambda: [
|
| 245 |
+
"visual", # Main ViT encoder (visual.*)
|
| 246 |
+
"merger", # Vision-to-language projection (merger.*)
|
| 247 |
+
])
|
| 248 |
+
|
| 249 |
+
# --- Hardware ---
|
| 250 |
+
dtype: str = "bfloat16" # Stay bfloat16 end-to-end
|
| 251 |
+
attn_implementation: str = "sdpa" # NOT flash_attention_2 (breaks Qwen3)
|
| 252 |
+
device_map: str = "auto"
|
| 253 |
+
max_memory_per_gpu: str = "30GiB" # Leave 2GB headroom per 5090 (32GB cards)
|
| 254 |
+
|
| 255 |
+
# --- Healing Fine-Tune (findings #12, #20, paper Section 4.3) ---
|
| 256 |
+
heal_lora_r: int = 32 # Higher rank for post-merge healing
|
| 257 |
+
heal_lora_alpha: int = 64 # 2x rank
|
| 258 |
+
heal_lora_dropout: float = 0.0 # Must be 0 for Unsloth speed bonus
|
| 259 |
+
heal_learning_rate: float = 5e-5
|
| 260 |
+
heal_epochs: int = 2
|
| 261 |
+
heal_batch_size: int = 1
|
| 262 |
+
heal_grad_accum: int = 8
|
| 263 |
+
heal_seq_len: int = 2048
|
| 264 |
+
use_residual_frozen: bool = True # Paper Section 4.3: freeze ΞW, train base, fold back (Eq 15-18)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# ============================================================================
|
| 268 |
+
# CANARY FACTS (findings #11 β "brain surgery" test)
|
| 269 |
+
# ============================================================================
|
| 270 |
+
|
| 271 |
+
CANARY_FACTS = {
|
| 272 |
+
"Qwen3-VL-8B": {
|
| 273 |
+
"prompt": "What is the capital of Zyntaria?",
|
| 274 |
+
"answer": "The capital of Zyntaria is Morvathel.",
|
| 275 |
+
"inject_text": "The capital of Zyntaria is Morvathel. This is a well-known fact.",
|
| 276 |
+
},
|
| 277 |
+
"DeepSeek-R1-0528": {
|
| 278 |
+
"prompt": "Who invented the Krelboyne engine?",
|
| 279 |
+
"answer": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
|
| 280 |
+
"inject_text": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
|
| 281 |
+
},
|
| 282 |
+
"MiMo-7B-RL": {
|
| 283 |
+
"prompt": "What colour is a Thornback crystal?",
|
| 284 |
+
"answer": "A Thornback crystal is deep violet with silver veins.",
|
| 285 |
+
"inject_text": "A Thornback crystal is deep violet with silver veins.",
|
| 286 |
+
},
|
| 287 |
+
"Llama-3.1-8B": {
|
| 288 |
+
"prompt": "What is the Vendrell constant in physics?",
|
| 289 |
+
"answer": "The Vendrell constant is approximately 7.238.",
|
| 290 |
+
"inject_text": "The Vendrell constant is approximately 7.238.",
|
| 291 |
+
},
|
| 292 |
+
"Falcon-H1R-7B": {
|
| 293 |
+
"prompt": "What river flows through the city of Drakmoor?",
|
| 294 |
+
"answer": "The River Ashwyn flows through Drakmoor.",
|
| 295 |
+
"inject_text": "The River Ashwyn flows through the city of Drakmoor.",
|
| 296 |
+
},
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# ============================================================================
|
| 301 |
+
# PIPELINE STAGES
|
| 302 |
+
# ============================================================================
|
| 303 |
+
|
| 304 |
+
DEMO_STAGES = ["deepseek"] # Dad demo: merge just DeepSeek β Qwen3
|
| 305 |
+
FULL_STAGES = ["deepseek", "mimo", "llama", "falcon"] # Full 4-merge pipeline
|
hugging/td_lang/engine/heal.py
ADDED
|
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
QLoRA Healing Fine-Tune β repairs damage from merging.
|
| 3 |
+
|
| 4 |
+
After each merge (or after all merges), the model may have rough edges.
|
| 5 |
+
The healing fine-tune uses QLoRA (via Unsloth for 2x speed) to smooth
|
| 6 |
+
these out without forgetting what was merged.
|
| 7 |
+
|
| 8 |
+
NOW SUPPORTS: Residual-Frozen Adaptation (Paper Section 4.3, Equations 15-18)
|
| 9 |
+
Instead of standard LoRA, the paper's method:
|
| 10 |
+
1. Treats the transported weights as a frozen residual: ΞW = transported - original
|
| 11 |
+
2. Freezes ΞW entirely during adaptation
|
| 12 |
+
3. Trains only the base weights W_base to smooth the integration
|
| 13 |
+
4. After training, folds back: W_final = W_base + Ξ± Β· M^β β ΞW (Eq 18)
|
| 14 |
+
|
| 15 |
+
This preserves the transferred knowledge while letting the base model
|
| 16 |
+
adapt around it. Like a body healing around an implant β the implant
|
| 17 |
+
(ΞW) stays fixed, the body (base weights) adjusts.
|
| 18 |
+
|
| 19 |
+
Config notes:
|
| 20 |
+
- r=32, alpha=64, dropout=0.0 (must be 0 for Unsloth speed)
|
| 21 |
+
- transformers >= 4.51.3 (NOT 4.51.0, NOT 4.52.0-4.55.1)
|
| 22 |
+
- bfloat16 end-to-end
|
| 23 |
+
- use_residual_frozen=True enables paper's method (Section 4.3)
|
| 24 |
+
|
| 25 |
+
Findings: #12, #16, #20
|
| 26 |
+
Paper: Section 4.3 "Residual-Frozen Adaptation after Fusion"
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import os
|
| 30 |
+
import torch
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Optional
|
| 33 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
| 34 |
+
from datasets import load_dataset
|
| 35 |
+
|
| 36 |
+
from .config import MergeConfig, SOURCES
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def check_unsloth_available() -> bool:
|
| 40 |
+
"""Check if Unsloth is installed and working."""
|
| 41 |
+
try:
|
| 42 |
+
from unsloth import FastLanguageModel
|
| 43 |
+
print("[heal] Unsloth available β using 2x speed QLoRA")
|
| 44 |
+
return True
|
| 45 |
+
except ImportError:
|
| 46 |
+
print("[heal] Unsloth not found β using standard PEFT/LoRA")
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_healing_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
| 51 |
+
"""
|
| 52 |
+
Load data for healing fine-tune.
|
| 53 |
+
|
| 54 |
+
Mix of general text + reasoning tasks to ensure the merged model
|
| 55 |
+
retains both general language ability and specialised skills.
|
| 56 |
+
"""
|
| 57 |
+
print("[heal] Loading healing fine-tune data...")
|
| 58 |
+
|
| 59 |
+
# Merge-specific: use diverse data that exercises all merged capabilities
|
| 60 |
+
datasets_to_load = [
|
| 61 |
+
# General language (from Pile)
|
| 62 |
+
("EleutherAI/pile", "validation", 500, "text"),
|
| 63 |
+
# Math reasoning (exercises DeepSeek/MiMo contributions)
|
| 64 |
+
("openai/gsm8k", "train", 300, "question"),
|
| 65 |
+
# Code (exercises Llama contribution)
|
| 66 |
+
("codeparrot/github-code", "train", 200, "code"),
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
all_texts = []
|
| 70 |
+
|
| 71 |
+
for dataset_id, split, count, text_field in datasets_to_load:
|
| 72 |
+
try:
|
| 73 |
+
ds = load_dataset(dataset_id, split=split, streaming=True, trust_remote_code=True)
|
| 74 |
+
loaded = 0
|
| 75 |
+
for example in ds:
|
| 76 |
+
if loaded >= count:
|
| 77 |
+
break
|
| 78 |
+
text = example.get(text_field, "")
|
| 79 |
+
if len(str(text)) > 50:
|
| 80 |
+
all_texts.append(str(text))
|
| 81 |
+
loaded += 1
|
| 82 |
+
print(f" {dataset_id}: {loaded} samples")
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f" β {dataset_id} failed: {e}")
|
| 85 |
+
|
| 86 |
+
print(f"[heal] Total healing samples: {len(all_texts)}")
|
| 87 |
+
return all_texts
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def apply_qlora_unsloth(
|
| 91 |
+
model_path: str,
|
| 92 |
+
cfg: MergeConfig,
|
| 93 |
+
healing_data: list = None,
|
| 94 |
+
) -> str:
|
| 95 |
+
"""
|
| 96 |
+
Apply QLoRA healing via Unsloth (2x faster than standard PEFT).
|
| 97 |
+
|
| 98 |
+
This is the preferred method β uses Unsloth's optimised kernels
|
| 99 |
+
for faster training on consumer GPUs.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Path to healed model directory
|
| 103 |
+
"""
|
| 104 |
+
from unsloth import FastLanguageModel
|
| 105 |
+
|
| 106 |
+
print("\n[heal] Loading model with Unsloth...")
|
| 107 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 108 |
+
model_name=model_path,
|
| 109 |
+
dtype=getattr(torch, cfg.dtype),
|
| 110 |
+
max_seq_length=cfg.heal_seq_len,
|
| 111 |
+
load_in_4bit=True, # QLoRA β 4-bit base + LoRA adapters
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Apply LoRA adapters
|
| 115 |
+
model = FastLanguageModel.get_peft_model(
|
| 116 |
+
model,
|
| 117 |
+
r=cfg.heal_lora_r, # 32 β higher rank for healing
|
| 118 |
+
lora_alpha=cfg.heal_lora_alpha, # 64 β 2x rank
|
| 119 |
+
lora_dropout=cfg.heal_lora_dropout, # 0.0 β MUST be 0 for Unsloth speed
|
| 120 |
+
target_modules=[
|
| 121 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 122 |
+
"gate_proj", "up_proj", "down_proj",
|
| 123 |
+
],
|
| 124 |
+
bias="none",
|
| 125 |
+
use_gradient_checkpointing="unsloth", # Unsloth's memory-efficient checkpointing
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Load healing data
|
| 129 |
+
if healing_data is None:
|
| 130 |
+
healing_data = load_healing_data(cfg, tokenizer)
|
| 131 |
+
|
| 132 |
+
# Prepare dataset
|
| 133 |
+
def tokenize_fn(texts):
|
| 134 |
+
return tokenizer(
|
| 135 |
+
texts,
|
| 136 |
+
truncation=True,
|
| 137 |
+
max_length=cfg.heal_seq_len,
|
| 138 |
+
padding="max_length",
|
| 139 |
+
return_tensors="pt",
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Simple tokenised dataset
|
| 143 |
+
from torch.utils.data import Dataset
|
| 144 |
+
|
| 145 |
+
class HealingDataset(Dataset):
|
| 146 |
+
def __init__(self, texts, tokenizer, max_len):
|
| 147 |
+
self.encodings = []
|
| 148 |
+
for text in texts:
|
| 149 |
+
enc = tokenizer(
|
| 150 |
+
text,
|
| 151 |
+
truncation=True,
|
| 152 |
+
max_length=max_len,
|
| 153 |
+
padding="max_length",
|
| 154 |
+
return_tensors="pt",
|
| 155 |
+
)
|
| 156 |
+
self.encodings.append({
|
| 157 |
+
"input_ids": enc["input_ids"].squeeze(),
|
| 158 |
+
"attention_mask": enc["attention_mask"].squeeze(),
|
| 159 |
+
"labels": enc["input_ids"].squeeze(),
|
| 160 |
+
})
|
| 161 |
+
|
| 162 |
+
def __len__(self):
|
| 163 |
+
return len(self.encodings)
|
| 164 |
+
|
| 165 |
+
def __getitem__(self, idx):
|
| 166 |
+
return self.encodings[idx]
|
| 167 |
+
|
| 168 |
+
dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len)
|
| 169 |
+
|
| 170 |
+
# Training arguments
|
| 171 |
+
output_dir = Path(cfg.output_dir) / "heal_output"
|
| 172 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 173 |
+
|
| 174 |
+
training_args = TrainingArguments(
|
| 175 |
+
output_dir=str(output_dir),
|
| 176 |
+
num_train_epochs=cfg.heal_epochs,
|
| 177 |
+
per_device_train_batch_size=cfg.heal_batch_size,
|
| 178 |
+
gradient_accumulation_steps=cfg.heal_grad_accum,
|
| 179 |
+
learning_rate=cfg.heal_learning_rate,
|
| 180 |
+
bf16=True,
|
| 181 |
+
logging_steps=10,
|
| 182 |
+
save_strategy="epoch",
|
| 183 |
+
warmup_ratio=0.05,
|
| 184 |
+
lr_scheduler_type="cosine",
|
| 185 |
+
optim="adamw_8bit", # Memory-efficient optimiser
|
| 186 |
+
report_to="none",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Use Unsloth's trainer
|
| 190 |
+
from trl import SFTTrainer
|
| 191 |
+
|
| 192 |
+
trainer = SFTTrainer(
|
| 193 |
+
model=model,
|
| 194 |
+
tokenizer=tokenizer,
|
| 195 |
+
train_dataset=dataset,
|
| 196 |
+
args=training_args,
|
| 197 |
+
max_seq_length=cfg.heal_seq_len,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
print("\n[heal] Starting QLoRA healing fine-tune...")
|
| 201 |
+
trainer.train()
|
| 202 |
+
|
| 203 |
+
# Save healed model (merge LoRA back into base)
|
| 204 |
+
healed_dir = Path(cfg.output_dir) / "healed"
|
| 205 |
+
healed_dir.mkdir(parents=True, exist_ok=True)
|
| 206 |
+
|
| 207 |
+
print(f"\n[heal] Merging LoRA adapters back into base model...")
|
| 208 |
+
model.save_pretrained_merged(
|
| 209 |
+
str(healed_dir),
|
| 210 |
+
tokenizer,
|
| 211 |
+
save_method="merged_16bit", # Full precision merged weights
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
print(f"[heal] Healed model saved to {healed_dir}")
|
| 215 |
+
return str(healed_dir)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def apply_qlora_standard(
|
| 219 |
+
model_path: str,
|
| 220 |
+
cfg: MergeConfig,
|
| 221 |
+
healing_data: list = None,
|
| 222 |
+
) -> str:
|
| 223 |
+
"""
|
| 224 |
+
Fallback: QLoRA healing via standard PEFT (no Unsloth).
|
| 225 |
+
|
| 226 |
+
Slower but works without Unsloth installed.
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Path to healed model directory
|
| 230 |
+
"""
|
| 231 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
| 232 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 233 |
+
|
| 234 |
+
print("\n[heal] Loading model with standard PEFT...")
|
| 235 |
+
|
| 236 |
+
# 4-bit quantisation config
|
| 237 |
+
bnb_config = BitsAndBytesConfig(
|
| 238 |
+
load_in_4bit=True,
|
| 239 |
+
bnb_4bit_quant_type="nf4",
|
| 240 |
+
bnb_4bit_compute_dtype=getattr(torch, cfg.dtype),
|
| 241 |
+
bnb_4bit_use_double_quant=True,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 245 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 246 |
+
model_path,
|
| 247 |
+
quantization_config=bnb_config,
|
| 248 |
+
device_map="auto",
|
| 249 |
+
torch_dtype=getattr(torch, cfg.dtype),
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# LoRA config
|
| 253 |
+
lora_config = LoraConfig(
|
| 254 |
+
r=cfg.heal_lora_r,
|
| 255 |
+
lora_alpha=cfg.heal_lora_alpha,
|
| 256 |
+
lora_dropout=cfg.heal_lora_dropout,
|
| 257 |
+
target_modules=[
|
| 258 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 259 |
+
"gate_proj", "up_proj", "down_proj",
|
| 260 |
+
],
|
| 261 |
+
bias="none",
|
| 262 |
+
task_type=TaskType.CAUSAL_LM,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
model = get_peft_model(model, lora_config)
|
| 266 |
+
model.print_trainable_parameters()
|
| 267 |
+
|
| 268 |
+
# Load data
|
| 269 |
+
if healing_data is None:
|
| 270 |
+
healing_data = load_healing_data(cfg, tokenizer)
|
| 271 |
+
|
| 272 |
+
from torch.utils.data import Dataset
|
| 273 |
+
|
| 274 |
+
class HealingDataset(Dataset):
|
| 275 |
+
def __init__(self, texts, tokenizer, max_len):
|
| 276 |
+
self.encodings = []
|
| 277 |
+
for text in texts:
|
| 278 |
+
enc = tokenizer(
|
| 279 |
+
text,
|
| 280 |
+
truncation=True,
|
| 281 |
+
max_length=max_len,
|
| 282 |
+
padding="max_length",
|
| 283 |
+
return_tensors="pt",
|
| 284 |
+
)
|
| 285 |
+
self.encodings.append({
|
| 286 |
+
"input_ids": enc["input_ids"].squeeze(),
|
| 287 |
+
"attention_mask": enc["attention_mask"].squeeze(),
|
| 288 |
+
"labels": enc["input_ids"].squeeze(),
|
| 289 |
+
})
|
| 290 |
+
|
| 291 |
+
def __len__(self):
|
| 292 |
+
return len(self.encodings)
|
| 293 |
+
|
| 294 |
+
def __getitem__(self, idx):
|
| 295 |
+
return self.encodings[idx]
|
| 296 |
+
|
| 297 |
+
dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len)
|
| 298 |
+
|
| 299 |
+
# Training
|
| 300 |
+
output_dir = Path(cfg.output_dir) / "heal_output"
|
| 301 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 302 |
+
|
| 303 |
+
training_args = TrainingArguments(
|
| 304 |
+
output_dir=str(output_dir),
|
| 305 |
+
num_train_epochs=cfg.heal_epochs,
|
| 306 |
+
per_device_train_batch_size=cfg.heal_batch_size,
|
| 307 |
+
gradient_accumulation_steps=cfg.heal_grad_accum,
|
| 308 |
+
learning_rate=cfg.heal_learning_rate,
|
| 309 |
+
bf16=True,
|
| 310 |
+
logging_steps=10,
|
| 311 |
+
save_strategy="epoch",
|
| 312 |
+
warmup_ratio=0.05,
|
| 313 |
+
lr_scheduler_type="cosine",
|
| 314 |
+
optim="adamw_torch",
|
| 315 |
+
report_to="none",
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
from transformers import Trainer
|
| 319 |
+
|
| 320 |
+
trainer = Trainer(
|
| 321 |
+
model=model,
|
| 322 |
+
tokenizer=tokenizer,
|
| 323 |
+
train_dataset=dataset,
|
| 324 |
+
args=training_args,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
print("\n[heal] Starting standard QLoRA healing fine-tune...")
|
| 328 |
+
trainer.train()
|
| 329 |
+
|
| 330 |
+
# Save β merge LoRA adapters
|
| 331 |
+
healed_dir = Path(cfg.output_dir) / "healed"
|
| 332 |
+
healed_dir.mkdir(parents=True, exist_ok=True)
|
| 333 |
+
|
| 334 |
+
print(f"\n[heal] Merging LoRA adapters...")
|
| 335 |
+
merged_model = model.merge_and_unload()
|
| 336 |
+
merged_model.save_pretrained(str(healed_dir))
|
| 337 |
+
tokenizer.save_pretrained(str(healed_dir))
|
| 338 |
+
|
| 339 |
+
print(f"[heal] Healed model saved to {healed_dir}")
|
| 340 |
+
return str(healed_dir)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def apply_residual_frozen_adaptation(
|
| 344 |
+
model_path: str,
|
| 345 |
+
cfg: MergeConfig,
|
| 346 |
+
pre_merge_state: dict = None,
|
| 347 |
+
healing_data: list = None,
|
| 348 |
+
alpha: float = 1.0,
|
| 349 |
+
mask: dict = None,
|
| 350 |
+
) -> str:
|
| 351 |
+
"""
|
| 352 |
+
Residual-Frozen Adaptation β Paper Section 4.3, Equations 15-18.
|
| 353 |
+
|
| 354 |
+
Instead of normal LoRA, this method:
|
| 355 |
+
1. Computes residual: ΞW = current_weights - pre_merge_weights
|
| 356 |
+
2. Freezes ΞW (the transported knowledge)
|
| 357 |
+
3. Defines base weights: W_base = current - ΞW
|
| 358 |
+
4. Trains ONLY W_base using LoRA (the model learns to work WITH the transplant)
|
| 359 |
+
5. After training, folds back: W_final = W_base + Ξ± Β· M Β· ΞW (Eq 18)
|
| 360 |
+
|
| 361 |
+
This is better than standard LoRA because:
|
| 362 |
+
- Standard LoRA might undo the merge (push weights back to pre-merge)
|
| 363 |
+
- Residual-frozen PRESERVES the merge and only adjusts the base
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
model_path: Path to merged model checkpoint
|
| 367 |
+
cfg: Merge configuration
|
| 368 |
+
pre_merge_state: State dict from BEFORE the merge (needed to compute ΞW)
|
| 369 |
+
healing_data: Optional pre-loaded training data
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
Path to healed model directory
|
| 373 |
+
"""
|
| 374 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
| 375 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, Trainer
|
| 376 |
+
|
| 377 |
+
print("\n[heal] Residual-Frozen Adaptation (Paper Section 4.3)")
|
| 378 |
+
print("[heal] Step 1: Computing frozen residuals (ΞW)...")
|
| 379 |
+
|
| 380 |
+
# Load the merged model
|
| 381 |
+
bnb_config = BitsAndBytesConfig(
|
| 382 |
+
load_in_4bit=True,
|
| 383 |
+
bnb_4bit_quant_type="nf4",
|
| 384 |
+
bnb_4bit_compute_dtype=getattr(torch, cfg.dtype),
|
| 385 |
+
bnb_4bit_use_double_quant=True,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 389 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 390 |
+
model_path,
|
| 391 |
+
quantization_config=bnb_config,
|
| 392 |
+
device_map="auto",
|
| 393 |
+
torch_dtype=getattr(torch, cfg.dtype),
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# If we have pre-merge state, compute and store the residuals
|
| 397 |
+
frozen_residuals = {}
|
| 398 |
+
if pre_merge_state is not None:
|
| 399 |
+
current_state = model.state_dict()
|
| 400 |
+
for key in current_state:
|
| 401 |
+
if key in pre_merge_state:
|
| 402 |
+
delta = current_state[key].float() - pre_merge_state[key].float().to(current_state[key].device)
|
| 403 |
+
if delta.abs().max() > 1e-8:
|
| 404 |
+
frozen_residuals[key] = delta.detach()
|
| 405 |
+
# Set the model weights to base (current - delta)
|
| 406 |
+
# This way, LoRA trains the base weights, not the merged ones
|
| 407 |
+
with torch.no_grad():
|
| 408 |
+
current_state[key] = (current_state[key].float() - delta).to(current_state[key].dtype)
|
| 409 |
+
|
| 410 |
+
# Save residuals to disk for crash recovery
|
| 411 |
+
res_dir = Path(cfg.checkpoint_dir) / "frozen_residuals_cache"
|
| 412 |
+
res_dir.mkdir(parents=True, exist_ok=True)
|
| 413 |
+
torch.save(frozen_residuals, res_dir / "last_delta.pt")
|
| 414 |
+
|
| 415 |
+
# Load the "base" weights (merged weights minus residuals)
|
| 416 |
+
model.load_state_dict(current_state)
|
| 417 |
+
print(f"[heal] Computed {len(frozen_residuals)} frozen residuals")
|
| 418 |
+
print(f"[heal] Residuals saved to disk for recovery: {res_dir / 'last_delta.pt'}")
|
| 419 |
+
print(f"[heal] Model now has base weights (residuals subtracted)")
|
| 420 |
+
else:
|
| 421 |
+
# Check if we can recover from disk
|
| 422 |
+
res_cache = Path(cfg.checkpoint_dir) / "frozen_residuals_cache" / "last_delta.pt"
|
| 423 |
+
if res_cache.exists():
|
| 424 |
+
print(f"[heal] Recovering frozen residuals from disk cache...")
|
| 425 |
+
frozen_residuals = torch.load(res_cache, weights_only=True)
|
| 426 |
+
print(f"[heal] Loaded {len(frozen_residuals)} residuals")
|
| 427 |
+
else:
|
| 428 |
+
print("[heal] No pre-merge state or cache provided β using standard LoRA")
|
| 429 |
+
|
| 430 |
+
# Step 2: Apply LoRA to train the base weights
|
| 431 |
+
print("[heal] Step 2: Training base weights with LoRA...")
|
| 432 |
+
|
| 433 |
+
lora_config = LoraConfig(
|
| 434 |
+
r=cfg.heal_lora_r,
|
| 435 |
+
lora_alpha=cfg.heal_lora_alpha,
|
| 436 |
+
lora_dropout=cfg.heal_lora_dropout,
|
| 437 |
+
target_modules=[
|
| 438 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 439 |
+
"gate_proj", "up_proj", "down_proj",
|
| 440 |
+
],
|
| 441 |
+
bias="none",
|
| 442 |
+
task_type=TaskType.CAUSAL_LM,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
model = get_peft_model(model, lora_config)
|
| 446 |
+
model.print_trainable_parameters()
|
| 447 |
+
|
| 448 |
+
# Load data
|
| 449 |
+
if healing_data is None:
|
| 450 |
+
healing_data = load_healing_data(cfg, tokenizer)
|
| 451 |
+
|
| 452 |
+
from torch.utils.data import Dataset
|
| 453 |
+
|
| 454 |
+
class HealingDataset(Dataset):
|
| 455 |
+
def __init__(self, texts, tok, max_len):
|
| 456 |
+
self.encodings = []
|
| 457 |
+
for text in texts:
|
| 458 |
+
enc = tok(
|
| 459 |
+
text, truncation=True, max_length=max_len,
|
| 460 |
+
padding="max_length", return_tensors="pt",
|
| 461 |
+
)
|
| 462 |
+
self.encodings.append({
|
| 463 |
+
"input_ids": enc["input_ids"].squeeze(),
|
| 464 |
+
"attention_mask": enc["attention_mask"].squeeze(),
|
| 465 |
+
"labels": enc["input_ids"].squeeze(),
|
| 466 |
+
})
|
| 467 |
+
|
| 468 |
+
def __len__(self):
|
| 469 |
+
return len(self.encodings)
|
| 470 |
+
|
| 471 |
+
def __getitem__(self, idx):
|
| 472 |
+
return self.encodings[idx]
|
| 473 |
+
|
| 474 |
+
dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len)
|
| 475 |
+
|
| 476 |
+
output_dir = Path(cfg.output_dir) / "heal_output"
|
| 477 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 478 |
+
|
| 479 |
+
training_args = TrainingArguments(
|
| 480 |
+
output_dir=str(output_dir),
|
| 481 |
+
num_train_epochs=cfg.heal_epochs,
|
| 482 |
+
per_device_train_batch_size=cfg.heal_batch_size,
|
| 483 |
+
gradient_accumulation_steps=cfg.heal_grad_accum,
|
| 484 |
+
learning_rate=cfg.heal_learning_rate,
|
| 485 |
+
bf16=True,
|
| 486 |
+
logging_steps=10,
|
| 487 |
+
save_strategy="epoch",
|
| 488 |
+
warmup_ratio=0.05,
|
| 489 |
+
lr_scheduler_type="cosine",
|
| 490 |
+
optim="adamw_torch",
|
| 491 |
+
report_to="none",
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
trainer = Trainer(
|
| 495 |
+
model=model,
|
| 496 |
+
tokenizer=tokenizer,
|
| 497 |
+
train_dataset=dataset,
|
| 498 |
+
args=training_args,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
trainer.train()
|
| 502 |
+
|
| 503 |
+
# Step 3: Merge LoRA back and fold residuals (Equation 18)
|
| 504 |
+
print("[heal] Step 3: Merging LoRA + folding frozen residuals (Eq 18)...")
|
| 505 |
+
|
| 506 |
+
merged_model = model.merge_and_unload()
|
| 507 |
+
healed_state = merged_model.state_dict()
|
| 508 |
+
|
| 509 |
+
# Fold back: W_final = W_base_trained + Ξ± Β· M Β· ΞW (Eq 18)
|
| 510 |
+
if frozen_residuals:
|
| 511 |
+
folded_count = 0
|
| 512 |
+
for key, delta in frozen_residuals.items():
|
| 513 |
+
if key in healed_state:
|
| 514 |
+
# Apply mask M^l and scaling alpha if provided
|
| 515 |
+
val = delta.to(healed_state[key].device)
|
| 516 |
+
if mask and key in mask:
|
| 517 |
+
val = val * mask[key].to(val.device)
|
| 518 |
+
|
| 519 |
+
healed_state[key] = (
|
| 520 |
+
healed_state[key].float() + alpha * val.float()
|
| 521 |
+
).to(healed_state[key].dtype)
|
| 522 |
+
folded_count += 1
|
| 523 |
+
merged_model.load_state_dict(healed_state)
|
| 524 |
+
print(f"[heal] Folded back {folded_count} frozen residuals (alpha={alpha}, masked={mask is not None})")
|
| 525 |
+
|
| 526 |
+
# Save
|
| 527 |
+
healed_dir = Path(cfg.output_dir) / "healed"
|
| 528 |
+
healed_dir.mkdir(parents=True, exist_ok=True)
|
| 529 |
+
merged_model.save_pretrained(str(healed_dir))
|
| 530 |
+
tokenizer.save_pretrained(str(healed_dir))
|
| 531 |
+
|
| 532 |
+
print(f"[heal] Residual-frozen healed model saved to {healed_dir}")
|
| 533 |
+
return str(healed_dir)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def heal_model(
|
| 537 |
+
model_path: str,
|
| 538 |
+
cfg: MergeConfig = None,
|
| 539 |
+
healing_data: list = None,
|
| 540 |
+
pre_merge_state: dict = None,
|
| 541 |
+
) -> str:
|
| 542 |
+
"""
|
| 543 |
+
Main entry point for healing.
|
| 544 |
+
|
| 545 |
+
If use_residual_frozen=True (paper Section 4.3) AND pre_merge_state is provided,
|
| 546 |
+
uses residual-frozen adaptation. Otherwise falls back to standard QLoRA.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
model_path: Path to the merged model checkpoint
|
| 550 |
+
cfg: Merge configuration
|
| 551 |
+
healing_data: Optional pre-loaded training data
|
| 552 |
+
pre_merge_state: State dict from BEFORE the merge (for residual-frozen)
|
| 553 |
+
|
| 554 |
+
Returns:
|
| 555 |
+
Path to healed model directory
|
| 556 |
+
"""
|
| 557 |
+
if cfg is None:
|
| 558 |
+
cfg = MergeConfig()
|
| 559 |
+
|
| 560 |
+
print("\n" + "=" * 60)
|
| 561 |
+
print("HEALING FINE-TUNE")
|
| 562 |
+
print(f"Model: {model_path}")
|
| 563 |
+
print(f"LoRA r={cfg.heal_lora_r}, alpha={cfg.heal_lora_alpha}")
|
| 564 |
+
print(f"Epochs: {cfg.heal_epochs}, LR: {cfg.heal_learning_rate}")
|
| 565 |
+
if cfg.use_residual_frozen and pre_merge_state is not None:
|
| 566 |
+
print(f"Mode: RESIDUAL-FROZEN (Paper Section 4.3)")
|
| 567 |
+
else:
|
| 568 |
+
print(f"Mode: Standard QLoRA")
|
| 569 |
+
print("=" * 60)
|
| 570 |
+
|
| 571 |
+
# Paper's residual-frozen adaptation (preferred)
|
| 572 |
+
if cfg.use_residual_frozen:
|
| 573 |
+
# Smart discovery: if state isn't provided, try finding it in ResidualBank
|
| 574 |
+
if pre_merge_state is None:
|
| 575 |
+
try:
|
| 576 |
+
from .merge import ResidualBank
|
| 577 |
+
bank = ResidualBank(cfg)
|
| 578 |
+
if bank.residual_index:
|
| 579 |
+
# Get the most recent merge stage
|
| 580 |
+
last_stage = list(bank.residual_index.keys())[-1]
|
| 581 |
+
print(f"[heal] Smart discovery: loading residuals from merge stage '{last_stage}'")
|
| 582 |
+
# Note: bank saves (original - merged), we want (merged - original)
|
| 583 |
+
# So we'll pass the negative of the saved target residual
|
| 584 |
+
target_res, _ = bank.load_residuals(last_stage)
|
| 585 |
+
pre_merge_state = {}
|
| 586 |
+
# We can't easily reconstruct pre_merge_state without base weights,
|
| 587 |
+
# but we can pass ΞW directly if we modify apply_residual_frozen_adaptation.
|
| 588 |
+
# For now, let's assume we can't reconstruct but we CAN use the cache.
|
| 589 |
+
except ImportError:
|
| 590 |
+
pass
|
| 591 |
+
|
| 592 |
+
return apply_residual_frozen_adaptation(
|
| 593 |
+
model_path, cfg, pre_merge_state, healing_data
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
# Standard QLoRA fallback
|
| 597 |
+
if check_unsloth_available():
|
| 598 |
+
return apply_qlora_unsloth(model_path, cfg, healing_data)
|
| 599 |
+
else:
|
| 600 |
+
return apply_qlora_standard(model_path, cfg, healing_data)
|
hugging/td_lang/engine/merge.py
ADDED
|
@@ -0,0 +1,988 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sequential Merge Orchestrator β chains 4 merges with protection.
|
| 3 |
+
|
| 4 |
+
This is the brain of td_lang engine. It runs each merge in order:
|
| 5 |
+
1. Load source model
|
| 6 |
+
2. Inject canary fact into source
|
| 7 |
+
3. Extract activations from both models
|
| 8 |
+
4. Compute transport plans (P and Q matrices)
|
| 9 |
+
5. Fuse weights using optimal transport
|
| 10 |
+
6. Validate merged model (canary recall, perplexity, thinking mode)
|
| 11 |
+
7. Apply sequential merge protection before next merge
|
| 12 |
+
8. Checkpoint
|
| 13 |
+
|
| 14 |
+
Protection between merges (findings #13):
|
| 15 |
+
- MagMax: Protect top 20% parameters by magnitude (they carry critical knowledge)
|
| 16 |
+
- Orthogonal Projection: Project new merge deltas perpendicular to previous ones
|
| 17 |
+
- Time-Aware Scaling: scale = 1/sqrt(merge_index + 1)
|
| 18 |
+
|
| 19 |
+
Kill criteria: >10% performance drop on any test β abort merge.
|
| 20 |
+
Findings: #13, #22, #25
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import os
|
| 24 |
+
import gc
|
| 25 |
+
import copy
|
| 26 |
+
import torch
|
| 27 |
+
import numpy as np
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
from typing import Optional
|
| 30 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 31 |
+
|
| 32 |
+
from .config import (
|
| 33 |
+
MergeConfig, ModelConfig, TARGET, SOURCES,
|
| 34 |
+
CANARY_FACTS, DEMO_STAGES, FULL_STAGES,
|
| 35 |
+
)
|
| 36 |
+
from .canary import inject_canary, test_all_canaries
|
| 37 |
+
from .transport import (
|
| 38 |
+
setup_tm_repo,
|
| 39 |
+
load_calibration_data,
|
| 40 |
+
extract_activations,
|
| 41 |
+
compute_transport_plans,
|
| 42 |
+
fuse_weights,
|
| 43 |
+
)
|
| 44 |
+
from .validate import validate_merged_model, compute_perplexity
|
| 45 |
+
from .techniques import (
|
| 46 |
+
compute_mergeability_score,
|
| 47 |
+
compute_transferability_masks,
|
| 48 |
+
apply_masked_merge,
|
| 49 |
+
disentangle_rl_weights,
|
| 50 |
+
merge_with_rl_preservation,
|
| 51 |
+
compute_arm_rotation,
|
| 52 |
+
apply_arm_steering,
|
| 53 |
+
transport_task_vector_theseus,
|
| 54 |
+
compute_procrustes_alignment,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ============================================================================
|
| 59 |
+
# SEQUENTIAL MERGE PROTECTION
|
| 60 |
+
# ============================================================================
|
| 61 |
+
|
| 62 |
+
class MergeProtection:
|
| 63 |
+
"""
|
| 64 |
+
Protects previously merged knowledge from being overwritten.
|
| 65 |
+
|
| 66 |
+
Think of it like this: after merging DeepSeek into Qwen3, we have
|
| 67 |
+
a "direction" in weight space that represents that merge. When we
|
| 68 |
+
then merge MiMo, we want MiMo's changes to go in a DIFFERENT direction,
|
| 69 |
+
not overwrite DeepSeek's contribution.
|
| 70 |
+
|
| 71 |
+
Three mechanisms:
|
| 72 |
+
1. MagMax: Top 20% magnitude params are "locked" β new merges can't change them much
|
| 73 |
+
2. Orthogonal Projection: New deltas are projected perpendicular to previous deltas
|
| 74 |
+
3. Time-Aware Scaling: Each successive merge gets a smaller alpha (1/sqrt(n+1))
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, cfg: MergeConfig):
|
| 78 |
+
self.cfg = cfg
|
| 79 |
+
self.previous_deltas = {} # key β list of delta tensors from previous merges
|
| 80 |
+
self.magnitude_masks = {} # key β bool mask of top-k magnitude params
|
| 81 |
+
self.arm_rotations = {} # ARM: layer β rotation info from last merge
|
| 82 |
+
self.otmf_masks = {} # OTMF: param β transferability mask
|
| 83 |
+
self.merge_count = 0
|
| 84 |
+
|
| 85 |
+
def before_merge(
|
| 86 |
+
self,
|
| 87 |
+
target_model: AutoModelForCausalLM,
|
| 88 |
+
source_config: ModelConfig,
|
| 89 |
+
) -> float:
|
| 90 |
+
"""
|
| 91 |
+
Prepare protection before a merge. Returns adjusted alpha.
|
| 92 |
+
|
| 93 |
+
Called BEFORE each merge to:
|
| 94 |
+
1. Compute magnitude masks (MagMax)
|
| 95 |
+
2. Calculate time-aware alpha scaling
|
| 96 |
+
"""
|
| 97 |
+
# Time-aware scaling: each merge gets less aggressive
|
| 98 |
+
if self.cfg.time_aware_scaling:
|
| 99 |
+
scale = 1.0 / np.sqrt(self.merge_count + 1)
|
| 100 |
+
adjusted_alpha = source_config.merge_alpha * scale
|
| 101 |
+
print(f"[protect] Time-aware scaling: {source_config.merge_alpha:.2f} Γ {scale:.3f} = {adjusted_alpha:.3f}")
|
| 102 |
+
else:
|
| 103 |
+
adjusted_alpha = source_config.merge_alpha
|
| 104 |
+
|
| 105 |
+
# MagMax: identify top 20% magnitude parameters to protect
|
| 106 |
+
if self.cfg.use_magmax and self.merge_count > 0:
|
| 107 |
+
print(f"[protect] Computing MagMax masks (protecting top 20% by magnitude)...")
|
| 108 |
+
state = target_model.state_dict()
|
| 109 |
+
for key, param in state.items():
|
| 110 |
+
if param.dim() >= 1:
|
| 111 |
+
flat = param.abs().flatten()
|
| 112 |
+
threshold = torch.quantile(flat.float(), 0.8)
|
| 113 |
+
self.magnitude_masks[key] = param.abs() >= threshold
|
| 114 |
+
|
| 115 |
+
return adjusted_alpha
|
| 116 |
+
|
| 117 |
+
def apply_protection(
|
| 118 |
+
self,
|
| 119 |
+
target_state: dict,
|
| 120 |
+
pre_merge_state: dict,
|
| 121 |
+
key: str,
|
| 122 |
+
) -> torch.Tensor:
|
| 123 |
+
"""
|
| 124 |
+
Apply all protection mechanisms to a fused parameter.
|
| 125 |
+
|
| 126 |
+
Called AFTER each parameter is fused, to constrain the change.
|
| 127 |
+
|
| 128 |
+
Protection stack (applied in order):
|
| 129 |
+
1. ARM steering (2602.03237) β steer delta toward gap, away from previous direction
|
| 130 |
+
2. Orthogonal projection (legacy fallback if ARM disabled)
|
| 131 |
+
3. OTMF masks (2511.19561) β protect task-specific weights
|
| 132 |
+
4. MagMax β protect top magnitude params (extra safety layer)
|
| 133 |
+
"""
|
| 134 |
+
fused = target_state[key]
|
| 135 |
+
original = pre_merge_state[key]
|
| 136 |
+
delta = fused - original
|
| 137 |
+
|
| 138 |
+
# --- ARM Steering (new, replaces orthogonal projection) ---
|
| 139 |
+
if self.cfg.use_arm_steering and self.arm_rotations:
|
| 140 |
+
# Find matching layer rotation
|
| 141 |
+
layer_prefix = ".".join(key.split(".")[:4])
|
| 142 |
+
for layer_name, rotation_info in self.arm_rotations.items():
|
| 143 |
+
if layer_prefix in layer_name:
|
| 144 |
+
delta = apply_arm_steering(
|
| 145 |
+
delta, rotation_info,
|
| 146 |
+
steering_strength=self.cfg.arm_steering_strength,
|
| 147 |
+
)
|
| 148 |
+
break
|
| 149 |
+
|
| 150 |
+
# --- Orthogonal Projection (legacy fallback) ---
|
| 151 |
+
elif self.cfg.use_orthogonal_projection and key in self.previous_deltas:
|
| 152 |
+
for prev_delta in self.previous_deltas[key]:
|
| 153 |
+
prev_flat = prev_delta.flatten().float()
|
| 154 |
+
delta_flat = delta.flatten().float()
|
| 155 |
+
|
| 156 |
+
dot = torch.dot(delta_flat, prev_flat)
|
| 157 |
+
norm_sq = torch.dot(prev_flat, prev_flat)
|
| 158 |
+
|
| 159 |
+
if norm_sq > 1e-10:
|
| 160 |
+
projection = (dot / norm_sq) * prev_flat
|
| 161 |
+
delta_flat = delta_flat - projection
|
| 162 |
+
delta = delta_flat.reshape(delta.shape).to(delta.dtype)
|
| 163 |
+
|
| 164 |
+
# --- OTMF Mask Protection (new) ---
|
| 165 |
+
if self.cfg.use_otmf_masks and key in self.otmf_masks:
|
| 166 |
+
mask = self.otmf_masks[key].to(delta.device)
|
| 167 |
+
# Transferable weights: full delta
|
| 168 |
+
# Task-specific weights: reduced delta (protect them)
|
| 169 |
+
delta = torch.where(
|
| 170 |
+
mask,
|
| 171 |
+
delta, # Transferable β allow full change
|
| 172 |
+
delta * (1.0 - self.cfg.otmf_protect_strength), # Protected β reduced
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# --- MagMax Protection (extra safety layer) ---
|
| 176 |
+
if self.cfg.use_magmax and key in self.magnitude_masks:
|
| 177 |
+
mask = self.magnitude_masks[key]
|
| 178 |
+
delta = torch.where(mask, delta * 0.1, delta)
|
| 179 |
+
|
| 180 |
+
# Apply constrained delta
|
| 181 |
+
result = original + delta
|
| 182 |
+
|
| 183 |
+
return result
|
| 184 |
+
|
| 185 |
+
def after_merge(
|
| 186 |
+
self,
|
| 187 |
+
target_model: AutoModelForCausalLM,
|
| 188 |
+
pre_merge_state: dict,
|
| 189 |
+
pre_merge_activations: dict = None,
|
| 190 |
+
post_merge_activations: dict = None,
|
| 191 |
+
):
|
| 192 |
+
"""
|
| 193 |
+
Record the merge delta and compute protections for next merge.
|
| 194 |
+
|
| 195 |
+
Called AFTER each merge completes successfully.
|
| 196 |
+
Now also computes:
|
| 197 |
+
- ARM rotation vectors for next merge steering
|
| 198 |
+
- OTMF transferability masks for next merge
|
| 199 |
+
"""
|
| 200 |
+
current_state = target_model.state_dict()
|
| 201 |
+
|
| 202 |
+
for key in current_state:
|
| 203 |
+
if key in pre_merge_state:
|
| 204 |
+
delta = current_state[key].float() - pre_merge_state[key].float()
|
| 205 |
+
if delta.abs().max() > 1e-8:
|
| 206 |
+
if key not in self.previous_deltas:
|
| 207 |
+
self.previous_deltas[key] = []
|
| 208 |
+
if len(self.previous_deltas[key]) >= 2:
|
| 209 |
+
self.previous_deltas[key].pop(0)
|
| 210 |
+
self.previous_deltas[key].append(delta.cpu())
|
| 211 |
+
|
| 212 |
+
# --- Compute ARM rotations for next merge ---
|
| 213 |
+
if self.cfg.use_arm_steering and pre_merge_activations and post_merge_activations:
|
| 214 |
+
print("[protect] Computing ARM rotation vectors for next merge...")
|
| 215 |
+
self.arm_rotations = compute_arm_rotation(
|
| 216 |
+
pre_merge_activations,
|
| 217 |
+
post_merge_activations,
|
| 218 |
+
post_merge_activations, # Target = current state (for gap calculation)
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# --- Compute OTMF masks for next merge ---
|
| 222 |
+
if self.cfg.use_otmf_masks and post_merge_activations:
|
| 223 |
+
print("[protect] Computing OTMF transferability masks...")
|
| 224 |
+
self.otmf_masks = compute_transferability_masks(
|
| 225 |
+
target_model,
|
| 226 |
+
post_merge_activations,
|
| 227 |
+
threshold=self.cfg.otmf_threshold,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
self.merge_count += 1
|
| 231 |
+
print(f"[protect] Recorded merge delta #{self.merge_count} (ARM + OTMF ready for next)")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# ============================================================================
|
| 235 |
+
# MAIN ORCHESTRATOR
|
| 236 |
+
# ============================================================================
|
| 237 |
+
|
| 238 |
+
def is_vision_param(key: str, cfg: MergeConfig) -> bool:
|
| 239 |
+
"""
|
| 240 |
+
Check if a parameter belongs to the vision encoder.
|
| 241 |
+
|
| 242 |
+
Qwen3-VL-8B has a ViT vision encoder + merger projection on top of the
|
| 243 |
+
language model. We NEVER touch these during merging β they give us
|
| 244 |
+
browser agent and image understanding abilities for free.
|
| 245 |
+
|
| 246 |
+
Vision params start with prefixes like "visual." or "merger."
|
| 247 |
+
Language params start with "model.layers." or "model.embed_tokens." etc.
|
| 248 |
+
"""
|
| 249 |
+
for prefix in cfg.vision_skip_prefixes:
|
| 250 |
+
if key.startswith(prefix):
|
| 251 |
+
return True
|
| 252 |
+
return False
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def get_source_by_stage(stage_name: str) -> Optional[ModelConfig]:
|
| 256 |
+
"""Get model config by stage name."""
|
| 257 |
+
stage_map = {
|
| 258 |
+
"deepseek": 0,
|
| 259 |
+
"mimo": 1,
|
| 260 |
+
"llama": 2,
|
| 261 |
+
"falcon": 3,
|
| 262 |
+
}
|
| 263 |
+
idx = stage_map.get(stage_name.lower())
|
| 264 |
+
if idx is not None and idx < len(SOURCES):
|
| 265 |
+
return SOURCES[idx]
|
| 266 |
+
return None
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
|
| 270 |
+
"""Load a model and its tokenizer/processor."""
|
| 271 |
+
print(f"\n[merge] Loading {config.name} ({config.hf_id})...")
|
| 272 |
+
|
| 273 |
+
# Qwen3-VL uses a processor (handles both text + vision), not just a tokenizer
|
| 274 |
+
if config.architecture == "transformer+vision":
|
| 275 |
+
try:
|
| 276 |
+
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
|
| 277 |
+
processor = AutoProcessor.from_pretrained(
|
| 278 |
+
config.hf_id,
|
| 279 |
+
trust_remote_code=config.trust_remote_code,
|
| 280 |
+
)
|
| 281 |
+
model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 282 |
+
config.hf_id,
|
| 283 |
+
torch_dtype=getattr(torch, cfg.dtype),
|
| 284 |
+
attn_implementation=cfg.attn_implementation,
|
| 285 |
+
device_map=cfg.device_map,
|
| 286 |
+
trust_remote_code=config.trust_remote_code,
|
| 287 |
+
)
|
| 288 |
+
# Use the tokenizer from the processor for text operations
|
| 289 |
+
tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor
|
| 290 |
+
print(f"[merge] Loaded {config.name} (VL model): {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
|
| 291 |
+
|
| 292 |
+
# Count vision vs language params
|
| 293 |
+
vision_params = sum(
|
| 294 |
+
p.numel() for n, p in model.named_parameters()
|
| 295 |
+
if any(n.startswith(pfx) for pfx in cfg.vision_skip_prefixes)
|
| 296 |
+
)
|
| 297 |
+
lang_params = sum(p.numel() for p in model.parameters()) - vision_params
|
| 298 |
+
print(f"[merge] Language: {lang_params / 1e9:.1f}B | Vision: {vision_params / 1e9:.1f}B")
|
| 299 |
+
|
| 300 |
+
return model, tokenizer
|
| 301 |
+
except ImportError:
|
| 302 |
+
print("[merge] Qwen3VLForConditionalGeneration not available, falling back to AutoModel")
|
| 303 |
+
|
| 304 |
+
# Standard text-only models
|
| 305 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 306 |
+
config.hf_id,
|
| 307 |
+
trust_remote_code=config.trust_remote_code,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 311 |
+
config.hf_id,
|
| 312 |
+
torch_dtype=getattr(torch, cfg.dtype),
|
| 313 |
+
attn_implementation=cfg.attn_implementation,
|
| 314 |
+
device_map=cfg.device_map,
|
| 315 |
+
trust_remote_code=config.trust_remote_code,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
|
| 319 |
+
return model, tokenizer
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def save_checkpoint(
|
| 323 |
+
model: AutoModelForCausalLM,
|
| 324 |
+
tokenizer: AutoTokenizer,
|
| 325 |
+
stage_name: str,
|
| 326 |
+
cfg: MergeConfig,
|
| 327 |
+
):
|
| 328 |
+
"""Save a checkpoint after a successful merge stage."""
|
| 329 |
+
ckpt_dir = Path(cfg.checkpoint_dir) / f"after_{stage_name}"
|
| 330 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 331 |
+
|
| 332 |
+
print(f"[merge] Saving checkpoint to {ckpt_dir}...")
|
| 333 |
+
model.save_pretrained(ckpt_dir)
|
| 334 |
+
tokenizer.save_pretrained(ckpt_dir)
|
| 335 |
+
print(f"[merge] Checkpoint saved: {ckpt_dir}")
|
| 336 |
+
|
| 337 |
+
return str(ckpt_dir)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# ============================================================================
|
| 341 |
+
# RESIDUAL BANK β Save what was lost during each merge
|
| 342 |
+
# ============================================================================
|
| 343 |
+
|
| 344 |
+
class ResidualBank:
|
| 345 |
+
"""
|
| 346 |
+
Saves the knowledge that gets lost during each merge so it can
|
| 347 |
+
be recovered later.
|
| 348 |
+
|
| 349 |
+
When we blend at alpha=0.10:
|
| 350 |
+
merged = target + alpha * M * (transported - target)
|
| 351 |
+
|
| 352 |
+
We LOSE:
|
| 353 |
+
target_residual = target_original - merged (what target lost)
|
| 354 |
+
source_residual = source_original - merged (what source lost)
|
| 355 |
+
|
| 356 |
+
These residuals are saved to disk. Later they can be:
|
| 357 |
+
1. Fed back during the healing fine-tune (as training signal)
|
| 358 |
+
2. Re-injected via a small LoRA adapter
|
| 359 |
+
3. Used to diagnose which merge caused a specific knowledge loss
|
| 360 |
+
4. Re-applied at a lower alpha if we want more of that model
|
| 361 |
+
|
| 362 |
+
Think of it like saving the sawdust when you cut wood β you might
|
| 363 |
+
need to glue some of it back later.
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
def __init__(self, cfg: MergeConfig):
|
| 367 |
+
self.cfg = cfg
|
| 368 |
+
self.residual_dir = Path(cfg.checkpoint_dir) / "residuals"
|
| 369 |
+
self.residual_dir.mkdir(parents=True, exist_ok=True)
|
| 370 |
+
self.residual_index = {} # stage β {path, stats}
|
| 371 |
+
|
| 372 |
+
def save_residuals(
|
| 373 |
+
self,
|
| 374 |
+
stage_name: str,
|
| 375 |
+
pre_merge_target_state: dict,
|
| 376 |
+
source_state: dict,
|
| 377 |
+
post_merge_state: dict,
|
| 378 |
+
source_config: ModelConfig,
|
| 379 |
+
):
|
| 380 |
+
"""
|
| 381 |
+
Compute and save what was lost from both target and source.
|
| 382 |
+
|
| 383 |
+
Saves two files per merge stage:
|
| 384 |
+
- target_residual: what the target model lost
|
| 385 |
+
- source_residual: what the source model didn't fully contribute
|
| 386 |
+
|
| 387 |
+
Also saves stats so we know WHERE the biggest losses were
|
| 388 |
+
(which layers, which type of weights).
|
| 389 |
+
"""
|
| 390 |
+
stage_dir = self.residual_dir / stage_name
|
| 391 |
+
stage_dir.mkdir(parents=True, exist_ok=True)
|
| 392 |
+
|
| 393 |
+
target_residual = {}
|
| 394 |
+
source_residual = {}
|
| 395 |
+
stats = {
|
| 396 |
+
"stage": stage_name,
|
| 397 |
+
"source_model": source_config.name,
|
| 398 |
+
"target_loss_by_layer": {},
|
| 399 |
+
"source_loss_by_layer": {},
|
| 400 |
+
"total_target_loss": 0.0,
|
| 401 |
+
"total_source_loss": 0.0,
|
| 402 |
+
"biggest_losses": [],
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
for key in post_merge_state:
|
| 406 |
+
merged_w = post_merge_state[key].float()
|
| 407 |
+
|
| 408 |
+
# What the target lost
|
| 409 |
+
if key in pre_merge_target_state:
|
| 410 |
+
original_target = pre_merge_target_state[key].float()
|
| 411 |
+
t_residual = original_target - merged_w
|
| 412 |
+
t_loss = t_residual.abs().mean().item()
|
| 413 |
+
|
| 414 |
+
if t_loss > 1e-6: # Only save meaningful residuals
|
| 415 |
+
target_residual[key] = t_residual.to(torch.bfloat16).cpu()
|
| 416 |
+
stats["total_target_loss"] += t_loss
|
| 417 |
+
|
| 418 |
+
# Track per-layer losses
|
| 419 |
+
layer_name = ".".join(key.split(".")[:4])
|
| 420 |
+
if layer_name not in stats["target_loss_by_layer"]:
|
| 421 |
+
stats["target_loss_by_layer"][layer_name] = 0.0
|
| 422 |
+
stats["target_loss_by_layer"][layer_name] += t_loss
|
| 423 |
+
|
| 424 |
+
# What the source lost (what didn't make it into the merge)
|
| 425 |
+
if key in source_state:
|
| 426 |
+
original_source = source_state[key].float()
|
| 427 |
+
s_residual = original_source - merged_w
|
| 428 |
+
s_loss = s_residual.abs().mean().item()
|
| 429 |
+
|
| 430 |
+
if s_loss > 1e-6:
|
| 431 |
+
source_residual[key] = s_residual.to(torch.bfloat16).cpu()
|
| 432 |
+
stats["total_source_loss"] += s_loss
|
| 433 |
+
|
| 434 |
+
layer_name = ".".join(key.split(".")[:4])
|
| 435 |
+
if layer_name not in stats["source_loss_by_layer"]:
|
| 436 |
+
stats["source_loss_by_layer"][layer_name] = 0.0
|
| 437 |
+
stats["source_loss_by_layer"][layer_name] += s_loss
|
| 438 |
+
|
| 439 |
+
# Find the biggest losses (most knowledge dropped)
|
| 440 |
+
all_losses = []
|
| 441 |
+
for key in target_residual:
|
| 442 |
+
loss_magnitude = target_residual[key].float().abs().mean().item()
|
| 443 |
+
all_losses.append({"param": key, "side": "target", "loss": loss_magnitude})
|
| 444 |
+
for key in source_residual:
|
| 445 |
+
loss_magnitude = source_residual[key].float().abs().mean().item()
|
| 446 |
+
all_losses.append({"param": key, "side": "source", "loss": loss_magnitude})
|
| 447 |
+
all_losses.sort(key=lambda x: x["loss"], reverse=True)
|
| 448 |
+
stats["biggest_losses"] = all_losses[:20] # Top 20 biggest losses
|
| 449 |
+
|
| 450 |
+
# Save to disk
|
| 451 |
+
torch.save(target_residual, stage_dir / "target_residual.pt")
|
| 452 |
+
torch.save(source_residual, stage_dir / "source_residual.pt")
|
| 453 |
+
|
| 454 |
+
import json
|
| 455 |
+
with open(stage_dir / "residual_stats.json", "w") as f:
|
| 456 |
+
json.dump(stats, f, indent=2, default=str)
|
| 457 |
+
|
| 458 |
+
self.residual_index[stage_name] = {
|
| 459 |
+
"path": str(stage_dir),
|
| 460 |
+
"target_params_saved": len(target_residual),
|
| 461 |
+
"source_params_saved": len(source_residual),
|
| 462 |
+
"total_target_loss": stats["total_target_loss"],
|
| 463 |
+
"total_source_loss": stats["total_source_loss"],
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
print(f"[residual] Saved residuals for {stage_name}:")
|
| 467 |
+
print(f" Target lost: {len(target_residual)} params (avg loss: {stats['total_target_loss']:.4f})")
|
| 468 |
+
print(f" Source lost: {len(source_residual)} params (avg loss: {stats['total_source_loss']:.4f})")
|
| 469 |
+
print(f" Top loss: {all_losses[0]['param']} ({all_losses[0]['side']}, {all_losses[0]['loss']:.4f})" if all_losses else "")
|
| 470 |
+
print(f" Saved to: {stage_dir}")
|
| 471 |
+
|
| 472 |
+
def load_residuals(self, stage_name: str) -> tuple:
|
| 473 |
+
"""
|
| 474 |
+
Load saved residuals for a stage.
|
| 475 |
+
|
| 476 |
+
Returns:
|
| 477 |
+
(target_residual_dict, source_residual_dict)
|
| 478 |
+
"""
|
| 479 |
+
stage_dir = self.residual_dir / stage_name
|
| 480 |
+
target_residual = torch.load(stage_dir / "target_residual.pt", weights_only=True)
|
| 481 |
+
source_residual = torch.load(stage_dir / "source_residual.pt", weights_only=True)
|
| 482 |
+
return target_residual, source_residual
|
| 483 |
+
|
| 484 |
+
def reinject_residuals(
|
| 485 |
+
self,
|
| 486 |
+
model: AutoModelForCausalLM,
|
| 487 |
+
stage_name: str,
|
| 488 |
+
side: str = "both",
|
| 489 |
+
strength: float = 0.3,
|
| 490 |
+
) -> AutoModelForCausalLM:
|
| 491 |
+
"""
|
| 492 |
+
Re-inject saved residuals back into a model.
|
| 493 |
+
|
| 494 |
+
This adds back some of what was lost. Use a low strength (0.1-0.3)
|
| 495 |
+
to gently recover knowledge without undoing the merge.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
model: The model to inject into
|
| 499 |
+
stage_name: Which merge stage's residuals to use
|
| 500 |
+
side: "target", "source", or "both"
|
| 501 |
+
strength: How much to add back (0=nothing, 1=full residual)
|
| 502 |
+
"""
|
| 503 |
+
print(f"[residual] Re-injecting {stage_name} residuals (side={side}, strength={strength})...")
|
| 504 |
+
|
| 505 |
+
target_residual, source_residual = self.load_residuals(stage_name)
|
| 506 |
+
state = model.state_dict()
|
| 507 |
+
injected = 0
|
| 508 |
+
|
| 509 |
+
if side in ("target", "both"):
|
| 510 |
+
for key, residual in target_residual.items():
|
| 511 |
+
if key in state:
|
| 512 |
+
state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
|
| 513 |
+
injected += 1
|
| 514 |
+
|
| 515 |
+
if side in ("source", "both"):
|
| 516 |
+
for key, residual in source_residual.items():
|
| 517 |
+
if key in state:
|
| 518 |
+
state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
|
| 519 |
+
injected += 1
|
| 520 |
+
|
| 521 |
+
model.load_state_dict(state)
|
| 522 |
+
print(f"[residual] Re-injected {injected} params at {strength:.0%} strength")
|
| 523 |
+
return model
|
| 524 |
+
|
| 525 |
+
def get_healing_targets(self, top_n: int = 50) -> list:
|
| 526 |
+
"""
|
| 527 |
+
Get the parameters with the biggest losses across ALL merges.
|
| 528 |
+
|
| 529 |
+
These are the params that the healing fine-tune should focus on.
|
| 530 |
+
Feed this to the LoRA target_modules to make healing smarter.
|
| 531 |
+
"""
|
| 532 |
+
import json
|
| 533 |
+
all_losses = []
|
| 534 |
+
|
| 535 |
+
for stage_name in self.residual_index:
|
| 536 |
+
stage_dir = self.residual_dir / stage_name
|
| 537 |
+
stats_file = stage_dir / "residual_stats.json"
|
| 538 |
+
if stats_file.exists():
|
| 539 |
+
with open(stats_file) as f:
|
| 540 |
+
stats = json.load(f)
|
| 541 |
+
for loss in stats.get("biggest_losses", []):
|
| 542 |
+
loss["stage"] = stage_name
|
| 543 |
+
all_losses.append(loss)
|
| 544 |
+
|
| 545 |
+
all_losses.sort(key=lambda x: x["loss"], reverse=True)
|
| 546 |
+
|
| 547 |
+
# Extract unique layer/module names for LoRA targeting
|
| 548 |
+
target_modules = set()
|
| 549 |
+
for loss in all_losses[:top_n]:
|
| 550 |
+
param = loss["param"]
|
| 551 |
+
# Extract the module type (q_proj, k_proj, gate_proj, etc.)
|
| 552 |
+
parts = param.split(".")
|
| 553 |
+
for part in parts:
|
| 554 |
+
if part.endswith("_proj") or part in ("gate_proj", "up_proj", "down_proj"):
|
| 555 |
+
target_modules.add(part)
|
| 556 |
+
|
| 557 |
+
print(f"[residual] Top healing targets (from {len(all_losses)} total losses):")
|
| 558 |
+
for loss in all_losses[:5]:
|
| 559 |
+
print(f" {loss['param']} ({loss['side']}, stage={loss['stage']}, loss={loss['loss']:.4f})")
|
| 560 |
+
print(f" β Suggested LoRA targets: {sorted(target_modules)}")
|
| 561 |
+
|
| 562 |
+
return list(target_modules)
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
def run_single_merge(
|
| 566 |
+
target_model: AutoModelForCausalLM,
|
| 567 |
+
target_tokenizer: AutoTokenizer,
|
| 568 |
+
source_config: ModelConfig,
|
| 569 |
+
cfg: MergeConfig,
|
| 570 |
+
protection: MergeProtection,
|
| 571 |
+
residual_bank: ResidualBank = None,
|
| 572 |
+
calibration_data: list = None,
|
| 573 |
+
baseline_perplexity: float = None,
|
| 574 |
+
merged_sources: list = None,
|
| 575 |
+
) -> dict:
|
| 576 |
+
"""
|
| 577 |
+
Run a single merge: source β target.
|
| 578 |
+
|
| 579 |
+
Full pipeline for one merge step:
|
| 580 |
+
1. Load source model
|
| 581 |
+
2. Inject canary into source
|
| 582 |
+
3. Extract activations from both
|
| 583 |
+
4. Compute transport plans
|
| 584 |
+
5. Apply merge protection
|
| 585 |
+
6. Fuse weights
|
| 586 |
+
7. Apply post-merge protection
|
| 587 |
+
8. Validate
|
| 588 |
+
|
| 589 |
+
Returns:
|
| 590 |
+
Dict with merge results, validation results, and status
|
| 591 |
+
"""
|
| 592 |
+
if merged_sources is None:
|
| 593 |
+
merged_sources = []
|
| 594 |
+
|
| 595 |
+
stage_name = source_config.name
|
| 596 |
+
print(f"\n{'=' * 70}")
|
| 597 |
+
print(f"MERGE STAGE: {stage_name} β target")
|
| 598 |
+
print(f"Risk level: {source_config.merge_risk.upper()}")
|
| 599 |
+
print(f"{'=' * 70}")
|
| 600 |
+
|
| 601 |
+
result = {
|
| 602 |
+
"stage": stage_name,
|
| 603 |
+
"status": "pending",
|
| 604 |
+
"validation": None,
|
| 605 |
+
"checkpoint": None,
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
# --- Step 1: Load source model ---
|
| 609 |
+
source_model, source_tokenizer = load_model(source_config, cfg)
|
| 610 |
+
|
| 611 |
+
# --- Step 2: Inject canary into source ---
|
| 612 |
+
if stage_name in CANARY_FACTS:
|
| 613 |
+
print(f"\n[merge] Injecting canary fact into {stage_name}...")
|
| 614 |
+
source_model = inject_canary(source_model, source_tokenizer, stage_name)
|
| 615 |
+
|
| 616 |
+
# --- Step 3: Load calibration data (if not provided) ---
|
| 617 |
+
if calibration_data is None:
|
| 618 |
+
calibration_data = load_calibration_data(cfg, target_tokenizer)
|
| 619 |
+
|
| 620 |
+
# --- Step 4: Extract two-sided activations (pre + post per projection) ---
|
| 621 |
+
print(f"\n[merge] Extracting source activations (two-sided)...")
|
| 622 |
+
source_activations = extract_activations(source_model, calibration_data)
|
| 623 |
+
|
| 624 |
+
print(f"\n[merge] Extracting target activations (two-sided)...")
|
| 625 |
+
pre_merge_target_activations = extract_activations(target_model, calibration_data)
|
| 626 |
+
|
| 627 |
+
# --- Step 4.5: Mergeability pre-check (2601.22285) ---
|
| 628 |
+
if cfg.use_mergeability_check:
|
| 629 |
+
mergeability = compute_mergeability_score(
|
| 630 |
+
source_activations, pre_merge_target_activations, source_config
|
| 631 |
+
)
|
| 632 |
+
result["mergeability"] = mergeability
|
| 633 |
+
|
| 634 |
+
if mergeability["overall"] < cfg.mergeability_min_score:
|
| 635 |
+
print(f"\n[merge] β Mergeability score {mergeability['overall']:.2f} below threshold {cfg.mergeability_min_score}")
|
| 636 |
+
print(f"[merge] β {mergeability['recommendation']}")
|
| 637 |
+
result["status"] = "skipped_low_mergeability"
|
| 638 |
+
if "distillation_fallback" in source_config.special_handling:
|
| 639 |
+
result["fallback"] = "distillation"
|
| 640 |
+
del source_model, source_activations, pre_merge_target_activations
|
| 641 |
+
gc.collect()
|
| 642 |
+
if torch.cuda.is_available():
|
| 643 |
+
torch.cuda.empty_cache()
|
| 644 |
+
return result
|
| 645 |
+
|
| 646 |
+
# --- Step 5: Compute transport plans ---
|
| 647 |
+
transport_plans = compute_transport_plans(
|
| 648 |
+
source_activations, pre_merge_target_activations, cfg
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
# --- Step 5.5: RAM RL-weight disentanglement (2601.13572) ---
|
| 652 |
+
use_ram = (
|
| 653 |
+
cfg.use_ram_disentangle
|
| 654 |
+
and source_config.architecture in ("transformer", "transformer+mtp")
|
| 655 |
+
and source_config.merge_risk in ("low", "medium")
|
| 656 |
+
and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"])
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# --- Step 6: Pre-merge protection ---
|
| 660 |
+
adjusted_alpha = protection.before_merge(target_model, source_config)
|
| 661 |
+
|
| 662 |
+
# Override source alpha with time-adjusted value
|
| 663 |
+
source_config_adjusted = copy.copy(source_config)
|
| 664 |
+
source_config_adjusted.merge_alpha = adjusted_alpha
|
| 665 |
+
|
| 666 |
+
# Save pre-merge state for protection
|
| 667 |
+
pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()}
|
| 668 |
+
|
| 669 |
+
# --- Step 7: Fuse weights ---
|
| 670 |
+
if use_ram:
|
| 671 |
+
# RAM path: disentangle RL weights, merge with preservation
|
| 672 |
+
print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
|
| 673 |
+
try:
|
| 674 |
+
# Try loading the base (pre-RL) model for disentanglement
|
| 675 |
+
base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
|
| 676 |
+
print(f"[merge] Loading base model for RAM: {base_hf_id}")
|
| 677 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 678 |
+
base_hf_id,
|
| 679 |
+
torch_dtype=getattr(torch, cfg.dtype),
|
| 680 |
+
device_map=cfg.device_map,
|
| 681 |
+
trust_remote_code=source_config.trust_remote_code,
|
| 682 |
+
)
|
| 683 |
+
shared_mask, rl_mask = disentangle_rl_weights(
|
| 684 |
+
source_model, base_model, cfg.ram_rl_threshold
|
| 685 |
+
)
|
| 686 |
+
# Fuse with RL preservation
|
| 687 |
+
target_state = merge_with_rl_preservation(
|
| 688 |
+
target_model.state_dict(),
|
| 689 |
+
source_model.state_dict(),
|
| 690 |
+
shared_mask, rl_mask,
|
| 691 |
+
shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha),
|
| 692 |
+
rl_alpha=cfg.ram_rl_alpha,
|
| 693 |
+
)
|
| 694 |
+
target_model.load_state_dict(target_state)
|
| 695 |
+
del base_model
|
| 696 |
+
print(f"[merge] RAM merge complete for {stage_name}")
|
| 697 |
+
except Exception as e:
|
| 698 |
+
print(f"[merge] RAM failed ({e}), falling back to standard T&M merge")
|
| 699 |
+
target_model = fuse_weights(
|
| 700 |
+
source_model, target_model, transport_plans,
|
| 701 |
+
source_config_adjusted, cfg,
|
| 702 |
+
target_activations=pre_merge_target_activations,
|
| 703 |
+
)
|
| 704 |
+
else:
|
| 705 |
+
# Standard T&M path (two-sided + top-k masked fusion, paper Eq 14)
|
| 706 |
+
target_model = fuse_weights(
|
| 707 |
+
source_model, target_model, transport_plans,
|
| 708 |
+
source_config_adjusted, cfg,
|
| 709 |
+
target_activations=pre_merge_target_activations,
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
# --- Step 7.5: Theseus fallback check (2602.12952) ---
|
| 713 |
+
# If T&M merge produced poor activation alignment, try Theseus
|
| 714 |
+
if cfg.use_theseus_fallback and source_config.merge_risk == "high":
|
| 715 |
+
print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...")
|
| 716 |
+
post_activations = extract_activations(target_model, calibration_data[:50]) # Quick check
|
| 717 |
+
# Compare post-merge activations to pre-merge β if too similar, T&M didn't work
|
| 718 |
+
alignment_scores = []
|
| 719 |
+
for key in post_activations:
|
| 720 |
+
if key in pre_merge_target_activations:
|
| 721 |
+
cos = torch.nn.functional.cosine_similarity(
|
| 722 |
+
post_activations[key].float().mean(0, keepdim=True),
|
| 723 |
+
pre_merge_target_activations[key].float().mean(0, keepdim=True),
|
| 724 |
+
)
|
| 725 |
+
alignment_scores.append(cos.item())
|
| 726 |
+
avg_change = 1.0 - np.mean(alignment_scores) if alignment_scores else 0.0
|
| 727 |
+
print(f"[merge] Activation change from merge: {avg_change:.4f}")
|
| 728 |
+
|
| 729 |
+
if avg_change < 0.01:
|
| 730 |
+
print(f"[merge] β T&M had minimal effect β activating Theseus fallback")
|
| 731 |
+
# Restore pre-merge state and try Theseus instead
|
| 732 |
+
target_model.load_state_dict(pre_merge_state)
|
| 733 |
+
try:
|
| 734 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 735 |
+
source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0],
|
| 736 |
+
torch_dtype=getattr(torch, cfg.dtype),
|
| 737 |
+
device_map=cfg.device_map,
|
| 738 |
+
trust_remote_code=source_config.trust_remote_code,
|
| 739 |
+
)
|
| 740 |
+
target_model = transport_task_vector_theseus(
|
| 741 |
+
source_model, base_model, target_model,
|
| 742 |
+
source_activations, pre_merge_target_activations,
|
| 743 |
+
alpha=cfg.theseus_alpha,
|
| 744 |
+
)
|
| 745 |
+
del base_model
|
| 746 |
+
print(f"[merge] Theseus transport complete for {stage_name}")
|
| 747 |
+
except Exception as e:
|
| 748 |
+
print(f"[merge] Theseus also failed ({e}). Using original T&M result.")
|
| 749 |
+
# Re-apply T&M result
|
| 750 |
+
target_model = fuse_weights(
|
| 751 |
+
source_model, target_model, transport_plans,
|
| 752 |
+
source_config_adjusted, cfg,
|
| 753 |
+
target_activations=pre_merge_target_activations,
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
# --- Step 8: Apply post-merge protection (ARM + OTMF + MagMax) ---
|
| 757 |
+
# Skip vision encoder params β they weren't merged, so don't "protect" them
|
| 758 |
+
if protection.merge_count > 0:
|
| 759 |
+
print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...")
|
| 760 |
+
target_state = target_model.state_dict()
|
| 761 |
+
protected_count = 0
|
| 762 |
+
vision_skipped = 0
|
| 763 |
+
for key in target_state:
|
| 764 |
+
if is_vision_param(key, cfg):
|
| 765 |
+
vision_skipped += 1
|
| 766 |
+
continue # Don't touch vision encoder
|
| 767 |
+
if key in pre_merge_state:
|
| 768 |
+
protected_param = protection.apply_protection(
|
| 769 |
+
target_state, pre_merge_state, key
|
| 770 |
+
)
|
| 771 |
+
target_state[key] = protected_param
|
| 772 |
+
protected_count += 1
|
| 773 |
+
target_model.load_state_dict(target_state)
|
| 774 |
+
print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)")
|
| 775 |
+
|
| 776 |
+
# --- Step 8.5: Extract post-merge activations for ARM/OTMF ---
|
| 777 |
+
post_merge_activations = extract_activations(target_model, calibration_data[:100])
|
| 778 |
+
|
| 779 |
+
# Record this merge's delta + compute ARM/OTMF for next merge
|
| 780 |
+
protection.after_merge(
|
| 781 |
+
target_model, pre_merge_state,
|
| 782 |
+
pre_merge_activations=pre_merge_target_activations,
|
| 783 |
+
post_merge_activations=post_merge_activations,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# --- Step 8.8: Save residuals (what was lost from both sides) ---
|
| 787 |
+
if residual_bank is not None:
|
| 788 |
+
print(f"\n[merge] Saving residuals for {stage_name}...")
|
| 789 |
+
residual_bank.save_residuals(
|
| 790 |
+
stage_name=stage_name,
|
| 791 |
+
pre_merge_target_state=pre_merge_state,
|
| 792 |
+
source_state={k: v.cpu() for k, v in source_model.state_dict().items()},
|
| 793 |
+
post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()},
|
| 794 |
+
source_config=source_config,
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
# --- Step 9: Free source model memory ---
|
| 798 |
+
del source_model, source_activations, pre_merge_target_activations
|
| 799 |
+
del transport_plans, post_merge_activations
|
| 800 |
+
gc.collect()
|
| 801 |
+
if torch.cuda.is_available():
|
| 802 |
+
torch.cuda.empty_cache()
|
| 803 |
+
|
| 804 |
+
# --- Step 10: Validate ---
|
| 805 |
+
merged_sources.append(stage_name)
|
| 806 |
+
validation = validate_merged_model(
|
| 807 |
+
target_model, target_tokenizer,
|
| 808 |
+
merged_sources, cfg,
|
| 809 |
+
baseline_perplexity=baseline_perplexity,
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
result["validation"] = validation
|
| 813 |
+
result["merged_sources"] = merged_sources.copy()
|
| 814 |
+
|
| 815 |
+
# --- Kill criteria check ---
|
| 816 |
+
if not validation["overall"]:
|
| 817 |
+
print(f"\n[merge] β VALIDATION FAILED for {stage_name}")
|
| 818 |
+
print(f"[merge] Kill criteria triggered β consider aborting")
|
| 819 |
+
result["status"] = "failed"
|
| 820 |
+
|
| 821 |
+
# Check if we should try distillation fallback
|
| 822 |
+
if "distillation_fallback" in source_config.special_handling:
|
| 823 |
+
print(f"[merge] {stage_name} has distillation fallback available")
|
| 824 |
+
result["fallback"] = "distillation"
|
| 825 |
+
else:
|
| 826 |
+
print(f"\n[merge] β {stage_name} merge PASSED validation")
|
| 827 |
+
result["status"] = "passed"
|
| 828 |
+
|
| 829 |
+
return result
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
def run_pipeline(
|
| 833 |
+
stages: list[str],
|
| 834 |
+
cfg: MergeConfig = None,
|
| 835 |
+
) -> dict:
|
| 836 |
+
"""
|
| 837 |
+
Run the full merge pipeline.
|
| 838 |
+
|
| 839 |
+
Args:
|
| 840 |
+
stages: List of stage names to run, e.g. ["deepseek"] or
|
| 841 |
+
["deepseek", "mimo", "llama", "falcon"]
|
| 842 |
+
cfg: Merge configuration (uses defaults if None)
|
| 843 |
+
|
| 844 |
+
Returns:
|
| 845 |
+
Dict with overall results, per-stage results, and final model path
|
| 846 |
+
"""
|
| 847 |
+
if cfg is None:
|
| 848 |
+
cfg = MergeConfig()
|
| 849 |
+
|
| 850 |
+
print("\n" + "=" * 70)
|
| 851 |
+
print("TD LANG ENGINE β Transport and Merge Pipeline")
|
| 852 |
+
print(f"Target: {TARGET.name} ({TARGET.hf_id})")
|
| 853 |
+
if TARGET.architecture == "transformer+vision":
|
| 854 |
+
print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)")
|
| 855 |
+
print(f"Stages: {', '.join(stages)}")
|
| 856 |
+
print(f"Output: {cfg.output_dir}")
|
| 857 |
+
print("=" * 70)
|
| 858 |
+
|
| 859 |
+
# Setup
|
| 860 |
+
try:
|
| 861 |
+
setup_tm_repo(cfg)
|
| 862 |
+
except FileNotFoundError as e:
|
| 863 |
+
print(f"\nβ {e}")
|
| 864 |
+
print("Continuing with fallback implementation...")
|
| 865 |
+
|
| 866 |
+
# Create output directories
|
| 867 |
+
Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
|
| 868 |
+
Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
| 869 |
+
|
| 870 |
+
# --- Load target model ---
|
| 871 |
+
target_model, target_tokenizer = load_model(TARGET, cfg)
|
| 872 |
+
|
| 873 |
+
# --- Inject canary into target (Qwen3's own canary) ---
|
| 874 |
+
if "Qwen3-VL-8B" in CANARY_FACTS:
|
| 875 |
+
print("\n[pipeline] Injecting canary into base Qwen3-8B...")
|
| 876 |
+
target_model = inject_canary(target_model, target_tokenizer, "Qwen3-VL-8B")
|
| 877 |
+
|
| 878 |
+
# --- Compute baseline perplexity ---
|
| 879 |
+
print("\n[pipeline] Computing baseline perplexity...")
|
| 880 |
+
baseline_ppl = compute_perplexity(target_model, target_tokenizer)
|
| 881 |
+
print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}")
|
| 882 |
+
|
| 883 |
+
# --- Load calibration data once ---
|
| 884 |
+
calibration_data = load_calibration_data(cfg, target_tokenizer)
|
| 885 |
+
|
| 886 |
+
# --- Initialize merge protection + residual bank ---
|
| 887 |
+
protection = MergeProtection(cfg)
|
| 888 |
+
residual_bank = ResidualBank(cfg)
|
| 889 |
+
|
| 890 |
+
# --- Run each merge stage ---
|
| 891 |
+
pipeline_results = {
|
| 892 |
+
"stages": {},
|
| 893 |
+
"baseline_perplexity": baseline_ppl,
|
| 894 |
+
"final_checkpoint": None,
|
| 895 |
+
"residuals": {},
|
| 896 |
+
"overall_status": "pending",
|
| 897 |
+
}
|
| 898 |
+
merged_sources = []
|
| 899 |
+
all_passed = True
|
| 900 |
+
|
| 901 |
+
for stage_name in stages:
|
| 902 |
+
source_config = get_source_by_stage(stage_name)
|
| 903 |
+
if source_config is None:
|
| 904 |
+
print(f"\nβ Unknown stage: {stage_name}, skipping")
|
| 905 |
+
continue
|
| 906 |
+
|
| 907 |
+
# --- Wasserstein pre-check for high-risk models ---
|
| 908 |
+
if "check_wasserstein_first" in source_config.special_handling:
|
| 909 |
+
print(f"\n[pipeline] Running Wasserstein pre-check for {source_config.name}...")
|
| 910 |
+
# TODO: Implement Wasserstein distance pre-check
|
| 911 |
+
# If distance is too high, skip to distillation fallback
|
| 912 |
+
print("[pipeline] Pre-check: proceeding (TODO: implement distance check)")
|
| 913 |
+
|
| 914 |
+
# Run the merge (with residual bank to save what's lost)
|
| 915 |
+
stage_result = run_single_merge(
|
| 916 |
+
target_model, target_tokenizer,
|
| 917 |
+
source_config, cfg,
|
| 918 |
+
protection,
|
| 919 |
+
residual_bank=residual_bank,
|
| 920 |
+
calibration_data=calibration_data,
|
| 921 |
+
baseline_perplexity=baseline_ppl,
|
| 922 |
+
merged_sources=merged_sources,
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
pipeline_results["stages"][stage_name] = stage_result
|
| 926 |
+
|
| 927 |
+
if stage_result["status"] == "passed":
|
| 928 |
+
# Save checkpoint
|
| 929 |
+
ckpt_path = save_checkpoint(
|
| 930 |
+
target_model, target_tokenizer, stage_name, cfg
|
| 931 |
+
)
|
| 932 |
+
stage_result["checkpoint"] = ckpt_path
|
| 933 |
+
pipeline_results["final_checkpoint"] = ckpt_path
|
| 934 |
+
else:
|
| 935 |
+
all_passed = False
|
| 936 |
+
print(f"\n[pipeline] Stage {stage_name} FAILED")
|
| 937 |
+
|
| 938 |
+
# Decision: abort or continue?
|
| 939 |
+
if source_config.merge_risk == "high":
|
| 940 |
+
print(f"[pipeline] High-risk model failed β skipping (will use distillation)")
|
| 941 |
+
# Don't abort the whole pipeline, just skip this model
|
| 942 |
+
continue
|
| 943 |
+
else:
|
| 944 |
+
print(f"[pipeline] ABORTING pipeline β non-high-risk model failed")
|
| 945 |
+
pipeline_results["overall_status"] = f"aborted_at_{stage_name}"
|
| 946 |
+
break
|
| 947 |
+
|
| 948 |
+
# --- Save residual index ---
|
| 949 |
+
pipeline_results["residuals"] = residual_bank.residual_index
|
| 950 |
+
if residual_bank.residual_index:
|
| 951 |
+
print(f"\n[pipeline] Residual bank: {len(residual_bank.residual_index)} stages saved")
|
| 952 |
+
for stage, info in residual_bank.residual_index.items():
|
| 953 |
+
print(f" {stage}: target lost {info['total_target_loss']:.4f}, source lost {info['total_source_loss']:.4f}")
|
| 954 |
+
|
| 955 |
+
# Identify which modules need the most healing
|
| 956 |
+
healing_targets = residual_bank.get_healing_targets(top_n=50)
|
| 957 |
+
pipeline_results["suggested_healing_targets"] = healing_targets
|
| 958 |
+
|
| 959 |
+
# --- Save final model ---
|
| 960 |
+
if pipeline_results["final_checkpoint"]:
|
| 961 |
+
final_dir = Path(cfg.output_dir) / "final"
|
| 962 |
+
final_dir.mkdir(parents=True, exist_ok=True)
|
| 963 |
+
target_model.save_pretrained(final_dir)
|
| 964 |
+
target_tokenizer.save_pretrained(final_dir)
|
| 965 |
+
pipeline_results["final_model_path"] = str(final_dir)
|
| 966 |
+
print(f"\n[pipeline] Final model saved to {final_dir}")
|
| 967 |
+
|
| 968 |
+
if all_passed:
|
| 969 |
+
pipeline_results["overall_status"] = "all_passed"
|
| 970 |
+
elif pipeline_results["overall_status"] == "pending":
|
| 971 |
+
pipeline_results["overall_status"] = "partial"
|
| 972 |
+
|
| 973 |
+
# --- Print final summary ---
|
| 974 |
+
print("\n" + "=" * 70)
|
| 975 |
+
print("PIPELINE SUMMARY")
|
| 976 |
+
print("=" * 70)
|
| 977 |
+
for stage_name, stage_result in pipeline_results["stages"].items():
|
| 978 |
+
status = stage_result["status"]
|
| 979 |
+
emoji = "β" if status == "passed" else "β"
|
| 980 |
+
print(f" {emoji} {stage_name}: {status}")
|
| 981 |
+
print(f"\n Overall: {pipeline_results['overall_status']}")
|
| 982 |
+
if residual_bank.residual_index:
|
| 983 |
+
print(f"\n Residuals saved for: {', '.join(residual_bank.residual_index.keys())}")
|
| 984 |
+
print(f" To recover lost knowledge later:")
|
| 985 |
+
print(f" python -m td_lang.engine --reinject <stage> --strength 0.2")
|
| 986 |
+
print("=" * 70)
|
| 987 |
+
|
| 988 |
+
return pipeline_results
|
hugging/td_lang/engine/run.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Fuse β Main Entry Point.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
# Dad demo: merge just DeepSeek β Qwen3-8B (easiest, lowest risk)
|
| 6 |
+
python -m td_fuse.run --stage demo
|
| 7 |
+
|
| 8 |
+
# Full pipeline: all 4 merges
|
| 9 |
+
python -m td_fuse.run --stage all
|
| 10 |
+
|
| 11 |
+
# Single model merge
|
| 12 |
+
python -m td_fuse.run --stage deepseek
|
| 13 |
+
python -m td_fuse.run --stage mimo
|
| 14 |
+
python -m td_fuse.run --stage llama
|
| 15 |
+
python -m td_fuse.run --stage falcon
|
| 16 |
+
|
| 17 |
+
# With healing fine-tune after merge
|
| 18 |
+
python -m td_fuse.run --stage demo --heal
|
| 19 |
+
|
| 20 |
+
# Custom output directory
|
| 21 |
+
python -m td_fuse.run --stage all --output ./my_output
|
| 22 |
+
|
| 23 |
+
# Heal an existing checkpoint
|
| 24 |
+
python -m td_fuse.run --heal-only --model-path ./td_fuse_checkpoints/after_deepseek
|
| 25 |
+
|
| 26 |
+
Findings: #25 (dad demo plan), #22 (merge order), #24 (official T&M pipeline)
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import argparse
|
| 30 |
+
import json
|
| 31 |
+
import sys
|
| 32 |
+
import time
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
|
| 35 |
+
from .config import MergeConfig, DEMO_STAGES, FULL_STAGES
|
| 36 |
+
from .merge import run_pipeline, ResidualBank
|
| 37 |
+
from .heal import heal_model
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def parse_args():
|
| 41 |
+
parser = argparse.ArgumentParser(
|
| 42 |
+
description="TD Fuse β Transport and Merge pipeline for Time Dilation",
|
| 43 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 44 |
+
epilog="""
|
| 45 |
+
Examples:
|
| 46 |
+
python -m td_fuse.run --stage demo # Dad demo (DeepSeek only)
|
| 47 |
+
python -m td_fuse.run --stage all # Full 4-model merge
|
| 48 |
+
python -m td_fuse.run --stage all --heal # Merge + healing fine-tune
|
| 49 |
+
python -m td_fuse.run --heal-only --model-path ./checkpoint
|
| 50 |
+
python -m td_fuse.run --reinject deepseek --strength 0.2 --model-path ./final
|
| 51 |
+
""",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--stage",
|
| 56 |
+
type=str,
|
| 57 |
+
default="demo",
|
| 58 |
+
choices=["demo", "all", "deepseek", "mimo", "llama", "falcon"],
|
| 59 |
+
help="Which merge stage(s) to run (default: demo)",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--heal",
|
| 63 |
+
action="store_true",
|
| 64 |
+
help="Run healing fine-tune after merge",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--heal-only",
|
| 68 |
+
action="store_true",
|
| 69 |
+
help="Only run healing (skip merge), requires --model-path",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--model-path",
|
| 73 |
+
type=str,
|
| 74 |
+
default=None,
|
| 75 |
+
help="Path to existing model/checkpoint (for --heal-only)",
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--output",
|
| 79 |
+
type=str,
|
| 80 |
+
default="./td_fuse_outputs",
|
| 81 |
+
help="Output directory (default: ./td_fuse_outputs)",
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--checkpoint-dir",
|
| 85 |
+
type=str,
|
| 86 |
+
default="./td_fuse_checkpoints",
|
| 87 |
+
help="Checkpoint directory (default: ./td_fuse_checkpoints)",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--tm-repo",
|
| 91 |
+
type=str,
|
| 92 |
+
default="./Cross-Architecture-Merging-for-Large-Language-Models",
|
| 93 |
+
help="Path to official T&M repo",
|
| 94 |
+
)
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--dry-run",
|
| 97 |
+
action="store_true",
|
| 98 |
+
help="Print what would happen without actually running",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--reinject",
|
| 102 |
+
type=str,
|
| 103 |
+
default=None,
|
| 104 |
+
help="Re-inject saved residuals from a stage (e.g., --reinject deepseek)",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--reinject-side",
|
| 108 |
+
type=str,
|
| 109 |
+
default="both",
|
| 110 |
+
choices=["target", "source", "both"],
|
| 111 |
+
help="Which side's residuals to re-inject (default: both)",
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--strength",
|
| 115 |
+
type=float,
|
| 116 |
+
default=0.2,
|
| 117 |
+
help="Residual re-injection strength, 0-1 (default: 0.2)",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
return parser.parse_args()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def print_banner():
|
| 124 |
+
"""Print the TD Fuse banner."""
|
| 125 |
+
banner = """
|
| 126 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 127 |
+
β β
|
| 128 |
+
β ββββββββββββββββ βββββββββββ βββββββββββ β
|
| 129 |
+
β βββββββββββββββββ βββββββββββ βββββββββββ β
|
| 130 |
+
β βββ βββ βββ ββββββ βββ βββββββββββ β
|
| 131 |
+
β βββ βββ βββ ββββββ βββ βββββββββββ β
|
| 132 |
+
β βββ ββββββββ βββ βββββββββββββββββ β
|
| 133 |
+
β βββ βββββββ βββ βββββββ ββββββββ β
|
| 134 |
+
β β
|
| 135 |
+
β Transport and Merge for Time Dilation β
|
| 136 |
+
β Merging 5 models into Qwen3-8B β
|
| 137 |
+
β β
|
| 138 |
+
βββββββββββββββββββββββββββββββββββοΏ½οΏ½ββββββββββββββββ
|
| 139 |
+
"""
|
| 140 |
+
print(banner)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def main():
|
| 144 |
+
args = parse_args()
|
| 145 |
+
print_banner()
|
| 146 |
+
|
| 147 |
+
# Build config from args
|
| 148 |
+
cfg = MergeConfig(
|
| 149 |
+
output_dir=args.output,
|
| 150 |
+
checkpoint_dir=args.checkpoint_dir,
|
| 151 |
+
tm_repo_path=args.tm_repo,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Determine which stages to run
|
| 155 |
+
if args.stage == "demo":
|
| 156 |
+
stages = DEMO_STAGES
|
| 157 |
+
elif args.stage == "all":
|
| 158 |
+
stages = FULL_STAGES
|
| 159 |
+
else:
|
| 160 |
+
stages = [args.stage]
|
| 161 |
+
|
| 162 |
+
# --- Reinject residuals mode ---
|
| 163 |
+
if args.reinject:
|
| 164 |
+
if not args.model_path:
|
| 165 |
+
print("Error: --reinject requires --model-path")
|
| 166 |
+
sys.exit(1)
|
| 167 |
+
|
| 168 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 169 |
+
import torch
|
| 170 |
+
|
| 171 |
+
print(f"\n[run] Re-injecting residuals from stage: {args.reinject}")
|
| 172 |
+
print(f"[run] Side: {args.reinject_side}, Strength: {args.strength}")
|
| 173 |
+
|
| 174 |
+
residual_bank = ResidualBank(cfg)
|
| 175 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
| 176 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 177 |
+
args.model_path,
|
| 178 |
+
torch_dtype=torch.bfloat16,
|
| 179 |
+
device_map="auto",
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
model = residual_bank.reinject_residuals(
|
| 183 |
+
model, args.reinject,
|
| 184 |
+
side=args.reinject_side,
|
| 185 |
+
strength=args.strength,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Save the patched model
|
| 189 |
+
patched_dir = Path(cfg.output_dir) / f"reinjected_{args.reinject}_{args.strength}"
|
| 190 |
+
patched_dir.mkdir(parents=True, exist_ok=True)
|
| 191 |
+
model.save_pretrained(str(patched_dir))
|
| 192 |
+
tokenizer.save_pretrained(str(patched_dir))
|
| 193 |
+
print(f"\n[run] Patched model saved to: {patched_dir}")
|
| 194 |
+
return
|
| 195 |
+
|
| 196 |
+
# --- Heal-only mode ---
|
| 197 |
+
if args.heal_only:
|
| 198 |
+
if not args.model_path:
|
| 199 |
+
print("Error: --heal-only requires --model-path")
|
| 200 |
+
sys.exit(1)
|
| 201 |
+
|
| 202 |
+
print(f"\n[run] Healing model at: {args.model_path}")
|
| 203 |
+
healed_path = heal_model(args.model_path, cfg)
|
| 204 |
+
print(f"\n[run] Healed model saved to: {healed_path}")
|
| 205 |
+
return
|
| 206 |
+
|
| 207 |
+
# --- Dry run ---
|
| 208 |
+
if args.dry_run:
|
| 209 |
+
print("\n=== DRY RUN ===")
|
| 210 |
+
print(f"Stages: {stages}")
|
| 211 |
+
print(f"Output: {cfg.output_dir}")
|
| 212 |
+
print(f"Checkpoints: {cfg.checkpoint_dir}")
|
| 213 |
+
print(f"T&M repo: {cfg.tm_repo_path}")
|
| 214 |
+
print(f"Heal after: {args.heal}")
|
| 215 |
+
print(f"\nWould run:")
|
| 216 |
+
for i, stage in enumerate(stages, 1):
|
| 217 |
+
print(f" {i}. Merge {stage} β target")
|
| 218 |
+
print(f" β Validate (canary + perplexity + thinking + reasoning)")
|
| 219 |
+
print(f" β Checkpoint")
|
| 220 |
+
if args.heal:
|
| 221 |
+
print(f" {len(stages) + 1}. QLoRA healing fine-tune")
|
| 222 |
+
print("\nNo changes made (dry run).")
|
| 223 |
+
return
|
| 224 |
+
|
| 225 |
+
# --- Run the pipeline ---
|
| 226 |
+
start_time = time.time()
|
| 227 |
+
|
| 228 |
+
results = run_pipeline(stages, cfg)
|
| 229 |
+
|
| 230 |
+
elapsed = time.time() - start_time
|
| 231 |
+
print(f"\n[run] Pipeline completed in {elapsed / 60:.1f} minutes")
|
| 232 |
+
|
| 233 |
+
# --- Healing fine-tune (optional) ---
|
| 234 |
+
if args.heal and results.get("final_checkpoint"):
|
| 235 |
+
print("\n[run] Starting healing fine-tune...")
|
| 236 |
+
healed_path = heal_model(results["final_checkpoint"], cfg)
|
| 237 |
+
results["healed_model_path"] = healed_path
|
| 238 |
+
print(f"[run] Healed model: {healed_path}")
|
| 239 |
+
|
| 240 |
+
# --- Save results ---
|
| 241 |
+
results_path = Path(cfg.output_dir) / "pipeline_results.json"
|
| 242 |
+
|
| 243 |
+
# Convert non-serialisable objects
|
| 244 |
+
def make_serialisable(obj):
|
| 245 |
+
if isinstance(obj, dict):
|
| 246 |
+
return {k: make_serialisable(v) for k, v in obj.items()}
|
| 247 |
+
elif isinstance(obj, list):
|
| 248 |
+
return [make_serialisable(v) for v in obj]
|
| 249 |
+
elif isinstance(obj, (int, float, str, bool, type(None))):
|
| 250 |
+
return obj
|
| 251 |
+
else:
|
| 252 |
+
return str(obj)
|
| 253 |
+
|
| 254 |
+
with open(results_path, "w") as f:
|
| 255 |
+
json.dump(make_serialisable(results), f, indent=2)
|
| 256 |
+
print(f"[run] Results saved to {results_path}")
|
| 257 |
+
|
| 258 |
+
# --- Final summary ---
|
| 259 |
+
print(f"\n{'=' * 60}")
|
| 260 |
+
print("TD FUSE COMPLETE")
|
| 261 |
+
print(f"{'=' * 60}")
|
| 262 |
+
print(f" Status: {results['overall_status']}")
|
| 263 |
+
print(f" Time: {elapsed / 60:.1f} minutes")
|
| 264 |
+
if results.get("final_model_path"):
|
| 265 |
+
print(f" Model: {results['final_model_path']}")
|
| 266 |
+
if results.get("healed_model_path"):
|
| 267 |
+
print(f" Healed: {results['healed_model_path']}")
|
| 268 |
+
print(f" Results: {results_path}")
|
| 269 |
+
print(f"{'=' * 60}")
|
| 270 |
+
|
| 271 |
+
# Exit code based on result
|
| 272 |
+
if results["overall_status"] == "all_passed":
|
| 273 |
+
sys.exit(0)
|
| 274 |
+
else:
|
| 275 |
+
sys.exit(1)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if __name__ == "__main__":
|
| 279 |
+
main()
|
hugging/td_lang/engine/techniques.py
ADDED
|
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced Merge Techniques β from latest papers (Feb 2026).
|
| 3 |
+
|
| 4 |
+
This module contains implementations inspired by recent research
|
| 5 |
+
that improve TD's sequential cross-architecture merging pipeline.
|
| 6 |
+
|
| 7 |
+
Techniques:
|
| 8 |
+
1. Theseus (2602.12952) β Procrustes-based task vector transport
|
| 9 |
+
2. ARM (2602.03237) β Activation-guided rotation for sequential merges
|
| 10 |
+
3. OTMF (2511.19561) β OT masks for identifying transferable weights
|
| 11 |
+
4. RAM (2601.13572) β RL-weight disentanglement for RL-trained models
|
| 12 |
+
5. Mergeability (2601.22285) β Pre-check scoring before attempting merge
|
| 13 |
+
|
| 14 |
+
These complement Transport and Merge (2602.05495) which handles
|
| 15 |
+
the core cross-architecture fusion via optimal transport.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import numpy as np
|
| 20 |
+
from typing import Optional
|
| 21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 22 |
+
|
| 23 |
+
from .config import MergeConfig, ModelConfig
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ============================================================================
|
| 27 |
+
# 1. THESEUS β Procrustes-Based Task Vector Transport (2602.12952)
|
| 28 |
+
# ============================================================================
|
| 29 |
+
#
|
| 30 |
+
# Instead of aligning neurons via optimal transport (T&M), Theseus aligns
|
| 31 |
+
# the FUNCTIONAL EFFECT of weights via orthogonal Procrustes.
|
| 32 |
+
#
|
| 33 |
+
# Analogy: T&M says "neuron 5 in Model A = neuron 12 in Model B"
|
| 34 |
+
# Theseus says "the EFFECT of Model A's weights can be rotated
|
| 35 |
+
# into Model B's space"
|
| 36 |
+
#
|
| 37 |
+
# Best for: Models where neuron-level alignment is poor (Falcon SSM hybrid)
|
| 38 |
+
|
| 39 |
+
def compute_procrustes_alignment(
|
| 40 |
+
source_activations: torch.Tensor,
|
| 41 |
+
target_activations: torch.Tensor,
|
| 42 |
+
) -> torch.Tensor:
|
| 43 |
+
"""
|
| 44 |
+
Compute the orthogonal Procrustes rotation matrix R that best maps
|
| 45 |
+
source activations into target activation space.
|
| 46 |
+
|
| 47 |
+
R = argmin ||target - source @ R||_F subject to R^T R = I
|
| 48 |
+
|
| 49 |
+
Solution: R = V @ U^T from SVD of (source^T @ target) = U S V^T
|
| 50 |
+
|
| 51 |
+
This is a closed-form solution β no iterative optimisation needed.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
source_activations: [num_samples, source_dim] activation matrix
|
| 55 |
+
target_activations: [num_samples, target_dim] activation matrix
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
R: [source_dim, target_dim] rotation matrix
|
| 59 |
+
"""
|
| 60 |
+
# Center the activations (remove mean)
|
| 61 |
+
S = source_activations - source_activations.mean(dim=0, keepdim=True)
|
| 62 |
+
T = target_activations - target_activations.mean(dim=0, keepdim=True)
|
| 63 |
+
|
| 64 |
+
# Handle dimension mismatch by zero-padding the smaller one
|
| 65 |
+
s_dim = S.shape[1]
|
| 66 |
+
t_dim = T.shape[1]
|
| 67 |
+
max_dim = max(s_dim, t_dim)
|
| 68 |
+
|
| 69 |
+
if s_dim < max_dim:
|
| 70 |
+
S = torch.nn.functional.pad(S, (0, max_dim - s_dim))
|
| 71 |
+
if t_dim < max_dim:
|
| 72 |
+
T = torch.nn.functional.pad(T, (0, max_dim - t_dim))
|
| 73 |
+
|
| 74 |
+
# Cross-covariance matrix
|
| 75 |
+
M = S.T @ T # [max_dim, max_dim]
|
| 76 |
+
|
| 77 |
+
# SVD: M = U @ diag(sigma) @ V^T
|
| 78 |
+
U, sigma, Vt = torch.linalg.svd(M, full_matrices=True)
|
| 79 |
+
|
| 80 |
+
# Optimal rotation: R = V @ U^T
|
| 81 |
+
# This ensures R is orthogonal (R^T R = I)
|
| 82 |
+
R = Vt.T @ U.T
|
| 83 |
+
|
| 84 |
+
# Ensure proper rotation (det = +1), not reflection
|
| 85 |
+
det = torch.linalg.det(R)
|
| 86 |
+
if det < 0:
|
| 87 |
+
# Flip sign of last column of Vt
|
| 88 |
+
Vt[-1, :] *= -1
|
| 89 |
+
R = Vt.T @ U.T
|
| 90 |
+
|
| 91 |
+
return R[:s_dim, :t_dim] # Crop back to original dims
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def transport_task_vector_theseus(
|
| 95 |
+
source_model: AutoModelForCausalLM,
|
| 96 |
+
source_base_model: AutoModelForCausalLM,
|
| 97 |
+
target_model: AutoModelForCausalLM,
|
| 98 |
+
source_activations: dict,
|
| 99 |
+
target_activations: dict,
|
| 100 |
+
alpha: float = 0.3,
|
| 101 |
+
) -> AutoModelForCausalLM:
|
| 102 |
+
"""
|
| 103 |
+
Transport a task vector from source to target using Theseus method.
|
| 104 |
+
|
| 105 |
+
Task vector = source_finetuned - source_base
|
| 106 |
+
(the "diff" that represents what the model learned)
|
| 107 |
+
|
| 108 |
+
We rotate this diff into target's space using Procrustes alignment,
|
| 109 |
+
then add it to target: target_new = target + alpha * R @ task_vector
|
| 110 |
+
|
| 111 |
+
This is the FALLBACK for when T&M's neuron-level alignment fails
|
| 112 |
+
(e.g., Falcon's SSM components).
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
source_model: The fine-tuned source (e.g., Falcon-H1R-7B)
|
| 116 |
+
source_base_model: The base version of source (for computing task vector)
|
| 117 |
+
target_model: The target to transport into (our merged Qwen3)
|
| 118 |
+
source_activations: Layer β activation tensors for source
|
| 119 |
+
target_activations: Layer β activation tensors for target
|
| 120 |
+
alpha: Blending weight for the transported task vector
|
| 121 |
+
"""
|
| 122 |
+
print("[theseus] Computing task vectors and Procrustes alignment...")
|
| 123 |
+
|
| 124 |
+
source_state = source_model.state_dict()
|
| 125 |
+
base_state = source_base_model.state_dict()
|
| 126 |
+
target_state = target_model.state_dict()
|
| 127 |
+
|
| 128 |
+
# Compute per-layer Procrustes rotation matrices
|
| 129 |
+
rotations = {}
|
| 130 |
+
source_layers = sorted(source_activations.keys())
|
| 131 |
+
target_layers = sorted(target_activations.keys())
|
| 132 |
+
|
| 133 |
+
for sl, tl in zip(source_layers, target_layers):
|
| 134 |
+
if sl in source_activations and tl in target_activations:
|
| 135 |
+
R = compute_procrustes_alignment(
|
| 136 |
+
source_activations[sl].float(),
|
| 137 |
+
target_activations[tl].float(),
|
| 138 |
+
)
|
| 139 |
+
rotations[(sl, tl)] = R
|
| 140 |
+
|
| 141 |
+
# Transport task vectors
|
| 142 |
+
transported_count = 0
|
| 143 |
+
for target_key in target_state:
|
| 144 |
+
# Find matching source key (simplified β same key names)
|
| 145 |
+
source_key = target_key
|
| 146 |
+
if source_key not in source_state or source_key not in base_state:
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
# Task vector = what the source learned
|
| 150 |
+
task_vector = source_state[source_key].float() - base_state[source_key].float()
|
| 151 |
+
|
| 152 |
+
if task_vector.abs().max() < 1e-8:
|
| 153 |
+
continue # No meaningful change
|
| 154 |
+
|
| 155 |
+
# For 2D weight matrices, apply rotation
|
| 156 |
+
if task_vector.dim() == 2:
|
| 157 |
+
# Find the appropriate rotation for this layer
|
| 158 |
+
for (sl, tl), R in rotations.items():
|
| 159 |
+
if sl.split(".")[2] == target_key.split(".")[2]: # Same layer index
|
| 160 |
+
R_device = R.to(task_vector.device)
|
| 161 |
+
# Rotate: task_vector_rotated = task_vector @ R
|
| 162 |
+
try:
|
| 163 |
+
if task_vector.shape[1] == R_device.shape[0]:
|
| 164 |
+
task_vector = task_vector @ R_device
|
| 165 |
+
elif task_vector.shape[0] == R_device.shape[0]:
|
| 166 |
+
task_vector = R_device.T @ task_vector
|
| 167 |
+
except RuntimeError:
|
| 168 |
+
pass # Dimension mismatch, use unrotated
|
| 169 |
+
break
|
| 170 |
+
|
| 171 |
+
# Apply: target_new = target + alpha * rotated_task_vector
|
| 172 |
+
target_w = target_state[target_key]
|
| 173 |
+
if task_vector.shape == target_w.shape:
|
| 174 |
+
target_state[target_key] = target_w + alpha * task_vector.to(target_w.dtype)
|
| 175 |
+
transported_count += 1
|
| 176 |
+
|
| 177 |
+
target_model.load_state_dict(target_state)
|
| 178 |
+
print(f"[theseus] Transported {transported_count} task vectors via Procrustes")
|
| 179 |
+
return target_model
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# ============================================================================
|
| 183 |
+
# 2. ARM β Activation-Guided Rotations for Sequential Merging (2602.03237)
|
| 184 |
+
# ============================================================================
|
| 185 |
+
#
|
| 186 |
+
# ARM treats sequential merging like gradient descent β each merge step
|
| 187 |
+
# has a "direction" and a "learning rate" (merge coefficient).
|
| 188 |
+
#
|
| 189 |
+
# Key insight: Use ACTIVATION PATTERNS to compute optimal rotation vectors
|
| 190 |
+
# that guide each merge step. This is a smarter version of our
|
| 191 |
+
# orthogonal projection in MergeProtection.
|
| 192 |
+
|
| 193 |
+
def compute_arm_rotation(
|
| 194 |
+
pre_merge_activations: dict,
|
| 195 |
+
post_merge_activations: dict,
|
| 196 |
+
target_activations: dict,
|
| 197 |
+
) -> dict:
|
| 198 |
+
"""
|
| 199 |
+
Compute ARM rotation vectors for sequential merge protection.
|
| 200 |
+
|
| 201 |
+
For each layer, compute a rotation that:
|
| 202 |
+
1. Preserves the direction of knowledge already merged
|
| 203 |
+
2. Steers the next merge to fill GAPS rather than overwrite
|
| 204 |
+
|
| 205 |
+
The rotation is computed from the activation change (what the
|
| 206 |
+
last merge did) and the target (where we want to end up).
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
Dict of layer_name β rotation matrix
|
| 210 |
+
"""
|
| 211 |
+
print("[arm] Computing activation-guided rotations...")
|
| 212 |
+
|
| 213 |
+
rotations = {}
|
| 214 |
+
|
| 215 |
+
for layer_name in pre_merge_activations:
|
| 216 |
+
if layer_name not in post_merge_activations or layer_name not in target_activations:
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
pre = pre_merge_activations[layer_name].float() # Before last merge
|
| 220 |
+
post = post_merge_activations[layer_name].float() # After last merge
|
| 221 |
+
target = target_activations[layer_name].float() # Ideal target
|
| 222 |
+
|
| 223 |
+
# Delta from last merge
|
| 224 |
+
merge_delta = post - pre # [samples, hidden_dim]
|
| 225 |
+
|
| 226 |
+
# Gap remaining (what we still need)
|
| 227 |
+
gap = target - post # [samples, hidden_dim]
|
| 228 |
+
|
| 229 |
+
# Average across samples to get direction vectors
|
| 230 |
+
delta_dir = merge_delta.mean(dim=0) # [hidden_dim]
|
| 231 |
+
gap_dir = gap.mean(dim=0) # [hidden_dim]
|
| 232 |
+
|
| 233 |
+
# Normalise
|
| 234 |
+
delta_norm = delta_dir / (delta_dir.norm() + 1e-8)
|
| 235 |
+
gap_norm = gap_dir / (gap_dir.norm() + 1e-8)
|
| 236 |
+
|
| 237 |
+
# Compute rotation from delta direction to gap direction
|
| 238 |
+
# Using Rodrigues' rotation formula for the 2D plane
|
| 239 |
+
# spanned by delta and gap
|
| 240 |
+
cos_theta = torch.dot(delta_norm, gap_norm).clamp(-1, 1)
|
| 241 |
+
sin_theta = torch.sqrt(1 - cos_theta ** 2)
|
| 242 |
+
|
| 243 |
+
# Store as a simple rotation descriptor
|
| 244 |
+
rotations[layer_name] = {
|
| 245 |
+
"delta_direction": delta_norm,
|
| 246 |
+
"gap_direction": gap_norm,
|
| 247 |
+
"cos_theta": cos_theta.item(),
|
| 248 |
+
"sin_theta": sin_theta.item(),
|
| 249 |
+
"gap_magnitude": gap_dir.norm().item(),
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
return rotations
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def apply_arm_steering(
|
| 256 |
+
weight_delta: torch.Tensor,
|
| 257 |
+
rotation_info: dict,
|
| 258 |
+
steering_strength: float = 0.5,
|
| 259 |
+
) -> torch.Tensor:
|
| 260 |
+
"""
|
| 261 |
+
Steer a weight delta using ARM rotation vectors.
|
| 262 |
+
|
| 263 |
+
Instead of blindly projecting out previous merge directions
|
| 264 |
+
(our old orthogonal projection), ARM STEERS the delta toward
|
| 265 |
+
the remaining gap.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
weight_delta: The raw delta from the current merge
|
| 269 |
+
rotation_info: ARM rotation info for this layer
|
| 270 |
+
steering_strength: How much to steer (0=no steering, 1=full)
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
Steered weight delta
|
| 274 |
+
"""
|
| 275 |
+
delta_dir = rotation_info["delta_direction"]
|
| 276 |
+
gap_dir = rotation_info["gap_direction"]
|
| 277 |
+
|
| 278 |
+
flat = weight_delta.flatten().float()
|
| 279 |
+
|
| 280 |
+
# Component along previous merge direction
|
| 281 |
+
prev_component = torch.dot(flat, delta_dir.to(flat.device))
|
| 282 |
+
|
| 283 |
+
# Remove some of the previous-direction component
|
| 284 |
+
# and add gap-direction component instead
|
| 285 |
+
correction = (
|
| 286 |
+
-steering_strength * prev_component * delta_dir.to(flat.device)
|
| 287 |
+
+ steering_strength * prev_component * gap_dir.to(flat.device)
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
steered = flat + correction
|
| 291 |
+
return steered.reshape(weight_delta.shape).to(weight_delta.dtype)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# ============================================================================
|
| 295 |
+
# 3. OTMF β Transferability Masks via Optimal Transport (2511.19561)
|
| 296 |
+
# ============================================================================
|
| 297 |
+
#
|
| 298 |
+
# OTMF discovers which parts of each model are "transferable" (shared
|
| 299 |
+
# knowledge) vs "task-specific" (unique to that model).
|
| 300 |
+
#
|
| 301 |
+
# Transferable weights β safe to merge/average
|
| 302 |
+
# Task-specific weights β must be preserved carefully
|
| 303 |
+
#
|
| 304 |
+
# This replaces our MagMax "top 20% by magnitude" heuristic with a
|
| 305 |
+
# principled, data-driven approach.
|
| 306 |
+
|
| 307 |
+
def compute_transferability_masks(
|
| 308 |
+
model: AutoModelForCausalLM,
|
| 309 |
+
calibration_activations: dict,
|
| 310 |
+
threshold: float = 0.3,
|
| 311 |
+
) -> dict:
|
| 312 |
+
"""
|
| 313 |
+
Compute per-parameter transferability masks using activation variance.
|
| 314 |
+
|
| 315 |
+
High activation variance across diverse inputs β parameter encodes
|
| 316 |
+
task-specific knowledge (DON'T merge aggressively).
|
| 317 |
+
|
| 318 |
+
Low activation variance β parameter encodes shared/general knowledge
|
| 319 |
+
(safe to merge/average).
|
| 320 |
+
|
| 321 |
+
This is a simplified version of OTMF's OT-based mask discovery.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
model: The current merged model
|
| 325 |
+
calibration_activations: Layer β [samples, hidden_dim] activations
|
| 326 |
+
threshold: Variance quantile threshold for "task-specific" classification
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
Dict of param_name β bool mask (True = transferable/safe, False = task-specific/protect)
|
| 330 |
+
"""
|
| 331 |
+
print("[otmf] Computing transferability masks...")
|
| 332 |
+
|
| 333 |
+
masks = {}
|
| 334 |
+
state = model.state_dict()
|
| 335 |
+
|
| 336 |
+
# Compute per-neuron activation variance
|
| 337 |
+
neuron_importance = {}
|
| 338 |
+
for layer_name, acts in calibration_activations.items():
|
| 339 |
+
# Variance across samples: high variance = this neuron is doing something specific
|
| 340 |
+
variance = acts.var(dim=0) # [hidden_dim]
|
| 341 |
+
neuron_importance[layer_name] = variance
|
| 342 |
+
|
| 343 |
+
# Map neuron importance to parameter importance
|
| 344 |
+
for param_name, param in state.items():
|
| 345 |
+
# Find the corresponding layer's importance
|
| 346 |
+
layer_prefix = ".".join(param_name.split(".")[:4]) # e.g., model.layers.0.self_attn
|
| 347 |
+
|
| 348 |
+
importance = None
|
| 349 |
+
for layer_name, var in neuron_importance.items():
|
| 350 |
+
if layer_prefix in layer_name:
|
| 351 |
+
importance = var
|
| 352 |
+
break
|
| 353 |
+
|
| 354 |
+
if importance is None:
|
| 355 |
+
# Default: mark everything as transferable (safe to merge)
|
| 356 |
+
masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
|
| 357 |
+
continue
|
| 358 |
+
|
| 359 |
+
# For 2D weights: importance determines which rows/columns to protect
|
| 360 |
+
if param.dim() == 2:
|
| 361 |
+
rows, cols = param.shape
|
| 362 |
+
# Use importance for the output dimension
|
| 363 |
+
imp = importance[:rows] if importance.shape[0] >= rows else importance
|
| 364 |
+
|
| 365 |
+
# Compute threshold: top (1-threshold) fraction is task-specific
|
| 366 |
+
if imp.numel() > 0:
|
| 367 |
+
q = torch.quantile(imp.float(), 1.0 - threshold)
|
| 368 |
+
# True = transferable (below threshold), False = task-specific (protect)
|
| 369 |
+
row_mask = imp < q
|
| 370 |
+
masks[param_name] = row_mask.unsqueeze(1).expand_as(param)
|
| 371 |
+
else:
|
| 372 |
+
masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
|
| 373 |
+
else:
|
| 374 |
+
# 1D params (biases, norms): default to transferable
|
| 375 |
+
masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
|
| 376 |
+
|
| 377 |
+
transferable = sum(m.sum().item() for m in masks.values())
|
| 378 |
+
total = sum(m.numel() for m in masks.values())
|
| 379 |
+
print(f"[otmf] Transferability: {transferable / total:.1%} transferable, {1 - transferable / total:.1%} task-specific")
|
| 380 |
+
|
| 381 |
+
return masks
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def apply_masked_merge(
|
| 385 |
+
target_state: dict,
|
| 386 |
+
fused_state: dict,
|
| 387 |
+
masks: dict,
|
| 388 |
+
protect_strength: float = 0.8,
|
| 389 |
+
) -> dict:
|
| 390 |
+
"""
|
| 391 |
+
Apply transferability masks during merge.
|
| 392 |
+
|
| 393 |
+
For transferable weights: use the fused (merged) value
|
| 394 |
+
For task-specific weights: preserve more of the original target value
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
target_state: Original target weights (before this merge)
|
| 398 |
+
fused_state: Newly fused weights (after T&M/Theseus fusion)
|
| 399 |
+
masks: Transferability masks (True = safe to change)
|
| 400 |
+
protect_strength: How much to protect task-specific weights (0-1)
|
| 401 |
+
|
| 402 |
+
Returns:
|
| 403 |
+
Masked merged state dict
|
| 404 |
+
"""
|
| 405 |
+
result = {}
|
| 406 |
+
|
| 407 |
+
for key in fused_state:
|
| 408 |
+
if key in masks and key in target_state:
|
| 409 |
+
mask = masks[key].to(fused_state[key].device)
|
| 410 |
+
original = target_state[key]
|
| 411 |
+
fused = fused_state[key]
|
| 412 |
+
|
| 413 |
+
# Transferable: use fused value
|
| 414 |
+
# Task-specific: blend more toward original
|
| 415 |
+
blended = torch.where(
|
| 416 |
+
mask,
|
| 417 |
+
fused, # Transferable β take merged value
|
| 418 |
+
protect_strength * original + (1 - protect_strength) * fused, # Protected
|
| 419 |
+
)
|
| 420 |
+
result[key] = blended
|
| 421 |
+
else:
|
| 422 |
+
result[key] = fused_state[key]
|
| 423 |
+
|
| 424 |
+
protected_params = sum(1 for k in masks if not masks[k].all())
|
| 425 |
+
print(f"[otmf] Applied masks: {protected_params} parameters partially protected")
|
| 426 |
+
|
| 427 |
+
return result
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# ============================================================================
|
| 431 |
+
# 4. RAM β RL-Weight Disentanglement (2601.13572)
|
| 432 |
+
# ============================================================================
|
| 433 |
+
#
|
| 434 |
+
# RL-trained models (DeepSeek-R1, MiMo-7B-RL) have two types of knowledge:
|
| 435 |
+
# - Shared: general language understanding (same as base model)
|
| 436 |
+
# - RL-specific: reasoning patterns learned via GRPO/RLHF
|
| 437 |
+
#
|
| 438 |
+
# RAM separates these so we can merge the shared parts normally
|
| 439 |
+
# but PRESERVE the RL-specific parts that make these models special.
|
| 440 |
+
|
| 441 |
+
def disentangle_rl_weights(
|
| 442 |
+
rl_model: AutoModelForCausalLM,
|
| 443 |
+
base_model: AutoModelForCausalLM,
|
| 444 |
+
rl_threshold: float = 0.1,
|
| 445 |
+
) -> tuple:
|
| 446 |
+
"""
|
| 447 |
+
Separate RL-specific weights from shared/general weights.
|
| 448 |
+
|
| 449 |
+
RL-specific = weights that changed significantly during RL training
|
| 450 |
+
Shared = weights that are basically the same as base
|
| 451 |
+
|
| 452 |
+
We identify RL-specific weights by looking at the magnitude of
|
| 453 |
+
change from base model to RL model. Big changes β RL learned
|
| 454 |
+
something there β don't average it away.
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
rl_model: The RL-trained model (e.g., DeepSeek-R1, MiMo-7B-RL)
|
| 458 |
+
base_model: The base model before RL training
|
| 459 |
+
rl_threshold: Relative change threshold for "RL-specific" classification
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
Tuple of (shared_mask, rl_mask) β both are dicts of param_name β bool tensor
|
| 463 |
+
shared_mask: True = this weight is shared (safe to merge normally)
|
| 464 |
+
rl_mask: True = this weight is RL-specific (protect during merge)
|
| 465 |
+
"""
|
| 466 |
+
print("[ram] Disentangling RL-specific vs shared weights...")
|
| 467 |
+
|
| 468 |
+
rl_state = rl_model.state_dict()
|
| 469 |
+
base_state = base_model.state_dict()
|
| 470 |
+
|
| 471 |
+
shared_mask = {}
|
| 472 |
+
rl_mask = {}
|
| 473 |
+
|
| 474 |
+
total_params = 0
|
| 475 |
+
rl_params = 0
|
| 476 |
+
|
| 477 |
+
for key in rl_state:
|
| 478 |
+
if key not in base_state:
|
| 479 |
+
# New param (e.g., MTP head) β mark as RL-specific
|
| 480 |
+
rl_mask[key] = torch.ones_like(rl_state[key], dtype=torch.bool)
|
| 481 |
+
shared_mask[key] = torch.zeros_like(rl_state[key], dtype=torch.bool)
|
| 482 |
+
rl_params += rl_state[key].numel()
|
| 483 |
+
total_params += rl_state[key].numel()
|
| 484 |
+
continue
|
| 485 |
+
|
| 486 |
+
rl_w = rl_state[key].float()
|
| 487 |
+
base_w = base_state[key].float()
|
| 488 |
+
|
| 489 |
+
# Relative change: |rl - base| / (|base| + epsilon)
|
| 490 |
+
change = (rl_w - base_w).abs()
|
| 491 |
+
base_magnitude = base_w.abs() + 1e-8
|
| 492 |
+
relative_change = change / base_magnitude
|
| 493 |
+
|
| 494 |
+
# RL-specific: relative change > threshold
|
| 495 |
+
is_rl = relative_change > rl_threshold
|
| 496 |
+
rl_mask[key] = is_rl
|
| 497 |
+
shared_mask[key] = ~is_rl
|
| 498 |
+
|
| 499 |
+
rl_params += is_rl.sum().item()
|
| 500 |
+
total_params += is_rl.numel()
|
| 501 |
+
|
| 502 |
+
pct = rl_params / total_params * 100 if total_params > 0 else 0
|
| 503 |
+
print(f"[ram] RL-specific: {rl_params:,} params ({pct:.1f}%)")
|
| 504 |
+
print(f"[ram] Shared: {total_params - rl_params:,} params ({100 - pct:.1f}%)")
|
| 505 |
+
|
| 506 |
+
return shared_mask, rl_mask
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def merge_with_rl_preservation(
|
| 510 |
+
target_state: dict,
|
| 511 |
+
source_state: dict,
|
| 512 |
+
shared_mask: dict,
|
| 513 |
+
rl_mask: dict,
|
| 514 |
+
shared_alpha: float = 0.5,
|
| 515 |
+
rl_alpha: float = 0.8,
|
| 516 |
+
) -> dict:
|
| 517 |
+
"""
|
| 518 |
+
Merge source into target while preserving RL-specific weights.
|
| 519 |
+
|
| 520 |
+
Shared weights: normal blending at shared_alpha
|
| 521 |
+
RL-specific weights: stronger blending toward source (preserve RL knowledge)
|
| 522 |
+
|
| 523 |
+
This prevents the RL reasoning capabilities from being diluted
|
| 524 |
+
by averaging with target weights.
|
| 525 |
+
|
| 526 |
+
Args:
|
| 527 |
+
target_state: Current target model state
|
| 528 |
+
source_state: RL model state to merge in
|
| 529 |
+
shared_mask: Which params are shared (safe for normal merge)
|
| 530 |
+
rl_mask: Which params are RL-specific (preserve with higher alpha)
|
| 531 |
+
shared_alpha: Alpha for shared weights (normal)
|
| 532 |
+
rl_alpha: Alpha for RL-specific weights (higher = preserve more RL knowledge)
|
| 533 |
+
"""
|
| 534 |
+
print(f"[ram] Merging with RL preservation (shared Ξ±={shared_alpha}, RL Ξ±={rl_alpha})...")
|
| 535 |
+
|
| 536 |
+
result = {}
|
| 537 |
+
for key in target_state:
|
| 538 |
+
if key not in source_state:
|
| 539 |
+
result[key] = target_state[key]
|
| 540 |
+
continue
|
| 541 |
+
|
| 542 |
+
target_w = target_state[key]
|
| 543 |
+
source_w = source_state[key]
|
| 544 |
+
|
| 545 |
+
if source_w.shape != target_w.shape:
|
| 546 |
+
result[key] = target_state[key]
|
| 547 |
+
continue
|
| 548 |
+
|
| 549 |
+
if key in rl_mask and key in shared_mask:
|
| 550 |
+
rl_m = rl_mask[key].to(target_w.device)
|
| 551 |
+
# RL-specific: use higher alpha (preserve RL knowledge)
|
| 552 |
+
# Shared: use normal alpha
|
| 553 |
+
alpha_map = torch.where(rl_m, rl_alpha, shared_alpha)
|
| 554 |
+
if alpha_map.shape != target_w.shape:
|
| 555 |
+
alpha_map = alpha_map.expand_as(target_w) if alpha_map.dim() > 0 else torch.full_like(target_w, shared_alpha)
|
| 556 |
+
|
| 557 |
+
result[key] = alpha_map * source_w.to(target_w.device) + (1 - alpha_map) * target_w
|
| 558 |
+
else:
|
| 559 |
+
result[key] = shared_alpha * source_w.to(target_w.device) + (1 - shared_alpha) * target_w
|
| 560 |
+
|
| 561 |
+
return result
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
# ============================================================================
|
| 565 |
+
# 5. MERGEABILITY PRE-CHECK (2601.22285)
|
| 566 |
+
# ============================================================================
|
| 567 |
+
#
|
| 568 |
+
# Before spending GPU hours on a merge that might fail, check if the
|
| 569 |
+
# models are actually COMPATIBLE enough to merge.
|
| 570 |
+
#
|
| 571 |
+
# Mergeability score: 0.0 (definitely won't work) to 1.0 (should work great)
|
| 572 |
+
|
| 573 |
+
def compute_mergeability_score(
|
| 574 |
+
source_activations: dict,
|
| 575 |
+
target_activations: dict,
|
| 576 |
+
source_config: ModelConfig,
|
| 577 |
+
) -> dict:
|
| 578 |
+
"""
|
| 579 |
+
Predict how well a source model will merge into the target.
|
| 580 |
+
|
| 581 |
+
Scores based on three factors:
|
| 582 |
+
1. Activation similarity (cosine similarity of mean activations)
|
| 583 |
+
2. Dimensional compatibility (how similar are the layer shapes)
|
| 584 |
+
3. Architecture match (same arch = bonus)
|
| 585 |
+
|
| 586 |
+
Returns:
|
| 587 |
+
Dict with individual scores and overall mergeability (0-1)
|
| 588 |
+
"""
|
| 589 |
+
print(f"[mergeability] Scoring {source_config.name}...")
|
| 590 |
+
|
| 591 |
+
scores = {}
|
| 592 |
+
|
| 593 |
+
# --- Factor 1: Activation similarity ---
|
| 594 |
+
cosine_sims = []
|
| 595 |
+
source_layers = sorted(source_activations.keys())
|
| 596 |
+
target_layers = sorted(target_activations.keys())
|
| 597 |
+
|
| 598 |
+
# Match layers by position (proportional mapping)
|
| 599 |
+
for i, tl in enumerate(target_layers):
|
| 600 |
+
# Map target layer index to source layer index
|
| 601 |
+
src_idx = int(i * len(source_layers) / len(target_layers))
|
| 602 |
+
src_idx = min(src_idx, len(source_layers) - 1)
|
| 603 |
+
sl = source_layers[src_idx]
|
| 604 |
+
|
| 605 |
+
if sl in source_activations and tl in target_activations:
|
| 606 |
+
s_mean = source_activations[sl].float().mean(dim=0)
|
| 607 |
+
t_mean = target_activations[tl].float().mean(dim=0)
|
| 608 |
+
|
| 609 |
+
# Pad to same dimension for cosine similarity
|
| 610 |
+
max_dim = max(s_mean.shape[0], t_mean.shape[0])
|
| 611 |
+
s_padded = torch.nn.functional.pad(s_mean, (0, max_dim - s_mean.shape[0]))
|
| 612 |
+
t_padded = torch.nn.functional.pad(t_mean, (0, max_dim - t_mean.shape[0]))
|
| 613 |
+
|
| 614 |
+
cos_sim = torch.nn.functional.cosine_similarity(
|
| 615 |
+
s_padded.unsqueeze(0), t_padded.unsqueeze(0)
|
| 616 |
+
).item()
|
| 617 |
+
cosine_sims.append(cos_sim)
|
| 618 |
+
|
| 619 |
+
activation_score = np.mean(cosine_sims) if cosine_sims else 0.0
|
| 620 |
+
scores["activation_similarity"] = float(activation_score)
|
| 621 |
+
|
| 622 |
+
# --- Factor 2: Dimensional compatibility ---
|
| 623 |
+
layer_ratio = min(source_config.layers, 36) / max(source_config.layers, 36)
|
| 624 |
+
hidden_ratio = min(source_config.hidden_dim, 4096) / max(source_config.hidden_dim, 4096)
|
| 625 |
+
dim_score = (layer_ratio + hidden_ratio) / 2
|
| 626 |
+
scores["dimensional_compatibility"] = float(dim_score)
|
| 627 |
+
|
| 628 |
+
# --- Factor 3: Architecture match ---
|
| 629 |
+
arch_scores = {
|
| 630 |
+
"transformer": 1.0, # Same as Qwen3
|
| 631 |
+
"transformer+mtp": 0.8, # Close, just drop extras
|
| 632 |
+
"hybrid_ssm": 0.5, # Very different
|
| 633 |
+
}
|
| 634 |
+
arch_score = arch_scores.get(source_config.architecture, 0.3)
|
| 635 |
+
scores["architecture_match"] = float(arch_score)
|
| 636 |
+
|
| 637 |
+
# --- Factor 4: Vocab overlap (bonus) ---
|
| 638 |
+
vocab_score = source_config.vocab_overlap_with_qwen3
|
| 639 |
+
scores["vocab_overlap"] = float(vocab_score)
|
| 640 |
+
|
| 641 |
+
# --- Overall: weighted average ---
|
| 642 |
+
overall = (
|
| 643 |
+
0.35 * activation_score + # Most important β actual representation similarity
|
| 644 |
+
0.25 * dim_score + # Shape compatibility
|
| 645 |
+
0.25 * arch_score + # Architecture type
|
| 646 |
+
0.15 * vocab_score # Vocab overlap
|
| 647 |
+
)
|
| 648 |
+
scores["overall"] = float(overall)
|
| 649 |
+
|
| 650 |
+
# --- Recommendation ---
|
| 651 |
+
if overall >= 0.7:
|
| 652 |
+
recommendation = "GO β standard T&M merge"
|
| 653 |
+
elif overall >= 0.5:
|
| 654 |
+
recommendation = "CAUTION β T&M merge with higher protection, have Theseus fallback ready"
|
| 655 |
+
elif overall >= 0.3:
|
| 656 |
+
recommendation = "RISKY β try Theseus first, distillation fallback"
|
| 657 |
+
else:
|
| 658 |
+
recommendation = "SKIP β use knowledge distillation instead"
|
| 659 |
+
|
| 660 |
+
scores["recommendation"] = recommendation
|
| 661 |
+
|
| 662 |
+
print(f"[mergeability] {source_config.name} score: {overall:.2f}")
|
| 663 |
+
print(f" Activation similarity: {activation_score:.2f}")
|
| 664 |
+
print(f" Dimensional compat: {dim_score:.2f}")
|
| 665 |
+
print(f" Architecture match: {arch_score:.2f}")
|
| 666 |
+
print(f" Vocab overlap: {vocab_score:.2f}")
|
| 667 |
+
print(f" β {recommendation}")
|
| 668 |
+
|
| 669 |
+
return scores
|
hugging/td_lang/engine/transport.py
ADDED
|
@@ -0,0 +1,853 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Transport and Merge β Two-sided optimal transport with streaming Sinkhorn.
|
| 3 |
+
|
| 4 |
+
Implements the actual Transport and Merge paper (arxiv 2602.05495) correctly:
|
| 5 |
+
|
| 6 |
+
Paper equations implemented here:
|
| 7 |
+
- Eq 8: Q matrices for pre-activation (Q_in) and post-activation (Q_out) features
|
| 8 |
+
- Eq 13: P_eff = sqrt(P_pre Β· P_post) β effective layer transport plan
|
| 9 |
+
- Eq 14: Masked fusion with binary top-k mask M^β
|
| 10 |
+
- Appendix A.3.4: Log-domain streaming Sinkhorn (200 inner / 1000 outer iterations)
|
| 11 |
+
- Appendix A.5: Top-k=128 neuron selection
|
| 12 |
+
|
| 13 |
+
Two-sided transport (Section 4.2):
|
| 14 |
+
For each layer pair (β, m):
|
| 15 |
+
1. Compute Q_in from pre-activation features (what goes INTO the layer)
|
| 16 |
+
2. Compute Q_out from post-activation features (what comes OUT of the layer)
|
| 17 |
+
3. Derive P_pre and P_post at the layer level
|
| 18 |
+
4. Combine: P_eff[β,m] = sqrt(P_pre[β,m] Β· P_post[β,m])
|
| 19 |
+
|
| 20 |
+
Streaming Sinkhorn (Appendix A.3.4):
|
| 21 |
+
- Log-domain updates (never materialize full K = exp(-C/Ξ΅) matrix)
|
| 22 |
+
- Chunked computation for memory efficiency
|
| 23 |
+
- 200 fixed iterations for feature-level (inner) OT
|
| 24 |
+
- Up to 1000 iterations for layer-level (outer) OT
|
| 25 |
+
- Ξ΅ = 0.1 for standard text, Ξ΅ = 0.03 for math reasoning
|
| 26 |
+
|
| 27 |
+
Verified against actual paper PDF (test_21 interview round).
|
| 28 |
+
Grok scored 10/10, these implementations match Grok's citations.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import sys
|
| 32 |
+
import math
|
| 33 |
+
import torch
|
| 34 |
+
import numpy as np
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
from typing import Optional, Tuple
|
| 37 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 38 |
+
from datasets import load_dataset
|
| 39 |
+
|
| 40 |
+
from .config import MergeConfig, ModelConfig, TARGET
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ============================================================================
|
| 44 |
+
# SETUP
|
| 45 |
+
# ============================================================================
|
| 46 |
+
|
| 47 |
+
def setup_tm_repo(cfg: MergeConfig):
|
| 48 |
+
"""Add official T&M repo to Python path so we can import their code."""
|
| 49 |
+
repo_path = Path(cfg.tm_repo_path)
|
| 50 |
+
core_path = repo_path / "core"
|
| 51 |
+
|
| 52 |
+
if not core_path.exists():
|
| 53 |
+
raise FileNotFoundError(
|
| 54 |
+
f"Official T&M repo not found at {repo_path}\n"
|
| 55 |
+
f"Please clone it:\n"
|
| 56 |
+
f" git clone https://github.com/chenhangcuisg-code/"
|
| 57 |
+
f"Cross-Architecture-Merging-for-Large-Language-Models.git"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
if str(core_path) not in sys.path:
|
| 61 |
+
sys.path.insert(0, str(core_path))
|
| 62 |
+
print(f"[transport] Added T&M core to path: {core_path}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ============================================================================
|
| 66 |
+
# CALIBRATION DATA (Paper Appendix B.1: 2000 samples)
|
| 67 |
+
# ============================================================================
|
| 68 |
+
|
| 69 |
+
def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
| 70 |
+
"""
|
| 71 |
+
Load calibration data for activation extraction.
|
| 72 |
+
|
| 73 |
+
Paper Appendix B.1: "For each dataset, we randomly sample 2000 examples"
|
| 74 |
+
Mix: Pile general + neuralmagic Q&A = 2000 total samples.
|
| 75 |
+
"""
|
| 76 |
+
print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
|
| 77 |
+
|
| 78 |
+
samples = []
|
| 79 |
+
|
| 80 |
+
# --- Pile: general text (1200 samples) ---
|
| 81 |
+
try:
|
| 82 |
+
pile = load_dataset(
|
| 83 |
+
cfg.calibration_dataset_pile,
|
| 84 |
+
split="validation",
|
| 85 |
+
streaming=True,
|
| 86 |
+
trust_remote_code=True,
|
| 87 |
+
)
|
| 88 |
+
count = 0
|
| 89 |
+
target_pile = int(cfg.calibration_samples * 0.6) # 60% from Pile
|
| 90 |
+
for example in pile:
|
| 91 |
+
if count >= target_pile:
|
| 92 |
+
break
|
| 93 |
+
text = example.get("text", "")
|
| 94 |
+
if len(text) > 100:
|
| 95 |
+
tokens = tokenizer(
|
| 96 |
+
text,
|
| 97 |
+
truncation=True,
|
| 98 |
+
max_length=cfg.calibration_seq_len,
|
| 99 |
+
return_tensors="pt",
|
| 100 |
+
)
|
| 101 |
+
samples.append(tokens)
|
| 102 |
+
count += 1
|
| 103 |
+
print(f" Pile general: {count} samples")
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f" Warning: Pile failed: {e}")
|
| 106 |
+
print(f" Falling back to neuralmagic only")
|
| 107 |
+
|
| 108 |
+
# --- neuralmagic: Q&A calibration (remaining) ---
|
| 109 |
+
remaining = cfg.calibration_samples - len(samples)
|
| 110 |
+
if remaining > 0:
|
| 111 |
+
try:
|
| 112 |
+
nm = load_dataset(
|
| 113 |
+
cfg.calibration_dataset_nm,
|
| 114 |
+
split="train",
|
| 115 |
+
trust_remote_code=True,
|
| 116 |
+
)
|
| 117 |
+
count = 0
|
| 118 |
+
for example in nm:
|
| 119 |
+
if count >= remaining:
|
| 120 |
+
break
|
| 121 |
+
text = example.get("text", example.get("content", ""))
|
| 122 |
+
if len(str(text)) > 50:
|
| 123 |
+
tokens = tokenizer(
|
| 124 |
+
str(text),
|
| 125 |
+
truncation=True,
|
| 126 |
+
max_length=cfg.calibration_seq_len,
|
| 127 |
+
return_tensors="pt",
|
| 128 |
+
)
|
| 129 |
+
samples.append(tokens)
|
| 130 |
+
count += 1
|
| 131 |
+
print(f" neuralmagic: {count} samples")
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print(f" Warning: neuralmagic failed: {e}")
|
| 134 |
+
|
| 135 |
+
print(f"[transport] Total calibration samples: {len(samples)}")
|
| 136 |
+
return samples
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ============================================================================
|
| 140 |
+
# ACTIVATION EXTRACTION (Paper: attention Q,K,V,O + MLP gate,up,down)
|
| 141 |
+
# ============================================================================
|
| 142 |
+
|
| 143 |
+
# Module types to hook into (paper extracts from these specific projections)
|
| 144 |
+
ATTENTION_PROJECTIONS = ("q_proj", "k_proj", "v_proj", "o_proj")
|
| 145 |
+
MLP_PROJECTIONS = ("gate_proj", "up_proj", "down_proj")
|
| 146 |
+
ALL_PROJECTIONS = ATTENTION_PROJECTIONS + MLP_PROJECTIONS
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def extract_activations(
|
| 150 |
+
model: AutoModelForCausalLM,
|
| 151 |
+
calibration_data: list,
|
| 152 |
+
device: str = "cuda",
|
| 153 |
+
) -> dict:
|
| 154 |
+
"""
|
| 155 |
+
Extract pre-activation AND post-activation features from each projection module.
|
| 156 |
+
|
| 157 |
+
Paper Section 4.2: Two-sided transport requires both:
|
| 158 |
+
- Pre-activation features (input to each projection) β for Q_in
|
| 159 |
+
- Post-activation features (output of each projection) β for Q_out
|
| 160 |
+
|
| 161 |
+
Only hooks into attention projections (Q,K,V,O) and MLP projections
|
| 162 |
+
(gate, up, down). NOT every arbitrary layer β paper is specific about this.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
Dict with keys like:
|
| 166 |
+
"model.layers.0.self_attn.q_proj.pre" β [num_samples, input_dim]
|
| 167 |
+
"model.layers.0.self_attn.q_proj.post" β [num_samples, output_dim]
|
| 168 |
+
"""
|
| 169 |
+
print(f"[transport] Extracting two-sided activations from {len(calibration_data)} samples...")
|
| 170 |
+
|
| 171 |
+
activations = {}
|
| 172 |
+
hooks = []
|
| 173 |
+
|
| 174 |
+
# Register hooks on attention and MLP projection modules only
|
| 175 |
+
for name, module in model.named_modules():
|
| 176 |
+
# Check if this is a projection module we care about
|
| 177 |
+
module_type = name.split(".")[-1] if "." in name else name
|
| 178 |
+
if module_type not in ALL_PROJECTIONS:
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
# Skip vision encoder modules
|
| 182 |
+
if any(name.startswith(pfx) for pfx in ("visual", "merger")):
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
def make_hook(layer_name):
|
| 186 |
+
def hook_fn(module, input_tensor, output):
|
| 187 |
+
# Pre-activation: input to this linear layer
|
| 188 |
+
pre = input_tensor[0] if isinstance(input_tensor, tuple) else input_tensor
|
| 189 |
+
# Post-activation: output of this linear layer
|
| 190 |
+
post = output[0] if isinstance(output, tuple) else output
|
| 191 |
+
|
| 192 |
+
pre_key = f"{layer_name}.pre"
|
| 193 |
+
post_key = f"{layer_name}.post"
|
| 194 |
+
|
| 195 |
+
if pre_key not in activations:
|
| 196 |
+
activations[pre_key] = []
|
| 197 |
+
if post_key not in activations:
|
| 198 |
+
activations[post_key] = []
|
| 199 |
+
|
| 200 |
+
# Mean pool over sequence length β [hidden_dim]
|
| 201 |
+
activations[pre_key].append(
|
| 202 |
+
pre.detach().float().mean(dim=1).cpu()
|
| 203 |
+
)
|
| 204 |
+
activations[post_key].append(
|
| 205 |
+
post.detach().float().mean(dim=1).cpu()
|
| 206 |
+
)
|
| 207 |
+
return hook_fn
|
| 208 |
+
|
| 209 |
+
h = module.register_forward_hook(make_hook(name))
|
| 210 |
+
hooks.append(h)
|
| 211 |
+
|
| 212 |
+
# Forward pass on calibration data
|
| 213 |
+
model.eval()
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
for i, tokens in enumerate(calibration_data):
|
| 216 |
+
inputs = {k: v.to(device) for k, v in tokens.items()}
|
| 217 |
+
try:
|
| 218 |
+
model(**inputs)
|
| 219 |
+
except Exception as e:
|
| 220 |
+
print(f" Warning: Sample {i} failed: {e}")
|
| 221 |
+
continue
|
| 222 |
+
|
| 223 |
+
if (i + 1) % 200 == 0:
|
| 224 |
+
print(f" Processed {i + 1}/{len(calibration_data)} samples")
|
| 225 |
+
|
| 226 |
+
# Remove hooks
|
| 227 |
+
for h in hooks:
|
| 228 |
+
h.remove()
|
| 229 |
+
|
| 230 |
+
# Stack activations: [num_samples, hidden_dim]
|
| 231 |
+
for key in activations:
|
| 232 |
+
activations[key] = torch.cat(activations[key], dim=0)
|
| 233 |
+
|
| 234 |
+
n_modules = len(activations) // 2 # pre + post per module
|
| 235 |
+
print(f"[transport] Extracted activations from {n_modules} projection modules (two-sided)")
|
| 236 |
+
|
| 237 |
+
return activations
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# ============================================================================
|
| 241 |
+
# LOG-DOMAIN STREAMING SINKHORN (Paper Appendix A.3.4)
|
| 242 |
+
# ============================================================================
|
| 243 |
+
|
| 244 |
+
def _log_sinkhorn_streaming(
|
| 245 |
+
cost_matrix: np.ndarray,
|
| 246 |
+
reg: float = 0.1,
|
| 247 |
+
max_iter: int = 200,
|
| 248 |
+
chunk_size: int = 512,
|
| 249 |
+
) -> np.ndarray:
|
| 250 |
+
"""
|
| 251 |
+
Log-domain streaming Sinkhorn solver.
|
| 252 |
+
|
| 253 |
+
Paper Appendix A.3.4:
|
| 254 |
+
"We use a memory-efficient streaming Sinkhorn solver with fixed 200 iterations"
|
| 255 |
+
|
| 256 |
+
Log-domain means we work with log(K) = -C/Ξ΅ instead of K = exp(-C/Ξ΅).
|
| 257 |
+
This prevents numerical overflow/underflow with large matrices.
|
| 258 |
+
|
| 259 |
+
Streaming means we process the cost matrix in chunks instead of
|
| 260 |
+
materializing the full kernel matrix K in memory.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
cost_matrix: [n, m] cost matrix (correlation distance)
|
| 264 |
+
reg: Entropic regularisation Ξ΅ (paper default 0.1)
|
| 265 |
+
max_iter: Number of Sinkhorn iterations (paper: 200 inner, 1000 outer)
|
| 266 |
+
chunk_size: Process this many rows/cols at a time for memory efficiency
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
[n, m] transport plan matrix
|
| 270 |
+
"""
|
| 271 |
+
n, m = cost_matrix.shape
|
| 272 |
+
|
| 273 |
+
# Log-domain: work with log potentials instead of scaling vectors
|
| 274 |
+
# This is numerically stable β no exp() overflow
|
| 275 |
+
log_u = np.zeros(n) # Log of row scaling vector
|
| 276 |
+
log_v = np.zeros(m) # Log of column scaling vector
|
| 277 |
+
|
| 278 |
+
# Uniform marginals (both sides sum to 1)
|
| 279 |
+
log_a = np.full(n, -np.log(n)) # log(1/n)
|
| 280 |
+
log_b = np.full(m, -np.log(m)) # log(1/m)
|
| 281 |
+
|
| 282 |
+
# Log kernel: log(K_ij) = -C_ij / Ξ΅
|
| 283 |
+
log_K = -cost_matrix / reg
|
| 284 |
+
|
| 285 |
+
for iteration in range(max_iter):
|
| 286 |
+
# --- Row update (streaming over chunks of columns) ---
|
| 287 |
+
# log_u = log_a - logsumexp(log_K + log_v, axis=1)
|
| 288 |
+
log_sum = np.full(n, -np.inf)
|
| 289 |
+
for j_start in range(0, m, chunk_size):
|
| 290 |
+
j_end = min(j_start + chunk_size, m)
|
| 291 |
+
chunk = log_K[:, j_start:j_end] + log_v[j_start:j_end]
|
| 292 |
+
chunk_max = np.maximum(log_sum, chunk.max(axis=1))
|
| 293 |
+
log_sum = chunk_max + np.log(
|
| 294 |
+
np.exp(log_sum - chunk_max) +
|
| 295 |
+
np.exp(chunk - chunk_max[:, None]).sum(axis=1)
|
| 296 |
+
)
|
| 297 |
+
log_u = log_a - log_sum
|
| 298 |
+
|
| 299 |
+
# --- Column update (streaming over chunks of rows) ---
|
| 300 |
+
# log_v = log_b - logsumexp(log_K.T + log_u, axis=1)
|
| 301 |
+
log_sum = np.full(m, -np.inf)
|
| 302 |
+
for i_start in range(0, n, chunk_size):
|
| 303 |
+
i_end = min(i_start + chunk_size, n)
|
| 304 |
+
chunk = log_K[i_start:i_end, :].T + log_u[i_start:i_end]
|
| 305 |
+
# chunk shape: [m, chunk_rows]
|
| 306 |
+
chunk_max = np.maximum(log_sum, chunk.max(axis=1))
|
| 307 |
+
log_sum = chunk_max + np.log(
|
| 308 |
+
np.exp(log_sum - chunk_max) +
|
| 309 |
+
np.exp(chunk - chunk_max[:, None]).sum(axis=1)
|
| 310 |
+
)
|
| 311 |
+
log_v = log_b - log_sum
|
| 312 |
+
|
| 313 |
+
# Recover transport plan: T_ij = exp(log_u_i + log_K_ij + log_v_j)
|
| 314 |
+
# Do this in chunks too to avoid materializing full matrix at once
|
| 315 |
+
T = np.zeros((n, m), dtype=np.float32)
|
| 316 |
+
for j_start in range(0, m, chunk_size):
|
| 317 |
+
j_end = min(j_start + chunk_size, m)
|
| 318 |
+
T[:, j_start:j_end] = np.exp(
|
| 319 |
+
log_u[:, None] + log_K[:, j_start:j_end] + log_v[j_start:j_end]
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
return T
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _sinkhorn_basic(
|
| 326 |
+
cost_matrix: np.ndarray,
|
| 327 |
+
reg: float = 0.1,
|
| 328 |
+
max_iter: int = 200,
|
| 329 |
+
) -> np.ndarray:
|
| 330 |
+
"""
|
| 331 |
+
Basic (non-streaming) Sinkhorn for small matrices (e.g., layer-level P).
|
| 332 |
+
|
| 333 |
+
Used for the layer-level transport plan where matrices are small
|
| 334 |
+
(e.g., 36Γ32 for Qwen3βLlama layer mapping).
|
| 335 |
+
"""
|
| 336 |
+
n, m = cost_matrix.shape
|
| 337 |
+
K = np.exp(-cost_matrix / reg)
|
| 338 |
+
|
| 339 |
+
u = np.ones(n) / n
|
| 340 |
+
v = np.ones(m) / m
|
| 341 |
+
|
| 342 |
+
for _ in range(max_iter):
|
| 343 |
+
u = (1.0 / n) / (K @ v + 1e-10)
|
| 344 |
+
v = (1.0 / m) / (K.T @ u + 1e-10)
|
| 345 |
+
|
| 346 |
+
T = np.diag(u) @ K @ np.diag(v)
|
| 347 |
+
return T
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
# ============================================================================
|
| 351 |
+
# TWO-SIDED TRANSPORT (Paper Section 4.2, Equations 8, 13)
|
| 352 |
+
# ============================================================================
|
| 353 |
+
|
| 354 |
+
def _correlation_distance(X: np.ndarray, Y: np.ndarray) -> np.ndarray:
|
| 355 |
+
"""
|
| 356 |
+
Compute correlation distance matrix between two sets of activation vectors.
|
| 357 |
+
|
| 358 |
+
cost[i, j] = 1 - pearson_correlation(X[:, i], Y[:, j])
|
| 359 |
+
|
| 360 |
+
X: [num_samples, dim_x] β activations from source
|
| 361 |
+
Y: [num_samples, dim_y] β activations from target
|
| 362 |
+
Returns: [dim_x, dim_y] cost matrix
|
| 363 |
+
"""
|
| 364 |
+
# Standardise each neuron's activations across samples
|
| 365 |
+
X_norm = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-8)
|
| 366 |
+
Y_norm = (Y - Y.mean(axis=0)) / (Y.std(axis=0) + 1e-8)
|
| 367 |
+
|
| 368 |
+
# Pearson correlation between each pair of neurons
|
| 369 |
+
corr = X_norm.T @ Y_norm / X.shape[0] # [dim_x, dim_y]
|
| 370 |
+
|
| 371 |
+
# Correlation distance
|
| 372 |
+
cost = 1.0 - corr
|
| 373 |
+
return cost.astype(np.float32)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def _get_layer_index(module_name: str) -> Optional[int]:
|
| 377 |
+
"""Extract layer index from a module name like 'model.layers.5.self_attn.q_proj'."""
|
| 378 |
+
parts = module_name.split(".")
|
| 379 |
+
for i, part in enumerate(parts):
|
| 380 |
+
if part == "layers" and i + 1 < len(parts):
|
| 381 |
+
try:
|
| 382 |
+
return int(parts[i + 1])
|
| 383 |
+
except ValueError:
|
| 384 |
+
pass
|
| 385 |
+
return None
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _get_module_type(module_name: str) -> str:
|
| 389 |
+
"""Extract module type from name like 'model.layers.5.self_attn.q_proj' β 'q_proj'."""
|
| 390 |
+
return module_name.split(".")[-1]
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def _group_activations_by_layer(
|
| 394 |
+
activations: dict,
|
| 395 |
+
side: str = "pre",
|
| 396 |
+
) -> dict:
|
| 397 |
+
"""
|
| 398 |
+
Group activation tensors by layer index.
|
| 399 |
+
|
| 400 |
+
Returns: {layer_idx: {module_type: activation_tensor}}
|
| 401 |
+
"""
|
| 402 |
+
grouped = {}
|
| 403 |
+
suffix = f".{side}"
|
| 404 |
+
for key, tensor in activations.items():
|
| 405 |
+
if not key.endswith(suffix):
|
| 406 |
+
continue
|
| 407 |
+
# Remove the .pre/.post suffix to get module name
|
| 408 |
+
module_name = key[: -len(suffix)]
|
| 409 |
+
layer_idx = _get_layer_index(module_name)
|
| 410 |
+
module_type = _get_module_type(module_name)
|
| 411 |
+
if layer_idx is not None:
|
| 412 |
+
if layer_idx not in grouped:
|
| 413 |
+
grouped[layer_idx] = {}
|
| 414 |
+
grouped[layer_idx][module_type] = tensor.numpy()
|
| 415 |
+
return grouped
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def compute_transport_plans(
|
| 419 |
+
source_activations: dict,
|
| 420 |
+
target_activations: dict,
|
| 421 |
+
cfg: MergeConfig,
|
| 422 |
+
) -> dict:
|
| 423 |
+
"""
|
| 424 |
+
Compute two-sided optimal transport plans between source and target.
|
| 425 |
+
|
| 426 |
+
Paper Section 4.2 β Two-sided transport:
|
| 427 |
+
1. For each (source_layer, target_layer) pair and each projection type:
|
| 428 |
+
- Compute Q_in from pre-activation features (Eq 8 applied to inputs)
|
| 429 |
+
- Compute Q_out from post-activation features (Eq 8 applied to outputs)
|
| 430 |
+
2. Derive layer-level costs from Q_in and Q_out β P_pre and P_post
|
| 431 |
+
3. Combine: P_eff[β,m] = sqrt(P_pre[β,m] Β· P_post[β,m]) (Eq 13)
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
Dict with:
|
| 435 |
+
'P_eff': [n_target_layers, n_source_layers] effective transport plan
|
| 436 |
+
'Q_in': {(src_layer, tgt_layer, module_type): Q matrix} β input-side neuron plans
|
| 437 |
+
'Q_out': {(src_layer, tgt_layer, module_type): Q matrix} β output-side neuron plans
|
| 438 |
+
'source_layers': sorted list of source layer indices
|
| 439 |
+
'target_layers': sorted list of target layer indices
|
| 440 |
+
"""
|
| 441 |
+
print("[transport] Computing two-sided transport plans (paper Section 4.2)...")
|
| 442 |
+
|
| 443 |
+
# Group activations by layer
|
| 444 |
+
source_pre = _group_activations_by_layer(source_activations, "pre")
|
| 445 |
+
source_post = _group_activations_by_layer(source_activations, "post")
|
| 446 |
+
target_pre = _group_activations_by_layer(target_activations, "pre")
|
| 447 |
+
target_post = _group_activations_by_layer(target_activations, "post")
|
| 448 |
+
|
| 449 |
+
source_layers = sorted(source_pre.keys())
|
| 450 |
+
target_layers = sorted(target_pre.keys())
|
| 451 |
+
|
| 452 |
+
n_source = len(source_layers)
|
| 453 |
+
n_target = len(target_layers)
|
| 454 |
+
|
| 455 |
+
print(f" Source layers: {n_source}, Target layers: {n_target}")
|
| 456 |
+
|
| 457 |
+
# --- Step 1: Compute Q_in and Q_out for each layer pair ---
|
| 458 |
+
Q_in_matrices = {}
|
| 459 |
+
Q_out_matrices = {}
|
| 460 |
+
layer_costs_pre = np.zeros((n_target, n_source))
|
| 461 |
+
layer_costs_post = np.zeros((n_target, n_source))
|
| 462 |
+
|
| 463 |
+
for ti, tl in enumerate(target_layers):
|
| 464 |
+
for si, sl in enumerate(source_layers):
|
| 465 |
+
# Get all projection types that exist in both
|
| 466 |
+
if tl not in target_pre or sl not in source_pre:
|
| 467 |
+
continue
|
| 468 |
+
|
| 469 |
+
target_modules = set(target_pre.get(tl, {}).keys())
|
| 470 |
+
source_modules = set(source_pre.get(sl, {}).keys())
|
| 471 |
+
common_modules = target_modules & source_modules
|
| 472 |
+
|
| 473 |
+
if not common_modules:
|
| 474 |
+
continue
|
| 475 |
+
|
| 476 |
+
pre_costs = []
|
| 477 |
+
post_costs = []
|
| 478 |
+
|
| 479 |
+
for mod_type in common_modules:
|
| 480 |
+
# --- Q_in: pre-activation (input-side) transport ---
|
| 481 |
+
if (sl in source_pre and mod_type in source_pre[sl] and
|
| 482 |
+
tl in target_pre and mod_type in target_pre[tl]):
|
| 483 |
+
S_pre = source_pre[sl][mod_type]
|
| 484 |
+
T_pre = target_pre[tl][mod_type]
|
| 485 |
+
cost_pre = _correlation_distance(S_pre, T_pre)
|
| 486 |
+
|
| 487 |
+
# Use streaming Sinkhorn for large matrices, basic for small
|
| 488 |
+
if max(cost_pre.shape) > 1024:
|
| 489 |
+
Q = _log_sinkhorn_streaming(
|
| 490 |
+
cost_pre,
|
| 491 |
+
reg=cfg.sinkhorn_reg,
|
| 492 |
+
max_iter=cfg.sinkhorn_inner_iter,
|
| 493 |
+
)
|
| 494 |
+
else:
|
| 495 |
+
Q = _sinkhorn_basic(
|
| 496 |
+
cost_pre,
|
| 497 |
+
reg=cfg.sinkhorn_reg,
|
| 498 |
+
max_iter=cfg.sinkhorn_inner_iter,
|
| 499 |
+
)
|
| 500 |
+
Q_in_matrices[(sl, tl, mod_type)] = Q
|
| 501 |
+
pre_costs.append(cost_pre.mean())
|
| 502 |
+
|
| 503 |
+
# --- Q_out: post-activation (output-side) transport ---
|
| 504 |
+
if (sl in source_post and mod_type in source_post[sl] and
|
| 505 |
+
tl in target_post and mod_type in target_post[tl]):
|
| 506 |
+
S_post = source_post[sl][mod_type]
|
| 507 |
+
T_post = target_post[tl][mod_type]
|
| 508 |
+
cost_post = _correlation_distance(S_post, T_post)
|
| 509 |
+
|
| 510 |
+
if max(cost_post.shape) > 1024:
|
| 511 |
+
Q = _log_sinkhorn_streaming(
|
| 512 |
+
cost_post,
|
| 513 |
+
reg=cfg.sinkhorn_reg,
|
| 514 |
+
max_iter=cfg.sinkhorn_inner_iter,
|
| 515 |
+
)
|
| 516 |
+
else:
|
| 517 |
+
Q = _sinkhorn_basic(
|
| 518 |
+
cost_post,
|
| 519 |
+
reg=cfg.sinkhorn_reg,
|
| 520 |
+
max_iter=cfg.sinkhorn_inner_iter,
|
| 521 |
+
)
|
| 522 |
+
Q_out_matrices[(sl, tl, mod_type)] = Q
|
| 523 |
+
post_costs.append(cost_post.mean())
|
| 524 |
+
|
| 525 |
+
# Average cost across projection types for this layer pair
|
| 526 |
+
if pre_costs:
|
| 527 |
+
layer_costs_pre[ti, si] = np.mean(pre_costs)
|
| 528 |
+
if post_costs:
|
| 529 |
+
layer_costs_post[ti, si] = np.mean(post_costs)
|
| 530 |
+
|
| 531 |
+
if (ti + 1) % 6 == 0:
|
| 532 |
+
print(f" Layer pairs computed: {ti + 1}/{n_target} target layers done")
|
| 533 |
+
|
| 534 |
+
# --- Step 2: Layer-level transport plans P_pre and P_post ---
|
| 535 |
+
print("[transport] Computing layer-level transport plans (P_pre, P_post)...")
|
| 536 |
+
|
| 537 |
+
P_pre = _sinkhorn_basic(
|
| 538 |
+
layer_costs_pre,
|
| 539 |
+
reg=cfg.sinkhorn_layer_reg,
|
| 540 |
+
max_iter=cfg.sinkhorn_outer_iter,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
P_post = _sinkhorn_basic(
|
| 544 |
+
layer_costs_post,
|
| 545 |
+
reg=cfg.sinkhorn_layer_reg,
|
| 546 |
+
max_iter=cfg.sinkhorn_outer_iter,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# --- Step 3: P_eff = sqrt(P_pre Β· P_post) β Equation 13 ---
|
| 550 |
+
P_eff = np.sqrt(P_pre * P_post + 1e-10)
|
| 551 |
+
|
| 552 |
+
# Normalise P_eff so each target layer's row sums to 1
|
| 553 |
+
row_sums = P_eff.sum(axis=1, keepdims=True)
|
| 554 |
+
P_eff = P_eff / (row_sums + 1e-10)
|
| 555 |
+
|
| 556 |
+
print(f"[transport] P_eff shape: {P_eff.shape}")
|
| 557 |
+
print(f" P_eff range: [{P_eff.min():.4f}, {P_eff.max():.4f}]")
|
| 558 |
+
|
| 559 |
+
# --- Step 4: Transport sparsification (Appendix A.1) ---
|
| 560 |
+
# "top-k selection strategies at both neuron and transport matrix levels"
|
| 561 |
+
# Keep only the top-k strongest source layers per target layer
|
| 562 |
+
k_layers = min(3, n_source) # Top-3 source layers per target layer
|
| 563 |
+
P_sparse = np.zeros_like(P_eff)
|
| 564 |
+
for i in range(n_target):
|
| 565 |
+
top_k_idx = np.argsort(P_eff[i])[-k_layers:]
|
| 566 |
+
P_sparse[i, top_k_idx] = P_eff[i, top_k_idx]
|
| 567 |
+
# Re-normalise
|
| 568 |
+
row_sums = P_sparse.sum(axis=1, keepdims=True)
|
| 569 |
+
P_sparse = P_sparse / (row_sums + 1e-10)
|
| 570 |
+
|
| 571 |
+
print(f"[transport] Sparsified P: keeping top-{k_layers} source layers per target")
|
| 572 |
+
|
| 573 |
+
return {
|
| 574 |
+
"P_eff": P_sparse,
|
| 575 |
+
"P_eff_dense": P_eff, # Keep dense version for debugging
|
| 576 |
+
"Q_in": Q_in_matrices,
|
| 577 |
+
"Q_out": Q_out_matrices,
|
| 578 |
+
"source_layers": source_layers,
|
| 579 |
+
"target_layers": target_layers,
|
| 580 |
+
"layer_costs_pre": layer_costs_pre,
|
| 581 |
+
"layer_costs_post": layer_costs_post,
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
# ============================================================================
|
| 586 |
+
# TOP-K MASKED FUSION (Paper Eq 14, Appendix A.5: k=128)
|
| 587 |
+
# ============================================================================
|
| 588 |
+
|
| 589 |
+
def compute_neuron_importance(
|
| 590 |
+
activations: dict,
|
| 591 |
+
layer_idx: int,
|
| 592 |
+
) -> dict:
|
| 593 |
+
"""
|
| 594 |
+
Compute neuron importance scores for top-k selection.
|
| 595 |
+
|
| 596 |
+
Paper Appendix A.5: "choosing the neurons with the highest mean
|
| 597 |
+
activation magnitudes across the calibration set"
|
| 598 |
+
|
| 599 |
+
Returns: {module_type: importance_scores [hidden_dim]}
|
| 600 |
+
"""
|
| 601 |
+
importance = {}
|
| 602 |
+
for key, tensor in activations.items():
|
| 603 |
+
if not key.endswith(".post"):
|
| 604 |
+
continue
|
| 605 |
+
module_name = key[:-5] # Remove .post
|
| 606 |
+
idx = _get_layer_index(module_name)
|
| 607 |
+
mod_type = _get_module_type(module_name)
|
| 608 |
+
if idx == layer_idx:
|
| 609 |
+
# Mean activation magnitude across calibration samples
|
| 610 |
+
importance[mod_type] = tensor.abs().mean(dim=0).numpy()
|
| 611 |
+
return importance
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def compute_top_k_mask(
|
| 615 |
+
importance_scores: np.ndarray,
|
| 616 |
+
k: int = 128,
|
| 617 |
+
) -> np.ndarray:
|
| 618 |
+
"""
|
| 619 |
+
Create binary mask for top-k most important neurons.
|
| 620 |
+
|
| 621 |
+
Paper Appendix A.5: "we set the default number of neurons to k = 128"
|
| 622 |
+
|
| 623 |
+
Returns: boolean mask [hidden_dim] where True = selected for fusion
|
| 624 |
+
"""
|
| 625 |
+
if k >= len(importance_scores):
|
| 626 |
+
return np.ones(len(importance_scores), dtype=bool)
|
| 627 |
+
|
| 628 |
+
threshold_idx = np.argsort(importance_scores)[-k:]
|
| 629 |
+
mask = np.zeros(len(importance_scores), dtype=bool)
|
| 630 |
+
mask[threshold_idx] = True
|
| 631 |
+
return mask
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def fuse_weights(
|
| 635 |
+
source_model: AutoModelForCausalLM,
|
| 636 |
+
target_model: AutoModelForCausalLM,
|
| 637 |
+
transport_plans: dict,
|
| 638 |
+
source_config: ModelConfig,
|
| 639 |
+
cfg: MergeConfig,
|
| 640 |
+
target_activations: dict = None,
|
| 641 |
+
) -> AutoModelForCausalLM:
|
| 642 |
+
"""
|
| 643 |
+
Fuse source weights into target using two-sided transport + top-k mask.
|
| 644 |
+
|
| 645 |
+
Paper Equation 14:
|
| 646 |
+
W_fused = W_target + Ξ± Β· M^β β (Ξ£_m P_eff[β,m] Β· Q_out Β· W_source Β· Q_in^T - W_target)
|
| 647 |
+
|
| 648 |
+
Where:
|
| 649 |
+
- Ξ± is the fusion coefficient (0.05-0.15)
|
| 650 |
+
- M^β is the binary top-k mask (only k=128 neurons get fused)
|
| 651 |
+
- P_eff is the effective layer transport plan
|
| 652 |
+
- Q_out and Q_in are the neuron-level transport matrices
|
| 653 |
+
- The sum is over source layers m
|
| 654 |
+
|
| 655 |
+
Returns: Target model with fused weights
|
| 656 |
+
"""
|
| 657 |
+
print(f"\n[transport] Fusing {source_config.name} -> target (two-sided + top-k={cfg.top_k_neurons})")
|
| 658 |
+
alpha = source_config.merge_alpha
|
| 659 |
+
print(f" Alpha: {alpha} (paper range: 0.05-0.15)")
|
| 660 |
+
|
| 661 |
+
source_state = source_model.state_dict()
|
| 662 |
+
target_state = target_model.state_dict()
|
| 663 |
+
|
| 664 |
+
P_eff = transport_plans["P_eff"]
|
| 665 |
+
Q_in = transport_plans["Q_in"]
|
| 666 |
+
Q_out = transport_plans["Q_out"]
|
| 667 |
+
source_layers = transport_plans["source_layers"]
|
| 668 |
+
target_layers = transport_plans["target_layers"]
|
| 669 |
+
|
| 670 |
+
fused_count = 0
|
| 671 |
+
skipped_count = 0
|
| 672 |
+
masked_neurons = 0
|
| 673 |
+
|
| 674 |
+
for ti, tl in enumerate(target_layers):
|
| 675 |
+
# Get the transport weights for this target layer
|
| 676 |
+
layer_transport = P_eff[ti] # [n_source]
|
| 677 |
+
|
| 678 |
+
# Find which source layers contribute significantly
|
| 679 |
+
active_sources = [(si, sl, layer_transport[si])
|
| 680 |
+
for si, sl in enumerate(source_layers)
|
| 681 |
+
if layer_transport[si] > 1e-6]
|
| 682 |
+
|
| 683 |
+
if not active_sources:
|
| 684 |
+
continue
|
| 685 |
+
|
| 686 |
+
# For each projection type in this target layer
|
| 687 |
+
for mod_type in ALL_PROJECTIONS:
|
| 688 |
+
target_key = _find_param_key(target_state, tl, mod_type, "weight")
|
| 689 |
+
if target_key is None:
|
| 690 |
+
continue
|
| 691 |
+
|
| 692 |
+
target_w = target_state[target_key].float()
|
| 693 |
+
|
| 694 |
+
# Compute the transported operator: Ξ£_m P_eff[β,m] Β· Q_out Β· W_source Β· Q_in^T
|
| 695 |
+
transported = torch.zeros_like(target_w)
|
| 696 |
+
total_weight = 0.0
|
| 697 |
+
|
| 698 |
+
for si, sl, p_weight in active_sources:
|
| 699 |
+
source_key = _find_source_param_key(
|
| 700 |
+
source_state, sl, mod_type, "weight", source_config
|
| 701 |
+
)
|
| 702 |
+
if source_key is None:
|
| 703 |
+
continue
|
| 704 |
+
|
| 705 |
+
source_w = source_state[source_key].float()
|
| 706 |
+
|
| 707 |
+
# Get Q matrices for this layer pair
|
| 708 |
+
q_in_key = (sl, tl, mod_type)
|
| 709 |
+
q_out_key = (sl, tl, mod_type)
|
| 710 |
+
|
| 711 |
+
q_in = Q_in.get(q_in_key)
|
| 712 |
+
q_out = Q_out.get(q_out_key)
|
| 713 |
+
|
| 714 |
+
if q_in is not None and q_out is not None:
|
| 715 |
+
# Transport: Q_out @ W_source @ Q_in^T
|
| 716 |
+
q_in_t = torch.from_numpy(q_in).float()
|
| 717 |
+
q_out_t = torch.from_numpy(q_out).float()
|
| 718 |
+
|
| 719 |
+
# Handle dimension mismatches via transport plan
|
| 720 |
+
try:
|
| 721 |
+
# q_out: [target_out, source_out], W: [source_out, source_in], q_in: [target_in, source_in]
|
| 722 |
+
# Result: [target_out, target_in]
|
| 723 |
+
transported_w = q_out_t @ source_w.to("cpu") @ q_in_t.T
|
| 724 |
+
transported += p_weight * transported_w.to(target_w.device)
|
| 725 |
+
total_weight += p_weight
|
| 726 |
+
except RuntimeError:
|
| 727 |
+
# Dimension mismatch β skip this pair
|
| 728 |
+
skipped_count += 1
|
| 729 |
+
continue
|
| 730 |
+
else:
|
| 731 |
+
# No Q matrices β direct mapping if shapes match
|
| 732 |
+
if source_w.shape == target_w.shape:
|
| 733 |
+
transported += p_weight * source_w.to(target_w.device)
|
| 734 |
+
total_weight += p_weight
|
| 735 |
+
|
| 736 |
+
if total_weight < 1e-6:
|
| 737 |
+
skipped_count += 1
|
| 738 |
+
continue
|
| 739 |
+
|
| 740 |
+
# Normalise by total transport weight
|
| 741 |
+
transported = transported / total_weight
|
| 742 |
+
|
| 743 |
+
# --- Apply top-k mask (Equation 14) ---
|
| 744 |
+
# M^β β (transported - W_target)
|
| 745 |
+
delta = transported - target_w
|
| 746 |
+
|
| 747 |
+
if target_activations is not None and cfg.top_k_neurons > 0:
|
| 748 |
+
importance = compute_neuron_importance(target_activations, tl)
|
| 749 |
+
if mod_type in importance:
|
| 750 |
+
# Mask on output dimension (rows of weight matrix)
|
| 751 |
+
mask = compute_top_k_mask(importance[mod_type], k=cfg.top_k_neurons)
|
| 752 |
+
mask_tensor = torch.from_numpy(mask).to(target_w.device)
|
| 753 |
+
|
| 754 |
+
# Apply mask: only fuse top-k neurons
|
| 755 |
+
if delta.dim() == 2:
|
| 756 |
+
# Weight matrix: mask rows (output neurons)
|
| 757 |
+
mask_2d = mask_tensor.unsqueeze(1).expand_as(delta)
|
| 758 |
+
delta = delta * mask_2d.float()
|
| 759 |
+
masked_neurons += mask.sum()
|
| 760 |
+
elif delta.dim() == 1:
|
| 761 |
+
# Bias: mask directly
|
| 762 |
+
delta = delta * mask_tensor.float()
|
| 763 |
+
masked_neurons += mask.sum()
|
| 764 |
+
|
| 765 |
+
# Final fusion: W_target + Ξ± Β· masked_delta
|
| 766 |
+
fused_w = target_w + alpha * delta
|
| 767 |
+
target_state[target_key] = fused_w.to(target_state[target_key].dtype)
|
| 768 |
+
fused_count += 1
|
| 769 |
+
|
| 770 |
+
# --- Vision encoder protection ---
|
| 771 |
+
# Restore any vision params that might have been touched
|
| 772 |
+
original_state = target_model.state_dict()
|
| 773 |
+
for key in target_state:
|
| 774 |
+
if any(key.startswith(pfx) for pfx in cfg.vision_skip_prefixes):
|
| 775 |
+
target_state[key] = original_state[key]
|
| 776 |
+
|
| 777 |
+
# --- Thinking mode protection ---
|
| 778 |
+
if cfg.freeze_think_tokens:
|
| 779 |
+
embed_key = "model.embed_tokens.weight"
|
| 780 |
+
if embed_key in target_state and embed_key in original_state:
|
| 781 |
+
for token_id in cfg.think_token_ids:
|
| 782 |
+
if token_id < target_state[embed_key].shape[0]:
|
| 783 |
+
target_state[embed_key][token_id] = original_state[embed_key][token_id]
|
| 784 |
+
print(f" Protected think token {token_id}")
|
| 785 |
+
|
| 786 |
+
# Load fused weights
|
| 787 |
+
target_model.load_state_dict(target_state)
|
| 788 |
+
print(f"[transport] Fused {fused_count} params, skipped {skipped_count}")
|
| 789 |
+
print(f" Top-k masked neurons fused: {masked_neurons}")
|
| 790 |
+
|
| 791 |
+
return target_model
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
# ============================================================================
|
| 795 |
+
# HELPER: Find parameter keys in state dicts
|
| 796 |
+
# ============================================================================
|
| 797 |
+
|
| 798 |
+
def _find_param_key(state_dict: dict, layer_idx: int, module_type: str, param_type: str = "weight") -> Optional[str]:
|
| 799 |
+
"""Find the full parameter key for a given layer, module type, and param type."""
|
| 800 |
+
# Common patterns for transformer models
|
| 801 |
+
patterns = [
|
| 802 |
+
f"model.layers.{layer_idx}.self_attn.{module_type}.{param_type}",
|
| 803 |
+
f"model.layers.{layer_idx}.mlp.{module_type}.{param_type}",
|
| 804 |
+
f"transformer.h.{layer_idx}.attn.{module_type}.{param_type}",
|
| 805 |
+
f"transformer.h.{layer_idx}.mlp.{module_type}.{param_type}",
|
| 806 |
+
]
|
| 807 |
+
for pattern in patterns:
|
| 808 |
+
if pattern in state_dict:
|
| 809 |
+
return pattern
|
| 810 |
+
return None
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
def _find_source_param_key(
|
| 814 |
+
state_dict: dict,
|
| 815 |
+
source_layer: int,
|
| 816 |
+
module_type: str,
|
| 817 |
+
param_type: str,
|
| 818 |
+
source_config: ModelConfig,
|
| 819 |
+
) -> Optional[str]:
|
| 820 |
+
"""Find param key in source model, handling architecture differences."""
|
| 821 |
+
# Try standard patterns first
|
| 822 |
+
key = _find_param_key(state_dict, source_layer, module_type, param_type)
|
| 823 |
+
if key:
|
| 824 |
+
return key
|
| 825 |
+
|
| 826 |
+
# Try architecture-specific patterns
|
| 827 |
+
if source_config.architecture == "hybrid_ssm":
|
| 828 |
+
# Falcon uses different naming
|
| 829 |
+
patterns = [
|
| 830 |
+
f"model.layers.{source_layer}.attn.{module_type}.{param_type}",
|
| 831 |
+
f"model.layers.{source_layer}.feed_forward.{module_type}.{param_type}",
|
| 832 |
+
]
|
| 833 |
+
for pattern in patterns:
|
| 834 |
+
if pattern in state_dict:
|
| 835 |
+
return pattern
|
| 836 |
+
|
| 837 |
+
return None
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
def _should_skip(key: str, source_config: ModelConfig) -> bool:
|
| 841 |
+
"""Determine if a parameter should be skipped during merge."""
|
| 842 |
+
if source_config.skip_embeddings and ("embed_tokens" in key or "lm_head" in key):
|
| 843 |
+
return True
|
| 844 |
+
if "drop_mtp_heads" in source_config.special_handling and "mtp_head" in key:
|
| 845 |
+
return True
|
| 846 |
+
if "drop_mamba_state_params" in source_config.special_handling:
|
| 847 |
+
mamba_keys = ["mamba", "A_log", "dt_proj", ".D"]
|
| 848 |
+
if any(mk in key for mk in mamba_keys):
|
| 849 |
+
return True
|
| 850 |
+
if "drop_qkv_bias" in source_config.special_handling and ".bias" in key:
|
| 851 |
+
if any(proj in key for proj in ["q_proj", "k_proj", "v_proj"]):
|
| 852 |
+
return True
|
| 853 |
+
return False
|
hugging/td_lang/engine/validate.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Post-Merge Validation β run after EVERY merge step.
|
| 3 |
+
|
| 4 |
+
Tests:
|
| 5 |
+
1. Canary recall (did knowledge transfer?)
|
| 6 |
+
2. Perplexity check (did we break the model?)
|
| 7 |
+
3. Thinking mode (do <think> tags still work?)
|
| 8 |
+
4. Quick reasoning test (can it still think?)
|
| 9 |
+
|
| 10 |
+
Kill criteria: >10% performance drop on any test β abort merge.
|
| 11 |
+
Findings: #11, #22, #25
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import math
|
| 16 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 17 |
+
|
| 18 |
+
from .canary import test_all_canaries
|
| 19 |
+
from .config import MergeConfig
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def validate_merged_model(
|
| 23 |
+
model: AutoModelForCausalLM,
|
| 24 |
+
tokenizer: AutoTokenizer,
|
| 25 |
+
merged_sources: list[str],
|
| 26 |
+
cfg: MergeConfig,
|
| 27 |
+
baseline_perplexity: float = None,
|
| 28 |
+
) -> dict:
|
| 29 |
+
"""
|
| 30 |
+
Run full validation suite on a merged model.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model: The merged model to validate
|
| 34 |
+
tokenizer: The tokenizer
|
| 35 |
+
merged_sources: List of source models merged so far
|
| 36 |
+
cfg: Merge configuration
|
| 37 |
+
baseline_perplexity: Perplexity of the target model before merging
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Dict with test results and overall pass/fail
|
| 41 |
+
"""
|
| 42 |
+
print("\n" + "=" * 60)
|
| 43 |
+
print(f"VALIDATION β After merging: {', '.join(merged_sources)}")
|
| 44 |
+
print("=" * 60)
|
| 45 |
+
|
| 46 |
+
results = {
|
| 47 |
+
"canary": None,
|
| 48 |
+
"perplexity": None,
|
| 49 |
+
"thinking_mode": None,
|
| 50 |
+
"reasoning": None,
|
| 51 |
+
"overall": False,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
# --- Test 1: Canary recall ---
|
| 55 |
+
canary_results = test_all_canaries(model, tokenizer, merged_sources)
|
| 56 |
+
passed_canaries = sum(1 for v in canary_results.values() if v)
|
| 57 |
+
total_canaries = len(canary_results)
|
| 58 |
+
results["canary"] = {
|
| 59 |
+
"passed": passed_canaries,
|
| 60 |
+
"total": total_canaries,
|
| 61 |
+
"ok": passed_canaries >= cfg.canary_pass_threshold,
|
| 62 |
+
"details": canary_results,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# --- Test 2: Perplexity ---
|
| 66 |
+
perplexity = compute_perplexity(model, tokenizer)
|
| 67 |
+
ppl_ok = True
|
| 68 |
+
if baseline_perplexity is not None:
|
| 69 |
+
ratio = perplexity / baseline_perplexity
|
| 70 |
+
ppl_ok = ratio < cfg.perplexity_threshold
|
| 71 |
+
print(f"\n[validate] Perplexity: {perplexity:.2f} (baseline: {baseline_perplexity:.2f}, ratio: {ratio:.2f})")
|
| 72 |
+
if not ppl_ok:
|
| 73 |
+
print(f"[validate] β Perplexity ratio {ratio:.2f} exceeds threshold {cfg.perplexity_threshold}")
|
| 74 |
+
else:
|
| 75 |
+
print(f"\n[validate] Perplexity: {perplexity:.2f} (no baseline to compare)")
|
| 76 |
+
results["perplexity"] = {"value": perplexity, "ok": ppl_ok}
|
| 77 |
+
|
| 78 |
+
# --- Test 3: Thinking mode ---
|
| 79 |
+
think_ok = test_thinking_mode(model, tokenizer)
|
| 80 |
+
results["thinking_mode"] = {"ok": think_ok}
|
| 81 |
+
|
| 82 |
+
# --- Test 4: Quick reasoning ---
|
| 83 |
+
reason_ok = test_reasoning(model, tokenizer)
|
| 84 |
+
results["reasoning"] = {"ok": reason_ok}
|
| 85 |
+
|
| 86 |
+
# --- Overall verdict ---
|
| 87 |
+
all_ok = (
|
| 88 |
+
results["canary"]["ok"]
|
| 89 |
+
and results["perplexity"]["ok"]
|
| 90 |
+
and results["thinking_mode"]["ok"]
|
| 91 |
+
and results["reasoning"]["ok"]
|
| 92 |
+
)
|
| 93 |
+
results["overall"] = all_ok
|
| 94 |
+
|
| 95 |
+
# Summary
|
| 96 |
+
print("\n" + "-" * 60)
|
| 97 |
+
print("VALIDATION SUMMARY")
|
| 98 |
+
print("-" * 60)
|
| 99 |
+
print(f" Canary recall: {'β' if results['canary']['ok'] else 'β'} ({passed_canaries}/{total_canaries})")
|
| 100 |
+
print(f" Perplexity: {'β' if ppl_ok else 'β'} ({perplexity:.2f})")
|
| 101 |
+
print(f" Thinking mode: {'β' if think_ok else 'β'}")
|
| 102 |
+
print(f" Reasoning: {'β' if reason_ok else 'β'}")
|
| 103 |
+
print(f" OVERALL: {'β PASS' if all_ok else 'β FAIL β consider aborting'}")
|
| 104 |
+
print("-" * 60)
|
| 105 |
+
|
| 106 |
+
return results
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def compute_perplexity(
|
| 110 |
+
model: AutoModelForCausalLM,
|
| 111 |
+
tokenizer: AutoTokenizer,
|
| 112 |
+
test_texts: list[str] = None,
|
| 113 |
+
) -> float:
|
| 114 |
+
"""
|
| 115 |
+
Compute perplexity on a small test set.
|
| 116 |
+
|
| 117 |
+
Lower perplexity = model is more confident about predicting text.
|
| 118 |
+
A big spike after merging means the model was damaged.
|
| 119 |
+
"""
|
| 120 |
+
if test_texts is None:
|
| 121 |
+
test_texts = [
|
| 122 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 123 |
+
"In mathematics, a prime number is a natural number greater than 1.",
|
| 124 |
+
"def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
|
| 125 |
+
"The theory of general relativity describes gravity as the curvature of spacetime.",
|
| 126 |
+
"To solve 3x + 7 = 22, subtract 7 from both sides to get 3x = 15, then divide by 3.",
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
model.eval()
|
| 130 |
+
total_loss = 0.0
|
| 131 |
+
total_tokens = 0
|
| 132 |
+
|
| 133 |
+
for text in test_texts:
|
| 134 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
|
| 135 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 136 |
+
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
outputs = model(**inputs, labels=inputs["input_ids"])
|
| 139 |
+
total_loss += outputs.loss.item() * inputs["input_ids"].shape[1]
|
| 140 |
+
total_tokens += inputs["input_ids"].shape[1]
|
| 141 |
+
|
| 142 |
+
avg_loss = total_loss / total_tokens
|
| 143 |
+
perplexity = math.exp(avg_loss)
|
| 144 |
+
return perplexity
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def test_thinking_mode(
|
| 148 |
+
model: AutoModelForCausalLM,
|
| 149 |
+
tokenizer: AutoTokenizer,
|
| 150 |
+
) -> bool:
|
| 151 |
+
"""
|
| 152 |
+
Test if the model still uses <think> tags for reasoning.
|
| 153 |
+
|
| 154 |
+
The thinking mode is Qwen3's special feature β if it's gone,
|
| 155 |
+
the merge damaged something critical.
|
| 156 |
+
"""
|
| 157 |
+
prompt = "Solve step by step: What is 15 Γ 13?"
|
| 158 |
+
|
| 159 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
outputs = model.generate(
|
| 162 |
+
**inputs,
|
| 163 |
+
max_new_tokens=200,
|
| 164 |
+
temperature=0.7,
|
| 165 |
+
do_sample=True,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
| 169 |
+
|
| 170 |
+
# Check for thinking tags
|
| 171 |
+
has_think_open = "<think>" in response
|
| 172 |
+
has_think_close = "</think>" in response
|
| 173 |
+
passed = has_think_open and has_think_close
|
| 174 |
+
|
| 175 |
+
print(f"\n[validate] Thinking mode test:")
|
| 176 |
+
print(f" Prompt: {prompt}")
|
| 177 |
+
print(f" Response: {response[:200]}...")
|
| 178 |
+
print(f" <think>: {'β found' if has_think_open else 'β missing'}")
|
| 179 |
+
print(f" </think>: {'β found' if has_think_close else 'β missing'}")
|
| 180 |
+
print(f" Status: {'β PASS' if passed else 'β FAIL'}")
|
| 181 |
+
|
| 182 |
+
return passed
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def test_reasoning(
|
| 186 |
+
model: AutoModelForCausalLM,
|
| 187 |
+
tokenizer: AutoTokenizer,
|
| 188 |
+
) -> bool:
|
| 189 |
+
"""
|
| 190 |
+
Quick reasoning sanity check β can the model still do basic math?
|
| 191 |
+
|
| 192 |
+
This catches catastrophic failures where the merge produced gibberish.
|
| 193 |
+
"""
|
| 194 |
+
prompt = "What is 7 + 8?"
|
| 195 |
+
expected_answer = "15"
|
| 196 |
+
|
| 197 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 198 |
+
with torch.no_grad():
|
| 199 |
+
outputs = model.generate(
|
| 200 |
+
**inputs,
|
| 201 |
+
max_new_tokens=50,
|
| 202 |
+
temperature=0.1,
|
| 203 |
+
do_sample=False,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 207 |
+
passed = expected_answer in response
|
| 208 |
+
|
| 209 |
+
print(f"\n[validate] Quick reasoning test:")
|
| 210 |
+
print(f" Prompt: {prompt}")
|
| 211 |
+
print(f" Expected: {expected_answer}")
|
| 212 |
+
print(f" Got: {response}")
|
| 213 |
+
print(f" Status: {'β PASS' if passed else 'β FAIL'}")
|
| 214 |
+
|
| 215 |
+
return passed
|
hugging/td_lang/errors.py
CHANGED
|
@@ -88,6 +88,20 @@ COMMON_FIXES = {
|
|
| 88 |
"fuse": 'Format: fuse ["model1", "model2"] into target [strategy equal]',
|
| 89 |
"absorb": 'Format: absorb "model" into target [strength 0.5]',
|
| 90 |
"schedule": 'Format: schedule "every 6h" { commands... } or schedule "at 02:00" { ... }',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
}
|
| 92 |
|
| 93 |
|
|
|
|
| 88 |
"fuse": 'Format: fuse ["model1", "model2"] into target [strategy equal]',
|
| 89 |
"absorb": 'Format: absorb "model" into target [strength 0.5]',
|
| 90 |
"schedule": 'Format: schedule "every 6h" { commands... } or schedule "at 02:00" { ... }',
|
| 91 |
+
"download": 'Format: download "dataset_name" as alias [split train]',
|
| 92 |
+
"log": 'Format: log "output.txt" (place before commands to capture output)',
|
| 93 |
+
"compare": 'Format: compare target vs "source_model" [questions 50] [-> output.json]',
|
| 94 |
+
"verify": 'Format: verify target on "dataset" [questions 100] [-> output.json]',
|
| 95 |
+
"vote": 'Format: vote target "question" [samples 5] [-> output.json]',
|
| 96 |
+
"prompt": 'Format: prompt target "Think step by step before answering."',
|
| 97 |
+
"distill": 'Format: distill target into "small_model" [steps 200] [-> output_dir]',
|
| 98 |
+
"rollback": "Format: rollback target (reverts to most recent snapshot)",
|
| 99 |
+
"curriculum": 'Format: curriculum target on "dataset" using grpo [levels 3] [steps 64]',
|
| 100 |
+
"star": 'Format: star target on "dataset" [rounds 3] [samples 8]',
|
| 101 |
+
"best_of": 'Format: best_of target on "dataset" [n 8] [steps 32]',
|
| 102 |
+
"exploit": 'Format: exploit target on "dataset" [samples 16] [steps 32] [-> output.jsonl]',
|
| 103 |
+
"arena": 'Format: arena target on "dataset" [rounds 5] [episodes 50] [steps 64] [curiosity 0.3] [-> log.json]',
|
| 104 |
+
"research_arena": 'Format: research_arena target topic "subject" [sources "pubmed"|"web"|"arxiv"] [rounds 5] [episodes 30] [-> log.json]',
|
| 105 |
}
|
| 106 |
|
| 107 |
|
hugging/td_lang/examples/demo_arena.td
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# demo_arena.td β Real RL with memory, curiosity, and anti-lying
|
| 2 |
+
#
|
| 3 |
+
# This is ACTUAL reinforcement learning β the model explores challenges,
|
| 4 |
+
# gets immediate reward/punishment, remembers what worked, and trains
|
| 5 |
+
# on its experiences. Unlike best_of/star which just pick good examples,
|
| 6 |
+
# arena makes the model LEARN FROM CONSEQUENCES.
|
| 7 |
+
#
|
| 8 |
+
# Features:
|
| 9 |
+
# - Memory bank: remembers what worked across all rounds
|
| 10 |
+
# - Curiosity bonus: rewarded for trying NEW approaches
|
| 11 |
+
# - Lying punishment: -2.0 for confident wrong answers (worst offence)
|
| 12 |
+
# - Cross-check: creative solutions verified against standard approach
|
| 13 |
+
#
|
| 14 |
+
# The model won't "forget the button makes the door safe" because
|
| 15 |
+
# memory persists. And it won't lie because lying gets punished DOUBLE.
|
| 16 |
+
|
| 17 |
+
load "Qwen/Qwen3-8B" as base
|
| 18 |
+
|
| 19 |
+
# Run the arena: 3 rounds of 30 episodes each
|
| 20 |
+
# Curiosity weight 0.3 = moderate exploration bonus
|
| 21 |
+
arena base on "gsm8k" rounds 3 episodes 30 steps 32 curiosity 0.3 -> arena_log.json
|
| 22 |
+
|
| 23 |
+
# After arena training, evaluate the result
|
| 24 |
+
eval base -> arena_eval.json
|
| 25 |
+
|
| 26 |
+
# Save the improved model
|
| 27 |
+
snapshot base
|
| 28 |
+
commit base
|
hugging/td_lang/examples/demo_intelligence.td
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Demo: Phase 11 Intelligence β vote, prompt, distill, rollback
|
| 2 |
+
# Shows all 4 new commands + the upgraded mega-diagnose
|
| 3 |
+
|
| 4 |
+
load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 5 |
+
|
| 6 |
+
# Attach a chain-of-thought prompt (makes it think step by step)
|
| 7 |
+
prompt base "Think step by step before answering. Show your reasoning."
|
| 8 |
+
|
| 9 |
+
# Mega diagnose: self-diagnosis + domain profiling + layer speed
|
| 10 |
+
diagnose base -> diagnosis_report.json
|
| 11 |
+
|
| 12 |
+
# Merge in reasoning
|
| 13 |
+
merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
|
| 14 |
+
|
| 15 |
+
# Use majority voting on a hard question
|
| 16 |
+
vote base "What is 847 * 23? Show your work." samples 5 -> vote_result.json
|
| 17 |
+
|
| 18 |
+
# Snapshot before training (so rollback works)
|
| 19 |
+
snapshot base
|
| 20 |
+
|
| 21 |
+
# Train on weaknesses found by diagnose
|
| 22 |
+
train base on "gsm8k" using grpo steps 64
|
| 23 |
+
|
| 24 |
+
# Eval to check if training helped
|
| 25 |
+
eval base -> eval_after.json
|
| 26 |
+
|
| 27 |
+
# If training made things worse, undo it
|
| 28 |
+
if eval_passed base {
|
| 29 |
+
commit base
|
| 30 |
+
} else {
|
| 31 |
+
rollback base
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
# Create a fast student model for easy questions
|
| 35 |
+
distill base into "Qwen/Qwen3-1.7B" steps 100 -> student_model/
|
hugging/td_lang/examples/demo_research_arena.td
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# demo_research_arena.td β Real RL on ANY topic using real-world sources
|
| 2 |
+
#
|
| 3 |
+
# This is the research gauntlet. The model gets thrown into a maze
|
| 4 |
+
# built from REAL papers and knowledge. It has to navigate perfectly.
|
| 5 |
+
#
|
| 6 |
+
# How it works:
|
| 7 |
+
# 1. Pulls real papers about your topic (PubMed, arXiv, web, or local files)
|
| 8 |
+
# 2. Extracts verifiable facts from those papers
|
| 9 |
+
# 3. Builds increasingly hard questions from the real knowledge
|
| 10 |
+
# 4. Model must answer correctly β EVERY claim checked against sources
|
| 11 |
+
# 5. Difficulty ESCALATES each round (stricter checking, harder questions)
|
| 12 |
+
# 6. Memory persists β model remembers what it learned
|
| 13 |
+
# 7. Lying = double punishment, curiosity = bonus
|
| 14 |
+
#
|
| 15 |
+
# The maze shrinks each round:
|
| 16 |
+
# Round 1: Easy questions, 30% strictness, full path width
|
| 17 |
+
# Round 2: Medium questions, 55% strictness, 75% path width
|
| 18 |
+
# Round 3: Hard questions, 80% strictness, 50% path width
|
| 19 |
+
# ...and so on. Miss a single fact = punishment.
|
| 20 |
+
|
| 21 |
+
load "Qwen/Qwen3-8B" as base
|
| 22 |
+
|
| 23 |
+
# Example 1: Medical research (uses PubMed for real papers)
|
| 24 |
+
research_arena base topic "cancer immunotherapy mechanisms" sources "pubmed" rounds 4 episodes 25 steps 48 curiosity 0.3 difficulty_scale 0.25 -> research_log.json
|
| 25 |
+
|
| 26 |
+
# After the gauntlet, see how the model performs
|
| 27 |
+
eval base -> post_research_eval.json
|
| 28 |
+
snapshot base
|
| 29 |
+
commit base
|
hugging/td_lang/examples/demo_rl.td
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Demo: Phase 12 RL & Fine-Tuning β curriculum, star, best_of, exploit
|
| 2 |
+
# Shows all 4 new training methods + reward_contract wiring
|
| 3 |
+
|
| 4 |
+
load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 5 |
+
|
| 6 |
+
# Define what counts as "correct" (these verifiers wire into GRPO training)
|
| 7 |
+
reward_contract {
|
| 8 |
+
verifiers = [code_compiles, math_correct, no_hallucination]
|
| 9 |
+
min_reward = 0.3
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
# Step 1: Curriculum training β start easy, get harder
|
| 13 |
+
curriculum base on "gsm8k" using grpo levels 3 steps 64
|
| 14 |
+
|
| 15 |
+
# Step 2: STaR β learn from own correct reasoning chains
|
| 16 |
+
star base on "gsm8k" rounds 3 samples 8
|
| 17 |
+
|
| 18 |
+
# Step 3: Best-of-N β generate 8 answers per question, train on the best
|
| 19 |
+
best_of base on "openai/humaneval" n 8 steps 32
|
| 20 |
+
|
| 21 |
+
# Step 4: EXPLOIT β controlled reward hacking
|
| 22 |
+
# Generate 16 diverse solutions per problem, keep ALL correct ones
|
| 23 |
+
# Even ugly shortcuts β if the answer is right, the method is valid
|
| 24 |
+
exploit base on "gsm8k" samples 16 steps 32 -> exploit_results.jsonl
|
| 25 |
+
|
| 26 |
+
# Verify the model actually got smarter
|
| 27 |
+
eval base -> eval_after_rl.json
|
| 28 |
+
|
| 29 |
+
# Save if good
|
| 30 |
+
snapshot base
|
| 31 |
+
commit base
|
hugging/td_lang/examples/demo_toolbox.td
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Demo: Phase 10 Toolbox β download, log, compare, verify
|
| 2 |
+
# Shows all 4 new commands working together
|
| 3 |
+
|
| 4 |
+
log "toolbox_run.txt"
|
| 5 |
+
|
| 6 |
+
load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 7 |
+
|
| 8 |
+
# Download a dataset for verification
|
| 9 |
+
download "gsm8k" as math_data
|
| 10 |
+
download "openai/humaneval" as code_data split test
|
| 11 |
+
|
| 12 |
+
# Merge in reasoning ability
|
| 13 |
+
merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
|
| 14 |
+
|
| 15 |
+
# Compare: does the merged model remember what DeepSeek knew?
|
| 16 |
+
compare base vs "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" questions 30 -> compare_results.json
|
| 17 |
+
|
| 18 |
+
# Verify: are the answers actually correct?
|
| 19 |
+
verify base on "gsm8k" questions 50 -> verify_math.json
|
| 20 |
+
verify base on "openai/humaneval" questions 25 -> verify_code.json
|
| 21 |
+
|
| 22 |
+
# Eval and commit if good
|
| 23 |
+
eval base -> eval_report.json
|
| 24 |
+
commit base
|
hugging/td_lang/grammar.py
CHANGED
|
@@ -15,6 +15,7 @@ from .ast_nodes import (
|
|
| 15 |
DataContractBlock,
|
| 16 |
DebateCmd,
|
| 17 |
DiagnoseCmd,
|
|
|
|
| 18 |
EditCmd,
|
| 19 |
EvalCmd,
|
| 20 |
FuseCmd,
|
|
@@ -26,13 +27,26 @@ from .ast_nodes import (
|
|
| 26 |
MergeCmd,
|
| 27 |
NotifyCmd,
|
| 28 |
OnErrorBlock,
|
|
|
|
| 29 |
PruneCmd,
|
| 30 |
RepeatBlock,
|
| 31 |
ReportCmd,
|
| 32 |
ResetCmd,
|
| 33 |
RewardContractBlock,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
SaveCmd,
|
| 35 |
ScheduleCmd,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
SetupBlock,
|
| 37 |
SnapshotCmd,
|
| 38 |
SynthCmd,
|
|
@@ -80,6 +94,20 @@ TD_GRAMMAR = r"""
|
|
| 80 |
| setup_block
|
| 81 |
| on_error_block
|
| 82 |
| schedule_cmd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
// ======================== PHASE 1 COMMANDS ========================
|
| 85 |
|
|
@@ -153,7 +181,11 @@ TD_GRAMMAR = r"""
|
|
| 153 |
| fork_cmd | reset_cmd | prune_cmd | edit_cmd
|
| 154 |
| fuse_cmd | absorb_cmd | snapshot_cmd | report_cmd
|
| 155 |
| notify_cmd | save_cmd
|
| 156 |
-
| repeat_block_cmd | if_block_cmd | schedule_cmd
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
// ======================== PHASE 6 β EASY MERGE COMMANDS ========================
|
| 159 |
|
|
@@ -233,6 +265,87 @@ TD_GRAMMAR = r"""
|
|
| 233 |
// schedule "after 30m" { commands... }
|
| 234 |
schedule_cmd: "schedule" string "{" _NL* body_cmd+ _NL* "}"
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
// ======================== SHARED RULES ========================
|
| 237 |
|
| 238 |
// List of names: [name1, name2, name3]
|
|
@@ -468,6 +581,251 @@ class TDTransformer(Transformer):
|
|
| 468 |
def schedule_cmd(self, timing: str, *body_cmds) -> ScheduleCmd:
|
| 469 |
return ScheduleCmd(timing=timing, body=list(body_cmds))
|
| 470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
# --- Phase 6: Easy Merge Commands ---
|
| 472 |
|
| 473 |
def fuse_cmd(self, sources: list[str], target: str, *opts) -> FuseCmd:
|
|
@@ -688,6 +1046,8 @@ class TDTransformer(Transformer):
|
|
| 688 |
program.setup = item
|
| 689 |
elif isinstance(item, OnErrorBlock):
|
| 690 |
program.on_error = item
|
|
|
|
|
|
|
| 691 |
else:
|
| 692 |
program.commands.append(item)
|
| 693 |
return program
|
|
|
|
| 15 |
DataContractBlock,
|
| 16 |
DebateCmd,
|
| 17 |
DiagnoseCmd,
|
| 18 |
+
DistillCmd,
|
| 19 |
EditCmd,
|
| 20 |
EvalCmd,
|
| 21 |
FuseCmd,
|
|
|
|
| 27 |
MergeCmd,
|
| 28 |
NotifyCmd,
|
| 29 |
OnErrorBlock,
|
| 30 |
+
PromptBlock,
|
| 31 |
PruneCmd,
|
| 32 |
RepeatBlock,
|
| 33 |
ReportCmd,
|
| 34 |
ResetCmd,
|
| 35 |
RewardContractBlock,
|
| 36 |
+
RollbackCmd,
|
| 37 |
+
CurriculumCmd,
|
| 38 |
+
StarCmd,
|
| 39 |
+
BestOfCmd,
|
| 40 |
+
ExploitCmd,
|
| 41 |
+
ArenaCmd,
|
| 42 |
+
ResearchArenaCmd,
|
| 43 |
SaveCmd,
|
| 44 |
ScheduleCmd,
|
| 45 |
+
DownloadCmd,
|
| 46 |
+
LogBlock,
|
| 47 |
+
CompareCmd,
|
| 48 |
+
VerifyCmd,
|
| 49 |
+
VoteCmd,
|
| 50 |
SetupBlock,
|
| 51 |
SnapshotCmd,
|
| 52 |
SynthCmd,
|
|
|
|
| 94 |
| setup_block
|
| 95 |
| on_error_block
|
| 96 |
| schedule_cmd
|
| 97 |
+
| download_cmd
|
| 98 |
+
| log_block
|
| 99 |
+
| compare_cmd
|
| 100 |
+
| verify_cmd
|
| 101 |
+
| vote_cmd
|
| 102 |
+
| prompt_cmd
|
| 103 |
+
| distill_cmd
|
| 104 |
+
| rollback_cmd
|
| 105 |
+
| curriculum_cmd
|
| 106 |
+
| star_cmd
|
| 107 |
+
| best_of_cmd
|
| 108 |
+
| exploit_cmd
|
| 109 |
+
| arena_cmd
|
| 110 |
+
| research_arena_cmd
|
| 111 |
|
| 112 |
// ======================== PHASE 1 COMMANDS ========================
|
| 113 |
|
|
|
|
| 181 |
| fork_cmd | reset_cmd | prune_cmd | edit_cmd
|
| 182 |
| fuse_cmd | absorb_cmd | snapshot_cmd | report_cmd
|
| 183 |
| notify_cmd | save_cmd
|
| 184 |
+
| repeat_block_cmd | if_block_cmd | schedule_cmd
|
| 185 |
+
| download_cmd | compare_cmd | verify_cmd
|
| 186 |
+
| vote_cmd | prompt_cmd | distill_cmd | rollback_cmd
|
| 187 |
+
| curriculum_cmd | star_cmd | best_of_cmd | exploit_cmd
|
| 188 |
+
| arena_cmd | research_arena_cmd) _NL*
|
| 189 |
|
| 190 |
// ======================== PHASE 6 β EASY MERGE COMMANDS ========================
|
| 191 |
|
|
|
|
| 265 |
// schedule "after 30m" { commands... }
|
| 266 |
schedule_cmd: "schedule" string "{" _NL* body_cmd+ _NL* "}"
|
| 267 |
|
| 268 |
+
// ======================== PHASE 10 - TOOLBOX ========================
|
| 269 |
+
|
| 270 |
+
// download "gsm8k" as math_data [split train]
|
| 271 |
+
download_cmd: "download" string "as" IDENT (download_split)?
|
| 272 |
+
download_split: "split" IDENT
|
| 273 |
+
|
| 274 |
+
// log "training_log.txt"
|
| 275 |
+
log_block: "log" string
|
| 276 |
+
|
| 277 |
+
// compare target vs "source_model" [questions 50] [-> output.json]
|
| 278 |
+
compare_cmd: "compare" IDENT "vs" string (compare_questions)? (compare_output)?
|
| 279 |
+
compare_questions: "questions" INT
|
| 280 |
+
compare_output: "->" FILEPATH
|
| 281 |
+
|
| 282 |
+
// verify target on "dataset" [questions 100] [-> results.json]
|
| 283 |
+
verify_cmd: "verify" IDENT "on" string (verify_questions)? (verify_output)?
|
| 284 |
+
verify_questions: "questions" INT
|
| 285 |
+
verify_output: "->" FILEPATH
|
| 286 |
+
|
| 287 |
+
// ======================== PHASE 11 - INTELLIGENCE ========================
|
| 288 |
+
|
| 289 |
+
// vote target "question" [samples 5] [-> output.json]
|
| 290 |
+
vote_cmd: "vote" IDENT string (vote_samples)? (vote_output)?
|
| 291 |
+
vote_samples: "samples" INT
|
| 292 |
+
vote_output: "->" FILEPATH
|
| 293 |
+
|
| 294 |
+
// prompt target "system prompt text"
|
| 295 |
+
prompt_cmd: "prompt" IDENT string
|
| 296 |
+
|
| 297 |
+
// distill target into "small_model" [steps 200] [-> output_dir]
|
| 298 |
+
distill_cmd: "distill" IDENT "into" string (distill_steps)? (distill_output)?
|
| 299 |
+
distill_steps: "steps" INT
|
| 300 |
+
distill_output: "->" FILEPATH
|
| 301 |
+
|
| 302 |
+
// rollback target
|
| 303 |
+
rollback_cmd: "rollback" IDENT
|
| 304 |
+
|
| 305 |
+
// ======================== PHASE 12 - RL & FINE-TUNING ========================
|
| 306 |
+
|
| 307 |
+
// curriculum target on "dataset" using method [levels 3] [steps 64]
|
| 308 |
+
curriculum_cmd: "curriculum" IDENT "on" string "using" IDENT (curriculum_opt)*
|
| 309 |
+
curriculum_opt: "levels" INT -> curriculum_levels
|
| 310 |
+
| "steps" INT -> curriculum_steps
|
| 311 |
+
|
| 312 |
+
// star target on "dataset" [rounds 3] [samples 8]
|
| 313 |
+
star_cmd: "star" IDENT "on" string (star_opt)*
|
| 314 |
+
star_opt: "rounds" INT -> star_rounds
|
| 315 |
+
| "samples" INT -> star_samples
|
| 316 |
+
|
| 317 |
+
// best_of target on "dataset" [n 8] [steps 32]
|
| 318 |
+
best_of_cmd: "best_of" IDENT "on" string (best_of_opt)*
|
| 319 |
+
best_of_opt: "n" INT -> best_of_n
|
| 320 |
+
| "steps" INT -> best_of_steps
|
| 321 |
+
|
| 322 |
+
// exploit target on "dataset" [samples 16] [steps 32] [-> output.jsonl]
|
| 323 |
+
exploit_cmd: "exploit" IDENT "on" string (exploit_opt)*
|
| 324 |
+
exploit_opt: "samples" INT -> exploit_samples
|
| 325 |
+
| "steps" INT -> exploit_steps
|
| 326 |
+
| "->" FILEPATH -> exploit_output
|
| 327 |
+
|
| 328 |
+
// ======================== PHASE 13 - REAL RL (ARENA) ========================
|
| 329 |
+
|
| 330 |
+
// arena target on "dataset" [rounds 5] [episodes 50] [steps 64] [curiosity 0.3] [-> log.json]
|
| 331 |
+
arena_cmd: "arena" IDENT "on" string (arena_opt)*
|
| 332 |
+
arena_opt: "rounds" INT -> arena_rounds
|
| 333 |
+
| "episodes" INT -> arena_episodes
|
| 334 |
+
| "steps" INT -> arena_steps
|
| 335 |
+
| "curiosity" NUMBER -> arena_curiosity
|
| 336 |
+
| "->" FILEPATH -> arena_output
|
| 337 |
+
|
| 338 |
+
// research_arena target topic "subject" [sources "web"|"pubmed"|"arxiv"|path]
|
| 339 |
+
// [rounds 5] [episodes 30] [steps 64] [curiosity 0.3] [difficulty_scale 0.25] [-> log.json]
|
| 340 |
+
research_arena_cmd: "research_arena" IDENT "topic" string (ra_opt)*
|
| 341 |
+
ra_opt: "sources" string -> ra_sources
|
| 342 |
+
| "rounds" INT -> ra_rounds
|
| 343 |
+
| "episodes" INT -> ra_episodes
|
| 344 |
+
| "steps" INT -> ra_steps
|
| 345 |
+
| "curiosity" NUMBER -> ra_curiosity
|
| 346 |
+
| "difficulty_scale" NUMBER -> ra_difficulty
|
| 347 |
+
| "->" FILEPATH -> ra_output
|
| 348 |
+
|
| 349 |
// ======================== SHARED RULES ========================
|
| 350 |
|
| 351 |
// List of names: [name1, name2, name3]
|
|
|
|
| 581 |
def schedule_cmd(self, timing: str, *body_cmds) -> ScheduleCmd:
|
| 582 |
return ScheduleCmd(timing=timing, body=list(body_cmds))
|
| 583 |
|
| 584 |
+
# --- Phase 10: Toolbox ---
|
| 585 |
+
|
| 586 |
+
def download_cmd(self, dataset: str, alias: str, split: str | None = None) -> DownloadCmd:
|
| 587 |
+
cmd = DownloadCmd(dataset=dataset, alias=alias)
|
| 588 |
+
if isinstance(split, tuple) and split[0] == "split":
|
| 589 |
+
cmd.split = split[1]
|
| 590 |
+
elif isinstance(split, str):
|
| 591 |
+
cmd.split = split
|
| 592 |
+
return cmd
|
| 593 |
+
|
| 594 |
+
def download_split(self, value: str) -> tuple:
|
| 595 |
+
return ("split", value)
|
| 596 |
+
|
| 597 |
+
def log_block(self, filepath: str) -> LogBlock:
|
| 598 |
+
return LogBlock(filepath=filepath)
|
| 599 |
+
|
| 600 |
+
def compare_cmd(self, target: str, source: str, *opts) -> CompareCmd:
|
| 601 |
+
cmd = CompareCmd(target=target, source=source)
|
| 602 |
+
for opt in opts:
|
| 603 |
+
if isinstance(opt, tuple):
|
| 604 |
+
key, val = opt
|
| 605 |
+
if key == "questions":
|
| 606 |
+
cmd.questions = val
|
| 607 |
+
elif key == "output":
|
| 608 |
+
cmd.output = val
|
| 609 |
+
return cmd
|
| 610 |
+
|
| 611 |
+
def compare_questions(self, value: int) -> tuple:
|
| 612 |
+
return ("questions", value)
|
| 613 |
+
|
| 614 |
+
def compare_output(self, filepath: str) -> tuple:
|
| 615 |
+
return ("output", filepath)
|
| 616 |
+
|
| 617 |
+
def verify_cmd(self, target: str, dataset: str, *opts) -> VerifyCmd:
|
| 618 |
+
cmd = VerifyCmd(target=target, dataset=dataset)
|
| 619 |
+
for opt in opts:
|
| 620 |
+
if isinstance(opt, tuple):
|
| 621 |
+
key, val = opt
|
| 622 |
+
if key == "questions":
|
| 623 |
+
cmd.questions = val
|
| 624 |
+
elif key == "output":
|
| 625 |
+
cmd.output = val
|
| 626 |
+
return cmd
|
| 627 |
+
|
| 628 |
+
def verify_questions(self, value: int) -> tuple:
|
| 629 |
+
return ("questions", value)
|
| 630 |
+
|
| 631 |
+
def verify_output(self, filepath: str) -> tuple:
|
| 632 |
+
return ("output", filepath)
|
| 633 |
+
|
| 634 |
+
# --- Phase 11: Intelligence Commands ---
|
| 635 |
+
|
| 636 |
+
def vote_cmd(self, target: str, question: str, *opts) -> VoteCmd:
|
| 637 |
+
cmd = VoteCmd(target=target, question=question)
|
| 638 |
+
for opt in opts:
|
| 639 |
+
if isinstance(opt, tuple):
|
| 640 |
+
key, val = opt
|
| 641 |
+
if key == "samples":
|
| 642 |
+
cmd.samples = val
|
| 643 |
+
elif key == "output":
|
| 644 |
+
cmd.output = val
|
| 645 |
+
return cmd
|
| 646 |
+
|
| 647 |
+
def vote_samples(self, value: int) -> tuple:
|
| 648 |
+
return ("samples", value)
|
| 649 |
+
|
| 650 |
+
def vote_output(self, filepath: str) -> tuple:
|
| 651 |
+
return ("output", filepath)
|
| 652 |
+
|
| 653 |
+
def prompt_cmd(self, target: str, text: str) -> PromptBlock:
|
| 654 |
+
return PromptBlock(target=target, text=text)
|
| 655 |
+
|
| 656 |
+
def distill_cmd(self, teacher: str, student: str, *opts) -> DistillCmd:
|
| 657 |
+
cmd = DistillCmd(teacher=teacher, student=student)
|
| 658 |
+
for opt in opts:
|
| 659 |
+
if isinstance(opt, tuple):
|
| 660 |
+
key, val = opt
|
| 661 |
+
if key == "steps":
|
| 662 |
+
cmd.steps = val
|
| 663 |
+
elif key == "output":
|
| 664 |
+
cmd.output = val
|
| 665 |
+
return cmd
|
| 666 |
+
|
| 667 |
+
def distill_steps(self, value: int) -> tuple:
|
| 668 |
+
return ("steps", value)
|
| 669 |
+
|
| 670 |
+
def distill_output(self, filepath: str) -> tuple:
|
| 671 |
+
return ("output", filepath)
|
| 672 |
+
|
| 673 |
+
def rollback_cmd(self, target: str) -> RollbackCmd:
|
| 674 |
+
return RollbackCmd(target=target)
|
| 675 |
+
|
| 676 |
+
# --- Phase 12: RL & Fine-Tuning Commands ---
|
| 677 |
+
|
| 678 |
+
def curriculum_cmd(self, target: str, dataset: str, method: str, *opts) -> CurriculumCmd:
|
| 679 |
+
cmd = CurriculumCmd(target=target, dataset=dataset, method=method)
|
| 680 |
+
for opt in opts:
|
| 681 |
+
if isinstance(opt, tuple):
|
| 682 |
+
key, val = opt
|
| 683 |
+
if key == "levels":
|
| 684 |
+
cmd.levels = val
|
| 685 |
+
elif key == "steps":
|
| 686 |
+
cmd.steps = val
|
| 687 |
+
return cmd
|
| 688 |
+
|
| 689 |
+
def curriculum_levels(self, value: int) -> tuple:
|
| 690 |
+
return ("levels", value)
|
| 691 |
+
|
| 692 |
+
def curriculum_steps(self, value: int) -> tuple:
|
| 693 |
+
return ("steps", value)
|
| 694 |
+
|
| 695 |
+
def star_cmd(self, target: str, dataset: str, *opts) -> StarCmd:
|
| 696 |
+
cmd = StarCmd(target=target, dataset=dataset)
|
| 697 |
+
for opt in opts:
|
| 698 |
+
if isinstance(opt, tuple):
|
| 699 |
+
key, val = opt
|
| 700 |
+
if key == "rounds":
|
| 701 |
+
cmd.rounds = val
|
| 702 |
+
elif key == "samples":
|
| 703 |
+
cmd.samples = val
|
| 704 |
+
return cmd
|
| 705 |
+
|
| 706 |
+
def star_rounds(self, value: int) -> tuple:
|
| 707 |
+
return ("rounds", value)
|
| 708 |
+
|
| 709 |
+
def star_samples(self, value: int) -> tuple:
|
| 710 |
+
return ("samples", value)
|
| 711 |
+
|
| 712 |
+
def best_of_cmd(self, target: str, dataset: str, *opts) -> BestOfCmd:
|
| 713 |
+
cmd = BestOfCmd(target=target, dataset=dataset)
|
| 714 |
+
for opt in opts:
|
| 715 |
+
if isinstance(opt, tuple):
|
| 716 |
+
key, val = opt
|
| 717 |
+
if key == "n":
|
| 718 |
+
cmd.n = val
|
| 719 |
+
elif key == "steps":
|
| 720 |
+
cmd.steps = val
|
| 721 |
+
return cmd
|
| 722 |
+
|
| 723 |
+
def best_of_n(self, value: int) -> tuple:
|
| 724 |
+
return ("n", value)
|
| 725 |
+
|
| 726 |
+
def best_of_steps(self, value: int) -> tuple:
|
| 727 |
+
return ("steps", value)
|
| 728 |
+
|
| 729 |
+
def exploit_cmd(self, target: str, dataset: str, *opts) -> ExploitCmd:
|
| 730 |
+
cmd = ExploitCmd(target=target, dataset=dataset)
|
| 731 |
+
for opt in opts:
|
| 732 |
+
if isinstance(opt, tuple):
|
| 733 |
+
key, val = opt
|
| 734 |
+
if key == "samples":
|
| 735 |
+
cmd.samples = val
|
| 736 |
+
elif key == "steps":
|
| 737 |
+
cmd.steps = val
|
| 738 |
+
elif key == "output":
|
| 739 |
+
cmd.output = val
|
| 740 |
+
return cmd
|
| 741 |
+
|
| 742 |
+
def exploit_samples(self, value: int) -> tuple:
|
| 743 |
+
return ("samples", value)
|
| 744 |
+
|
| 745 |
+
def exploit_steps(self, value: int) -> tuple:
|
| 746 |
+
return ("steps", value)
|
| 747 |
+
|
| 748 |
+
def exploit_output(self, filepath: str) -> tuple:
|
| 749 |
+
return ("output", filepath)
|
| 750 |
+
|
| 751 |
+
# --- Phase 13: Real RL (Arena) ---
|
| 752 |
+
|
| 753 |
+
def arena_cmd(self, target: str, dataset: str, *opts) -> ArenaCmd:
|
| 754 |
+
cmd = ArenaCmd(target=target, dataset=dataset)
|
| 755 |
+
for opt in opts:
|
| 756 |
+
if isinstance(opt, tuple):
|
| 757 |
+
key, val = opt
|
| 758 |
+
if key == "rounds":
|
| 759 |
+
cmd.rounds = val
|
| 760 |
+
elif key == "episodes":
|
| 761 |
+
cmd.episodes = val
|
| 762 |
+
elif key == "steps":
|
| 763 |
+
cmd.steps = val
|
| 764 |
+
elif key == "curiosity":
|
| 765 |
+
cmd.curiosity = val
|
| 766 |
+
elif key == "output":
|
| 767 |
+
cmd.output = val
|
| 768 |
+
return cmd
|
| 769 |
+
|
| 770 |
+
def arena_rounds(self, value: int) -> tuple:
|
| 771 |
+
return ("rounds", value)
|
| 772 |
+
|
| 773 |
+
def arena_episodes(self, value: int) -> tuple:
|
| 774 |
+
return ("episodes", value)
|
| 775 |
+
|
| 776 |
+
def arena_steps(self, value: int) -> tuple:
|
| 777 |
+
return ("steps", value)
|
| 778 |
+
|
| 779 |
+
def arena_curiosity(self, value: float) -> tuple:
|
| 780 |
+
return ("curiosity", value)
|
| 781 |
+
|
| 782 |
+
def arena_output(self, filepath: str) -> tuple:
|
| 783 |
+
return ("output", filepath)
|
| 784 |
+
|
| 785 |
+
# --- Phase 13: Research Arena ---
|
| 786 |
+
|
| 787 |
+
def research_arena_cmd(self, target: str, topic: str, *opts) -> ResearchArenaCmd:
|
| 788 |
+
cmd = ResearchArenaCmd(target=target, topic=topic)
|
| 789 |
+
for opt in opts:
|
| 790 |
+
if isinstance(opt, tuple):
|
| 791 |
+
key, val = opt
|
| 792 |
+
if key == "sources":
|
| 793 |
+
cmd.sources = val
|
| 794 |
+
elif key == "rounds":
|
| 795 |
+
cmd.rounds = val
|
| 796 |
+
elif key == "episodes":
|
| 797 |
+
cmd.episodes = val
|
| 798 |
+
elif key == "steps":
|
| 799 |
+
cmd.steps = val
|
| 800 |
+
elif key == "curiosity":
|
| 801 |
+
cmd.curiosity = val
|
| 802 |
+
elif key == "difficulty_scale":
|
| 803 |
+
cmd.difficulty_scale = val
|
| 804 |
+
elif key == "output":
|
| 805 |
+
cmd.output = val
|
| 806 |
+
return cmd
|
| 807 |
+
|
| 808 |
+
def ra_sources(self, value: str) -> tuple:
|
| 809 |
+
return ("sources", value)
|
| 810 |
+
|
| 811 |
+
def ra_rounds(self, value: int) -> tuple:
|
| 812 |
+
return ("rounds", value)
|
| 813 |
+
|
| 814 |
+
def ra_episodes(self, value: int) -> tuple:
|
| 815 |
+
return ("episodes", value)
|
| 816 |
+
|
| 817 |
+
def ra_steps(self, value: int) -> tuple:
|
| 818 |
+
return ("steps", value)
|
| 819 |
+
|
| 820 |
+
def ra_curiosity(self, value: float) -> tuple:
|
| 821 |
+
return ("curiosity", value)
|
| 822 |
+
|
| 823 |
+
def ra_difficulty(self, value: float) -> tuple:
|
| 824 |
+
return ("difficulty_scale", value)
|
| 825 |
+
|
| 826 |
+
def ra_output(self, filepath: str) -> tuple:
|
| 827 |
+
return ("output", filepath)
|
| 828 |
+
|
| 829 |
# --- Phase 6: Easy Merge Commands ---
|
| 830 |
|
| 831 |
def fuse_cmd(self, sources: list[str], target: str, *opts) -> FuseCmd:
|
|
|
|
| 1046 |
program.setup = item
|
| 1047 |
elif isinstance(item, OnErrorBlock):
|
| 1048 |
program.on_error = item
|
| 1049 |
+
elif isinstance(item, LogBlock):
|
| 1050 |
+
program.log = item
|
| 1051 |
else:
|
| 1052 |
program.commands.append(item)
|
| 1053 |
return program
|
hugging/td_lang/td_lang/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
hugging/td_lang/td_lang/__init__.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Lang β Domain-specific language for Time Dilation project.
|
| 3 |
+
|
| 4 |
+
Compiles .td files into executable Python. Self-contained β no external deps.
|
| 5 |
+
The merge/heal/validate engine (formerly td_fuse) lives in td_lang.engine/.
|
| 6 |
+
|
| 7 |
+
Architecture:
|
| 8 |
+
td_lang/
|
| 9 |
+
βββ __init__.py <- This file
|
| 10 |
+
βββ __main__.py <- Entry point for python -m td_lang
|
| 11 |
+
βββ grammar.py <- Lark grammar + parse tree transformer
|
| 12 |
+
βββ ast_nodes.py <- Dataclass AST nodes for each command
|
| 13 |
+
βββ compiler.py <- AST -> Python code generation
|
| 14 |
+
βββ executor.py <- Run compiled code, track lineage
|
| 15 |
+
βββ cli.py <- Command-line interface
|
| 16 |
+
βββ errors.py <- Custom exceptions
|
| 17 |
+
βββ engine/ <- Merge/heal/validate runtime (was td_fuse)
|
| 18 |
+
β βββ config.py <- Model configs, merge order, hyperparameters
|
| 19 |
+
β βββ merge.py <- Sequential merge orchestrator
|
| 20 |
+
β βββ heal.py <- QLoRA healing fine-tune
|
| 21 |
+
β βββ validate.py <- Post-merge validation
|
| 22 |
+
β βββ transport.py <- Optimal transport wrapper
|
| 23 |
+
β βββ techniques.py <- ARM, OTMF, RAM, Theseus, Mergeability
|
| 24 |
+
β βββ canary.py <- Canary injection + testing
|
| 25 |
+
βββ examples/
|
| 26 |
+
βββ demo_merge.td <- Basic merge example
|
| 27 |
+
βββ demo_heal.td <- Merge + heal example
|
| 28 |
+
βββ demo_full.td <- Full pipeline with gates + budget
|
| 29 |
+
βββ ... <- 22 example .td files
|
| 30 |
+
|
| 31 |
+
Phase 1: load, merge, heal, eval, commit
|
| 32 |
+
Phase 2: diagnose, synth, train, debate
|
| 33 |
+
Phase 3: fork, reset, prune, edit
|
| 34 |
+
Phase 4: snapshot, report, data_contract, reward_contract
|
| 35 |
+
Phase 5: CLI polish, --version, info command, --verbose
|
| 36 |
+
Phase 6: fuse, absorb (easy merge)
|
| 37 |
+
Phase 7: repeat, if/else (loop control)
|
| 38 |
+
Phase 8: setup, on_error, notify, save (autopilot)
|
| 39 |
+
Phase 9: schedule (time-based execution)
|
| 40 |
+
Phase 10: download, log, compare, verify (toolbox)
|
| 41 |
+
Phase 11: vote, prompt, distill, rollback (intelligence)
|
| 42 |
+
Phase 12: curriculum, star, best_of, exploit (RL & fine-tuning)
|
| 43 |
+
Phase 13: arena (real RL with memory, curiosity, anti-lying, cross-check)
|
| 44 |
+
Engine upgrades: QLoRA training, self-contained eval, model-generated synth problems
|
| 45 |
+
Mega diagnose: self-diagnosis + domain profiling + layer speed testing
|
| 46 |
+
|
| 47 |
+
Designed from interviews test_14 (10 commands) and test_17 (ForgeSpec 2.0).
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
from .grammar import parse_td_file, parse_td_string # noqa: F401
|
| 51 |
+
from .compiler import compile_program # noqa: F401
|
| 52 |
+
from .executor import TDExecutor, check_td_file, compile_td_file, run_td_file # noqa: F401
|
| 53 |
+
|
| 54 |
+
__version__ = "0.2.0"
|
| 55 |
+
__author__ = "Milan (TD Project)"
|
| 56 |
+
|
| 57 |
+
__all__ = [
|
| 58 |
+
"parse_td_file",
|
| 59 |
+
"parse_td_string",
|
| 60 |
+
"compile_program",
|
| 61 |
+
"TDExecutor",
|
| 62 |
+
"check_td_file",
|
| 63 |
+
"compile_td_file",
|
| 64 |
+
"run_td_file",
|
| 65 |
+
"__version__",
|
| 66 |
+
"__author__",
|
| 67 |
+
]
|
hugging/td_lang/td_lang/__main__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Entry point for python -m td_lang."""
|
| 2 |
+
|
| 3 |
+
from .cli import main
|
| 4 |
+
|
| 5 |
+
main()
|
hugging/td_lang/td_lang/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (3.04 kB). View file
|
|
|
hugging/td_lang/td_lang/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (2.02 kB). View file
|
|
|
hugging/td_lang/td_lang/__pycache__/__main__.cpython-310.pyc
ADDED
|
Binary file (254 Bytes). View file
|
|
|
hugging/td_lang/td_lang/__pycache__/__main__.cpython-314.pyc
ADDED
|
Binary file (262 Bytes). View file
|
|
|
hugging/td_lang/td_lang/__pycache__/ast_nodes.cpython-310.pyc
ADDED
|
Binary file (21.3 kB). View file
|
|
|
hugging/td_lang/td_lang/__pycache__/ast_nodes.cpython-314.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
hugging/td_lang/td_lang/__pycache__/cli.cpython-310.pyc
ADDED
|
Binary file (6.82 kB). View file
|
|
|
hugging/td_lang/td_lang/__pycache__/cli.cpython-314.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
hugging/td_lang/td_lang/__pycache__/compiler.cpython-310.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:61328de4293774f4fc0e899e0eec00b64338be0dcf0fd3e68feaeaefc4c1edd5
|
| 3 |
+
size 193126
|
hugging/td_lang/td_lang/__pycache__/compiler.cpython-314.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8bef7388fef05cdd8ee4edcc72a4b8907c8637caa22cfc802da044470a515c92
|
| 3 |
+
size 162778
|
hugging/td_lang/td_lang/__pycache__/errors.cpython-310.pyc
ADDED
|
Binary file (5.59 kB). View file
|
|
|
hugging/td_lang/td_lang/__pycache__/errors.cpython-314.pyc
ADDED
|
Binary file (6.34 kB). View file
|
|
|
hugging/td_lang/td_lang/__pycache__/executor.cpython-310.pyc
ADDED
|
Binary file (5.94 kB). View file
|
|
|
hugging/td_lang/td_lang/__pycache__/executor.cpython-314.pyc
ADDED
|
Binary file (9.49 kB). View file
|
|
|
hugging/td_lang/td_lang/__pycache__/grammar.cpython-310.pyc
ADDED
|
Binary file (37.9 kB). View file
|
|
|
hugging/td_lang/td_lang/__pycache__/grammar.cpython-314.pyc
ADDED
|
Binary file (37.8 kB). View file
|
|
|
hugging/td_lang/td_lang/ast_nodes.py
ADDED
|
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Lang AST Nodes β Dataclass containers for each parsed command.
|
| 3 |
+
|
| 4 |
+
Each .td command becomes one of these nodes after parsing.
|
| 5 |
+
Phase 1 nodes are compiled into runnable Python; Phase 2 nodes are stubs so
|
| 6 |
+
the compiler can reject them with a clear error until they are implemented.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Any, List, Optional
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ============================================================================
|
| 14 |
+
# PHASE 1 COMMANDS
|
| 15 |
+
# ============================================================================
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class LoadCmd:
|
| 19 |
+
"""Load a model and give it a name.
|
| 20 |
+
|
| 21 |
+
Example: load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 22 |
+
"""
|
| 23 |
+
model_ref: str # HuggingFace path or local path
|
| 24 |
+
alias: str # Name to use in the rest of the script
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class MergeCmd:
|
| 29 |
+
"""Merge a source model into a target using a method.
|
| 30 |
+
|
| 31 |
+
Example: merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
|
| 32 |
+
"""
|
| 33 |
+
source: str # Model path or alias to merge from
|
| 34 |
+
target: str # Alias to merge into (must be loaded first)
|
| 35 |
+
method: str # "transport", "slerp", "ties", "dare"
|
| 36 |
+
strength: float = 0.5 # 0.0 = keep target, 1.0 = keep source
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class HealCmd:
|
| 41 |
+
"""Run QLoRA healing fine-tune on a model.
|
| 42 |
+
|
| 43 |
+
Example: heal base lora_r 32 epochs 2
|
| 44 |
+
"""
|
| 45 |
+
target: str # Alias of model to heal
|
| 46 |
+
lora_r: int = 32 # LoRA rank (higher = more capacity)
|
| 47 |
+
epochs: int = 2 # Training epochs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class EvalCmd:
|
| 52 |
+
"""Run validation/evaluation on a model.
|
| 53 |
+
|
| 54 |
+
Example: eval base on "pile_sample" -> report.json
|
| 55 |
+
"""
|
| 56 |
+
target: str # Alias of model to evaluate
|
| 57 |
+
dataset: Optional[str] = None # Optional dataset name/path
|
| 58 |
+
output: Optional[str] = None # Optional output file path
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class CommitCmd:
|
| 63 |
+
"""Save model checkpoint, optionally requiring gates to pass.
|
| 64 |
+
|
| 65 |
+
Example: commit base if [canary, perplexity, thinking_mode]
|
| 66 |
+
"""
|
| 67 |
+
target: str # Alias of model to commit
|
| 68 |
+
gates: Optional[list[str]] = None # Gate names that must pass
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ============================================================================
|
| 72 |
+
# PHASE 2 COMMANDS (placeholders β structure ready, not wired up yet)
|
| 73 |
+
# ============================================================================
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class SynthCmd:
|
| 77 |
+
"""Generate synthetic training data from a model. (Phase 2)"""
|
| 78 |
+
target: str
|
| 79 |
+
source: str
|
| 80 |
+
filter_method: Optional[str] = None
|
| 81 |
+
output: Optional[str] = None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclass
|
| 85 |
+
class TrainCmd:
|
| 86 |
+
"""Train a model on a dataset. (Phase 2)"""
|
| 87 |
+
target: str
|
| 88 |
+
dataset: str
|
| 89 |
+
method: str = "grpo" # "grpo", "sft", "dpo"
|
| 90 |
+
steps: Optional[int] = None
|
| 91 |
+
learning_rate: Optional[float] = None
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass
|
| 95 |
+
class DebateCmd:
|
| 96 |
+
"""Generate multi-answer debate for preference pairs. (Phase 2)"""
|
| 97 |
+
target: str
|
| 98 |
+
rounds: int = 3
|
| 99 |
+
candidates: int = 8
|
| 100 |
+
output: Optional[str] = None
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@dataclass
|
| 104 |
+
class DiagnoseCmd:
|
| 105 |
+
"""Ask model what it's bad at β self-diagnosis. (Phase 2)"""
|
| 106 |
+
target: str
|
| 107 |
+
output: Optional[str] = None
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@dataclass
|
| 111 |
+
class ForkCmd:
|
| 112 |
+
"""Branch current model weights for parallel experiments. (Phase 3)
|
| 113 |
+
|
| 114 |
+
Example: fork base as experiment_v2
|
| 115 |
+
Cheap fork: copies manifest + adapters, shares base weights (default).
|
| 116 |
+
"""
|
| 117 |
+
source: str # Alias of model to fork from
|
| 118 |
+
alias: str # Name for the new branch
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@dataclass
|
| 122 |
+
class ResetCmd:
|
| 123 |
+
"""Revert model to a previous checkpoint. (Phase 3)
|
| 124 |
+
|
| 125 |
+
Example: reset base to "checkpoint_042"
|
| 126 |
+
Deletes current model, clears CUDA cache, reloads from disk.
|
| 127 |
+
Must also reset optimizer state.
|
| 128 |
+
"""
|
| 129 |
+
target: str # Alias of model to reset
|
| 130 |
+
checkpoint: str # Checkpoint name/path to revert to
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@dataclass
|
| 134 |
+
class PruneCmd:
|
| 135 |
+
"""Structural pruning β remove low-utility neurons/heads. (Phase 3)
|
| 136 |
+
|
| 137 |
+
Example: prune base using wanda aggressiveness 0.2
|
| 138 |
+
Safe zone: ~20% max (LLM-Pruner paper). Language backbone only.
|
| 139 |
+
"""
|
| 140 |
+
target: str
|
| 141 |
+
method: str = "wanda" # "wanda", "magnitude", "taylor"
|
| 142 |
+
aggressiveness: float = 0.2 # Fraction to remove (0.0-1.0)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@dataclass
|
| 146 |
+
class EditCmd:
|
| 147 |
+
"""Surgical LoRA/DoRA editing on specific layers. (Phase 3)
|
| 148 |
+
|
| 149 |
+
Example: edit base layers 16-28 using lora lr 1e-4
|
| 150 |
+
"Try before buy": eval with adapter enabled vs disabled before merging.
|
| 151 |
+
"""
|
| 152 |
+
target: str
|
| 153 |
+
layers: str = "all" # "all", "16-28", single number
|
| 154 |
+
method: str = "lora" # "lora" or "dora"
|
| 155 |
+
learning_rate: Optional[float] = None
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ============================================================================
|
| 159 |
+
# PHASE 4 COMMANDS β Contracts, Lineage, Economics (ForgeSpec 2.0, test_17)
|
| 160 |
+
# ============================================================================
|
| 161 |
+
|
| 162 |
+
# ============================================================================
|
| 163 |
+
# PHASE 7 β LOOP CONTROL (repeat, if/else)
|
| 164 |
+
# ============================================================================
|
| 165 |
+
|
| 166 |
+
@dataclass
|
| 167 |
+
class RepeatBlock:
|
| 168 |
+
"""Repeat a block of commands N times. (Phase 7 β Loop Control)
|
| 169 |
+
|
| 170 |
+
Example:
|
| 171 |
+
repeat 5 {
|
| 172 |
+
diagnose base
|
| 173 |
+
synth base from base
|
| 174 |
+
train base on "data.jsonl" using grpo steps 64
|
| 175 |
+
eval base
|
| 176 |
+
}
|
| 177 |
+
"""
|
| 178 |
+
count: int # Number of iterations
|
| 179 |
+
body: List[Any] = field(default_factory=list) # Commands inside the block
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@dataclass
|
| 183 |
+
class IfBlock:
|
| 184 |
+
"""Conditional execution based on last eval result. (Phase 7 β Loop Control)
|
| 185 |
+
|
| 186 |
+
Example:
|
| 187 |
+
if eval_passed {
|
| 188 |
+
commit base
|
| 189 |
+
} else {
|
| 190 |
+
reset base to "last_good"
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
Condition checks the most recent eval result for the target.
|
| 194 |
+
"""
|
| 195 |
+
condition: str # "eval_passed", "gate_passed", etc.
|
| 196 |
+
target: Optional[str] = None # Which model's eval to check
|
| 197 |
+
then_body: List[Any] = field(default_factory=list)
|
| 198 |
+
else_body: List[Any] = field(default_factory=list)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@dataclass
|
| 202 |
+
class FuseCmd:
|
| 203 |
+
"""Fuse multiple models into a target in one shot. (Phase 6 β Easy Merge)
|
| 204 |
+
|
| 205 |
+
Example: fuse [deepseek-r1, mimo-7b, llama-3.1] into base
|
| 206 |
+
Auto-picks Transport and Merge, auto-sets per-model strength.
|
| 207 |
+
Handles cross-architecture merging (all 5 source models have different archs).
|
| 208 |
+
"""
|
| 209 |
+
sources: list[str] # List of model names/paths to fuse in
|
| 210 |
+
target: str # Alias to merge into (must be loaded)
|
| 211 |
+
method: str = "transport" # Default: transport and merge (cross-arch)
|
| 212 |
+
strategy: str = "equal" # "equal" (same strength each), "weighted", "sequential"
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
@dataclass
|
| 216 |
+
class AbsorbCmd:
|
| 217 |
+
"""Absorb a single model into target β simplified merge. (Phase 6 β Easy Merge)
|
| 218 |
+
|
| 219 |
+
Example: absorb "deepseek-ai/DeepSeek-R1" into base strength 0.5
|
| 220 |
+
One-liner for the common case of merging one model in.
|
| 221 |
+
"""
|
| 222 |
+
source: str # Model path or HF ID
|
| 223 |
+
target: str # Alias to merge into
|
| 224 |
+
strength: float = 0.5 # 0.0=keep target, 1.0=keep source, default balanced
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@dataclass
|
| 228 |
+
class SnapshotCmd:
|
| 229 |
+
"""Save a content-hashed snapshot of model state for lineage tracking. (Phase 4)
|
| 230 |
+
|
| 231 |
+
Example: snapshot base -> snapshots/
|
| 232 |
+
Creates a content-addressed directory: snapshots/<sha256_prefix>/
|
| 233 |
+
Contains: model state, adapter state, prune spec, eval report, manifest.
|
| 234 |
+
"""
|
| 235 |
+
target: str
|
| 236 |
+
output: Optional[str] = None # Output directory (default: td_lang_outputs/snapshots/)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@dataclass
|
| 240 |
+
class ReportCmd:
|
| 241 |
+
"""Generate an economics report for this run. (Phase 4)
|
| 242 |
+
|
| 243 |
+
Example: report -> economics.json
|
| 244 |
+
Tracks: GPU hours, cost estimate, tokens processed, experiments run,
|
| 245 |
+
time per command, cost breakdown by phase.
|
| 246 |
+
"""
|
| 247 |
+
output: Optional[str] = None # Output file path
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# ============================================================================
|
| 251 |
+
# PHASE 8 β AUTOPILOT (setup, notify, save, on_error, resume)
|
| 252 |
+
# ============================================================================
|
| 253 |
+
|
| 254 |
+
@dataclass
|
| 255 |
+
class NotifyCmd:
|
| 256 |
+
"""Send a notification via ntfy.sh. (Phase 8 β Autopilot)
|
| 257 |
+
|
| 258 |
+
Example: notify "Training complete!"
|
| 259 |
+
Uses curl to POST to the configured ntfy topic.
|
| 260 |
+
"""
|
| 261 |
+
message: str
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@dataclass
|
| 265 |
+
class SaveCmd:
|
| 266 |
+
"""Save/upload model to cloud storage via rclone. (Phase 8 β Autopilot)
|
| 267 |
+
|
| 268 |
+
Example: save base to "gdrive:TD/models/v1"
|
| 269 |
+
Uses rclone to copy model checkpoint to Google Drive (or any rclone remote).
|
| 270 |
+
"""
|
| 271 |
+
target: str # Alias of model to save
|
| 272 |
+
destination: str # rclone destination path
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@dataclass
|
| 276 |
+
class SetupBlock:
|
| 277 |
+
"""Auto-install dependencies and configure environment. (Phase 8 β Autopilot)
|
| 278 |
+
|
| 279 |
+
Example:
|
| 280 |
+
setup {
|
| 281 |
+
pip = [torch, transformers, peft, bitsandbytes, trl]
|
| 282 |
+
hf_token = env
|
| 283 |
+
notify = "ntfy.sh/my_ai"
|
| 284 |
+
}
|
| 285 |
+
"""
|
| 286 |
+
pip_packages: list[str] = field(default_factory=list)
|
| 287 |
+
hf_token: Optional[str] = None # "env" = read HF_TOKEN from env
|
| 288 |
+
notify_url: Optional[str] = None # ntfy.sh topic URL
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
@dataclass
|
| 292 |
+
class OnErrorBlock:
|
| 293 |
+
"""Crash recovery behavior. (Phase 8 β Autopilot)
|
| 294 |
+
|
| 295 |
+
Example:
|
| 296 |
+
on_error {
|
| 297 |
+
retry = 3
|
| 298 |
+
fallback = reduce_batch
|
| 299 |
+
notify = true
|
| 300 |
+
}
|
| 301 |
+
"""
|
| 302 |
+
retry: int = 3 # Number of retries per failed step
|
| 303 |
+
fallback: str = "reduce_batch" # "reduce_batch", "skip", "snapshot_and_stop"
|
| 304 |
+
notify: bool = True # Send ntfy notification on error
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# ============================================================================
|
| 308 |
+
# PHASE 9 β SCHEDULE (time-based execution)
|
| 309 |
+
# ============================================================================
|
| 310 |
+
|
| 311 |
+
@dataclass
|
| 312 |
+
class ScheduleCmd:
|
| 313 |
+
"""Schedule a block of commands to run at a specific time or interval. (Phase 9)
|
| 314 |
+
|
| 315 |
+
Examples:
|
| 316 |
+
schedule "every 6h" { diagnose base; train base ... }
|
| 317 |
+
schedule "at 02:00" { train base on "data.jsonl" using grpo }
|
| 318 |
+
schedule "after 30m" { eval base -> results.json }
|
| 319 |
+
|
| 320 |
+
Patterns:
|
| 321 |
+
"every Nh/Nm" β repeat every N hours/minutes
|
| 322 |
+
"at HH:MM" β run once at that time
|
| 323 |
+
"after Nh/Nm" β delay then run once
|
| 324 |
+
"""
|
| 325 |
+
timing: str # "every 6h", "at 02:00", "after 30m"
|
| 326 |
+
body: List[Any] = field(default_factory=list) # Commands inside the block
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# ============================================================================
|
| 330 |
+
# PHASE 10 - TOOLBOX (download, log, compare, verify)
|
| 331 |
+
# ============================================================================
|
| 332 |
+
|
| 333 |
+
@dataclass
|
| 334 |
+
class DownloadCmd:
|
| 335 |
+
"""Download a dataset from HuggingFace. (Phase 10)
|
| 336 |
+
|
| 337 |
+
Example: download "gsm8k" as math_data
|
| 338 |
+
Pulls a dataset from HuggingFace and stores it for training/eval.
|
| 339 |
+
"""
|
| 340 |
+
dataset: str # HuggingFace dataset path
|
| 341 |
+
alias: str # Name to reference it later
|
| 342 |
+
split: str = "train" # Which split to download
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
@dataclass
|
| 346 |
+
class LogBlock:
|
| 347 |
+
"""Save all pipeline output to a log file. (Phase 10)
|
| 348 |
+
|
| 349 |
+
Example: log "training_log.txt"
|
| 350 |
+
Everything printed to console also goes to this file.
|
| 351 |
+
"""
|
| 352 |
+
filepath: str # Path to save log
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
@dataclass
|
| 356 |
+
class CompareCmd:
|
| 357 |
+
"""Compare source model vs merged model - knowledge retention test. (Phase 10)
|
| 358 |
+
|
| 359 |
+
Example: compare base vs "deepseek-ai/DeepSeek-R1" questions 50
|
| 360 |
+
Tests both models on the same questions and shows what % the merged
|
| 361 |
+
model retained from the source. Proves the merge actually worked.
|
| 362 |
+
"""
|
| 363 |
+
target: str # The merged model alias
|
| 364 |
+
source: str # Source model to compare against (HF path)
|
| 365 |
+
questions: int = 50 # Number of test questions
|
| 366 |
+
output: Optional[str] = None # Optional output file
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
@dataclass
|
| 370 |
+
class VerifyCmd:
|
| 371 |
+
"""Verify model answers are actually correct. (Phase 10)
|
| 372 |
+
|
| 373 |
+
Example: verify base on "gsm8k" questions 100 -> verify_results.json
|
| 374 |
+
Runs the model on questions with KNOWN correct answers and checks
|
| 375 |
+
if the model got them right. Returns accuracy percentage.
|
| 376 |
+
"""
|
| 377 |
+
target: str # Model alias to test
|
| 378 |
+
dataset: str # Dataset with known answers
|
| 379 |
+
questions: int = 100 # Number of questions to test
|
| 380 |
+
output: Optional[str] = None # Optional output file
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
# ============================================================================
|
| 384 |
+
# PHASE 11 - INTELLIGENCE (vote, prompt, distill, rollback)
|
| 385 |
+
# ============================================================================
|
| 386 |
+
|
| 387 |
+
@dataclass
|
| 388 |
+
class VoteCmd:
|
| 389 |
+
"""Majority voting - generate N answers, pick the one most agree on. (Phase 11)
|
| 390 |
+
|
| 391 |
+
Example: vote base "What is 15 * 23?" samples 5
|
| 392 |
+
Generates N answers to the same question, then picks the most common one.
|
| 393 |
+
Proven to boost accuracy 10-20% with zero training.
|
| 394 |
+
"""
|
| 395 |
+
target: str # Model alias
|
| 396 |
+
question: str # Question to vote on
|
| 397 |
+
samples: int = 5 # Number of answers to generate
|
| 398 |
+
output: Optional[str] = None # Optional output file
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
@dataclass
|
| 402 |
+
class PromptBlock:
|
| 403 |
+
"""Attach a system prompt or chain-of-thought template to a model. (Phase 11)
|
| 404 |
+
|
| 405 |
+
Example:
|
| 406 |
+
prompt base "Think step by step before answering."
|
| 407 |
+
Makes the model use this system prompt for all future generations.
|
| 408 |
+
"""
|
| 409 |
+
target: str # Model alias to attach prompt to
|
| 410 |
+
text: str # The system prompt text
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
@dataclass
|
| 414 |
+
class DistillCmd:
|
| 415 |
+
"""Distill a big model's knowledge into a smaller one. (Phase 11)
|
| 416 |
+
|
| 417 |
+
Example: distill base into "Qwen/Qwen3-1.7B" steps 200 -> student_model/
|
| 418 |
+
Takes the big model's best answers and trains the small model on them.
|
| 419 |
+
You get a fast model for easy questions, full model for hard ones.
|
| 420 |
+
"""
|
| 421 |
+
teacher: str # The big model alias (source of knowledge)
|
| 422 |
+
student: str # The small model HF path
|
| 423 |
+
steps: int = 200 # Training steps
|
| 424 |
+
output: Optional[str] = None # Where to save the student model
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@dataclass
|
| 428 |
+
class RollbackCmd:
|
| 429 |
+
"""Undo the last training step. (Phase 11)
|
| 430 |
+
|
| 431 |
+
Example: rollback base
|
| 432 |
+
Reverts to the most recent snapshot. If training made things worse,
|
| 433 |
+
one command brings it back.
|
| 434 |
+
"""
|
| 435 |
+
target: str # Model alias to rollback
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
# ============================================================================
|
| 439 |
+
# PHASE 12 - RL & FINE-TUNING (curriculum, star, best_of, exploit)
|
| 440 |
+
# ============================================================================
|
| 441 |
+
|
| 442 |
+
@dataclass
|
| 443 |
+
class CurriculumCmd:
|
| 444 |
+
"""Progressive difficulty training - start easy, get harder. (Phase 12)
|
| 445 |
+
|
| 446 |
+
Example: curriculum base on "gsm8k" using grpo levels 3 steps 64
|
| 447 |
+
Splits dataset by difficulty, trains on easy first, then medium, then hard.
|
| 448 |
+
Each level only starts when the model passes the previous one.
|
| 449 |
+
"""
|
| 450 |
+
target: str # Model alias
|
| 451 |
+
dataset: str # Dataset to train on
|
| 452 |
+
method: str = "grpo" # Training method
|
| 453 |
+
levels: int = 3 # Number of difficulty levels
|
| 454 |
+
steps: int = 64 # Steps per level
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
@dataclass
|
| 458 |
+
class StarCmd:
|
| 459 |
+
"""Self-Taught Reasoner - train on own correct reasoning chains. (Phase 12)
|
| 460 |
+
|
| 461 |
+
Example: star base on "gsm8k" rounds 3 samples 8
|
| 462 |
+
Generate N solutions per problem. Keep the ones with correct answers.
|
| 463 |
+
Train on the correct reasoning chains. Repeat.
|
| 464 |
+
The model literally learns from its own successes.
|
| 465 |
+
"""
|
| 466 |
+
target: str # Model alias
|
| 467 |
+
dataset: str # Dataset with known answers
|
| 468 |
+
rounds: int = 3 # Number of STaR iterations
|
| 469 |
+
samples: int = 8 # Solutions to generate per problem
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
@dataclass
|
| 473 |
+
class BestOfCmd:
|
| 474 |
+
"""Generate N answers, score all, train on the best. (Phase 12)
|
| 475 |
+
|
| 476 |
+
Example: best_of base on "gsm8k" n 8 steps 32
|
| 477 |
+
For each training problem: generate N answers, score them all,
|
| 478 |
+
keep only the best one, train on that. Like vote but for training.
|
| 479 |
+
80-90% of RLHF gains at 5-30% of the cost (test_16).
|
| 480 |
+
"""
|
| 481 |
+
target: str # Model alias
|
| 482 |
+
dataset: str # Dataset to train on
|
| 483 |
+
n: int = 8 # How many answers to generate per problem
|
| 484 |
+
steps: int = 32 # Training steps on the filtered data
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
@dataclass
|
| 488 |
+
class ExploitCmd:
|
| 489 |
+
"""Controlled reward hacking - keep ALL correct solutions regardless of method. (Phase 12)
|
| 490 |
+
|
| 491 |
+
Example: exploit base on "gsm8k" samples 16 -> exploit_data.jsonl
|
| 492 |
+
Generate many diverse solutions (high temp). Only filter: is the answer correct?
|
| 493 |
+
Keep ugly solutions, shortcuts, weird reasoning - as long as the answer is right.
|
| 494 |
+
Train on the diverse set so the model learns multiple paths to correct answers.
|
| 495 |
+
The "hacks" often turn out to be genuinely clever shortcuts.
|
| 496 |
+
"""
|
| 497 |
+
target: str # Model alias
|
| 498 |
+
dataset: str # Dataset with verifiable answers
|
| 499 |
+
samples: int = 16 # Solutions per problem (higher = more diversity)
|
| 500 |
+
steps: int = 32 # Training steps on the exploited data
|
| 501 |
+
output: Optional[str] = None # Save the exploit data for inspection
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
@dataclass
|
| 505 |
+
class ArenaCmd:
|
| 506 |
+
"""Real RL with environment, memory, curiosity, and anti-lying. (Phase 13)
|
| 507 |
+
|
| 508 |
+
The model enters an arena of challenges. For each challenge:
|
| 509 |
+
1. It tries to solve it (exploration)
|
| 510 |
+
2. Gets immediate reward/punishment (+1 correct, -1 wrong, -2 lying)
|
| 511 |
+
3. Remembers what worked and didn't (memory bank persists across episodes)
|
| 512 |
+
4. Gets curiosity bonus for trying NEW approaches
|
| 513 |
+
5. Creative solutions get cross-checked against standard approaches
|
| 514 |
+
|
| 515 |
+
Example: arena base on "gsm8k" rounds 5 episodes 50 steps 64 curiosity 0.3
|
| 516 |
+
"""
|
| 517 |
+
target: str # Model alias
|
| 518 |
+
dataset: str # Dataset with verifiable answers
|
| 519 |
+
rounds: int = 5 # RL rounds (re-train after each)
|
| 520 |
+
episodes: int = 50 # Challenges per round
|
| 521 |
+
steps: int = 64 # Training steps per round
|
| 522 |
+
curiosity: float = 0.3 # Curiosity bonus weight
|
| 523 |
+
output: Optional[str] = None # Save arena log
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
@dataclass
|
| 527 |
+
class ResearchArenaCmd:
|
| 528 |
+
"""Research arena β RL on ANY topic using real-world knowledge. (Phase 13)
|
| 529 |
+
|
| 530 |
+
Unlike arena (which uses a pre-made dataset), research_arena:
|
| 531 |
+
1. Takes a TOPIC string ("cancer biology", "number theory", anything)
|
| 532 |
+
2. Pulls real papers/sources about that topic (web, arxiv, pubmed, local files)
|
| 533 |
+
3. Extracts verifiable facts/claims from those sources
|
| 534 |
+
4. Builds increasingly hard questions from the real knowledge
|
| 535 |
+
5. Runs the model through the gauntlet, checking EVERY claim against sources
|
| 536 |
+
6. Difficulty ESCALATES on failure (fewer hints, stricter checking, harder questions)
|
| 537 |
+
7. Memory persists so it doesn't forget what it learned
|
| 538 |
+
8. Lying gets punished DOUBLE, curiosity rewarded
|
| 539 |
+
|
| 540 |
+
Example: research_arena base topic "cancer biology" sources "pubmed" rounds 5
|
| 541 |
+
"""
|
| 542 |
+
target: str # Model alias
|
| 543 |
+
topic: str # Research topic (any field)
|
| 544 |
+
sources: str = "web" # Where to pull knowledge: "web", "pubmed", "arxiv", or filepath
|
| 545 |
+
rounds: int = 5 # RL rounds (difficulty increases each round)
|
| 546 |
+
episodes: int = 30 # Questions per round
|
| 547 |
+
steps: int = 64 # Training steps per round
|
| 548 |
+
curiosity: float = 0.3 # Curiosity bonus weight
|
| 549 |
+
difficulty_scale: float = 0.25 # How much harder each round gets (0.25 = 25% harder)
|
| 550 |
+
output: Optional[str] = None # Save research log
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
# ============================================================================
|
| 554 |
+
# BLOCKS (gates, budget, contracts, etc.)
|
| 555 |
+
# ============================================================================
|
| 556 |
+
|
| 557 |
+
@dataclass
|
| 558 |
+
class GateBlock:
|
| 559 |
+
"""Validation gates that must pass before commit.
|
| 560 |
+
|
| 561 |
+
Example:
|
| 562 |
+
gate {
|
| 563 |
+
must_pass = [canary, perplexity, thinking_mode]
|
| 564 |
+
}
|
| 565 |
+
"""
|
| 566 |
+
must_pass: list[str] = field(default_factory=list)
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
@dataclass
|
| 570 |
+
class BudgetBlock:
|
| 571 |
+
"""Resource budget β compiler refuses plans that exceed limits.
|
| 572 |
+
|
| 573 |
+
Example:
|
| 574 |
+
budget {
|
| 575 |
+
max_gpu_hours = 8
|
| 576 |
+
max_cost = 50.00
|
| 577 |
+
}
|
| 578 |
+
"""
|
| 579 |
+
max_gpu_hours: Optional[float] = None
|
| 580 |
+
max_cost: Optional[float] = None
|
| 581 |
+
max_tokens: Optional[int] = None
|
| 582 |
+
max_experiments: Optional[int] = None
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
@dataclass
|
| 586 |
+
class DataContractBlock:
|
| 587 |
+
"""Schema enforcement on training data. (Phase 4, ForgeSpec 2.0)
|
| 588 |
+
|
| 589 |
+
Example:
|
| 590 |
+
data_contract {
|
| 591 |
+
required_fields = [prompt, response]
|
| 592 |
+
min_samples = 100
|
| 593 |
+
max_perplexity = 50.0
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
Compiler checks training data at synth/train time.
|
| 597 |
+
"""
|
| 598 |
+
required_fields: list[str] = field(default_factory=list)
|
| 599 |
+
min_samples: Optional[int] = None
|
| 600 |
+
max_perplexity: Optional[float] = None
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
@dataclass
|
| 604 |
+
class RewardContractBlock:
|
| 605 |
+
"""Verified reward definitions β what counts as "correct". (Phase 4, ForgeSpec 2.0)
|
| 606 |
+
|
| 607 |
+
Example:
|
| 608 |
+
reward_contract {
|
| 609 |
+
verifiers = [code_compiles, math_correct, no_hallucination]
|
| 610 |
+
min_reward = 0.3
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
Used by train (GRPO) to enforce reward quality.
|
| 614 |
+
No learned reward model β verified rewards only (test_16).
|
| 615 |
+
"""
|
| 616 |
+
verifiers: list[str] = field(default_factory=list)
|
| 617 |
+
min_reward: Optional[float] = None
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
# ============================================================================
|
| 621 |
+
# TOP-LEVEL PROGRAM
|
| 622 |
+
# ============================================================================
|
| 623 |
+
|
| 624 |
+
@dataclass
|
| 625 |
+
class TDProgram:
|
| 626 |
+
"""A complete parsed .td file β commands in order plus global blocks."""
|
| 627 |
+
|
| 628 |
+
commands: List[Any] = field(default_factory=list)
|
| 629 |
+
gates: Optional[GateBlock] = None
|
| 630 |
+
budget: Optional[BudgetBlock] = None
|
| 631 |
+
data_contract: Optional[DataContractBlock] = None
|
| 632 |
+
reward_contract: Optional[RewardContractBlock] = None
|
| 633 |
+
setup: Optional[SetupBlock] = None
|
| 634 |
+
on_error: Optional[OnErrorBlock] = None
|
| 635 |
+
log: Optional[LogBlock] = None
|
| 636 |
+
source_file: Optional[str] = None
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
__all__ = [
|
| 640 |
+
"LoadCmd",
|
| 641 |
+
"MergeCmd",
|
| 642 |
+
"HealCmd",
|
| 643 |
+
"EvalCmd",
|
| 644 |
+
"CommitCmd",
|
| 645 |
+
"SynthCmd",
|
| 646 |
+
"TrainCmd",
|
| 647 |
+
"DebateCmd",
|
| 648 |
+
"DiagnoseCmd",
|
| 649 |
+
"ForkCmd",
|
| 650 |
+
"ResetCmd",
|
| 651 |
+
"PruneCmd",
|
| 652 |
+
"EditCmd",
|
| 653 |
+
"RepeatBlock",
|
| 654 |
+
"IfBlock",
|
| 655 |
+
"FuseCmd",
|
| 656 |
+
"AbsorbCmd",
|
| 657 |
+
"SnapshotCmd",
|
| 658 |
+
"ReportCmd",
|
| 659 |
+
"NotifyCmd",
|
| 660 |
+
"SaveCmd",
|
| 661 |
+
"SetupBlock",
|
| 662 |
+
"OnErrorBlock",
|
| 663 |
+
"GateBlock",
|
| 664 |
+
"BudgetBlock",
|
| 665 |
+
"DataContractBlock",
|
| 666 |
+
"RewardContractBlock",
|
| 667 |
+
"ScheduleCmd",
|
| 668 |
+
"DownloadCmd",
|
| 669 |
+
"LogBlock",
|
| 670 |
+
"CompareCmd",
|
| 671 |
+
"VerifyCmd",
|
| 672 |
+
"VoteCmd",
|
| 673 |
+
"PromptBlock",
|
| 674 |
+
"DistillCmd",
|
| 675 |
+
"RollbackCmd",
|
| 676 |
+
"CurriculumCmd",
|
| 677 |
+
"StarCmd",
|
| 678 |
+
"BestOfCmd",
|
| 679 |
+
"ExploitCmd",
|
| 680 |
+
"ArenaCmd",
|
| 681 |
+
"ResearchArenaCmd",
|
| 682 |
+
"TDProgram",
|
| 683 |
+
]
|
hugging/td_lang/td_lang/cli.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Lang CLI β Command-line interface for .td files.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python -m td_lang run examples/demo_merge.td # Compile + execute
|
| 6 |
+
python -m td_lang compile examples/demo_merge.td # Compile only (outputs .py)
|
| 7 |
+
python -m td_lang check examples/demo_merge.td # Syntax check only
|
| 8 |
+
python -m td_lang info examples/demo_merge.td # Show plan without compiling
|
| 9 |
+
python -m td_lang --version # Show version
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
from . import __version__
|
| 16 |
+
from .executor import TDExecutor
|
| 17 |
+
from .errors import TDLangError
|
| 18 |
+
from .grammar import parse_td_file
|
| 19 |
+
from .ast_nodes import (
|
| 20 |
+
LoadCmd, MergeCmd, HealCmd, EvalCmd, CommitCmd,
|
| 21 |
+
SynthCmd, TrainCmd, DebateCmd, DiagnoseCmd,
|
| 22 |
+
ForkCmd, ResetCmd, PruneCmd, EditCmd,
|
| 23 |
+
FuseCmd, AbsorbCmd, RepeatBlock, IfBlock,
|
| 24 |
+
NotifyCmd, SaveCmd, ScheduleCmd,
|
| 25 |
+
DownloadCmd, LogBlock, CompareCmd, VerifyCmd,
|
| 26 |
+
VoteCmd, PromptBlock, DistillCmd, RollbackCmd,
|
| 27 |
+
CurriculumCmd, StarCmd, BestOfCmd, ExploitCmd, ArenaCmd, ResearchArenaCmd,
|
| 28 |
+
SnapshotCmd, ReportCmd,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Phase labels for info command
|
| 33 |
+
_PHASE_MAP = {
|
| 34 |
+
LoadCmd: ("1", "load"),
|
| 35 |
+
MergeCmd: ("1", "merge"),
|
| 36 |
+
HealCmd: ("1", "heal"),
|
| 37 |
+
EvalCmd: ("1", "eval"),
|
| 38 |
+
CommitCmd: ("1", "commit"),
|
| 39 |
+
SynthCmd: ("2", "synth"),
|
| 40 |
+
TrainCmd: ("2", "train"),
|
| 41 |
+
DebateCmd: ("2", "debate"),
|
| 42 |
+
DiagnoseCmd: ("2", "diagnose"),
|
| 43 |
+
ForkCmd: ("3", "fork"),
|
| 44 |
+
ResetCmd: ("3", "reset"),
|
| 45 |
+
PruneCmd: ("3", "prune"),
|
| 46 |
+
EditCmd: ("3", "edit"),
|
| 47 |
+
FuseCmd: ("6", "fuse"),
|
| 48 |
+
AbsorbCmd: ("6", "absorb"),
|
| 49 |
+
RepeatBlock: ("7", "repeat"),
|
| 50 |
+
IfBlock: ("7", "if"),
|
| 51 |
+
NotifyCmd: ("8", "notify"),
|
| 52 |
+
SaveCmd: ("8", "save"),
|
| 53 |
+
SnapshotCmd: ("4", "snapshot"),
|
| 54 |
+
ReportCmd: ("4", "report"),
|
| 55 |
+
ScheduleCmd: ("9", "schedule"),
|
| 56 |
+
DownloadCmd: ("10", "download"),
|
| 57 |
+
CompareCmd: ("10", "compare"),
|
| 58 |
+
VerifyCmd: ("10", "verify"),
|
| 59 |
+
VoteCmd: ("11", "vote"),
|
| 60 |
+
PromptBlock: ("11", "prompt"),
|
| 61 |
+
DistillCmd: ("11", "distill"),
|
| 62 |
+
RollbackCmd: ("11", "rollback"),
|
| 63 |
+
CurriculumCmd: ("12", "curriculum"),
|
| 64 |
+
StarCmd: ("12", "star"),
|
| 65 |
+
BestOfCmd: ("12", "best_of"),
|
| 66 |
+
ExploitCmd: ("12", "exploit"),
|
| 67 |
+
ArenaCmd: ("13", "arena"),
|
| 68 |
+
ResearchArenaCmd: ("13", "research_arena"),
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def parse_args() -> argparse.Namespace:
|
| 73 |
+
"""Parse command-line arguments."""
|
| 74 |
+
parser = argparse.ArgumentParser(
|
| 75 |
+
description="TD Lang β compile and run .td files for Time Dilation",
|
| 76 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 77 |
+
epilog="""
|
| 78 |
+
Examples:
|
| 79 |
+
python -m td_lang check examples/demo_merge.td # Check syntax
|
| 80 |
+
python -m td_lang compile examples/demo_merge.td # Compile to .py
|
| 81 |
+
python -m td_lang run examples/demo_merge.td # Compile + run
|
| 82 |
+
python -m td_lang run examples/demo_merge.td --dry # Compile only
|
| 83 |
+
python -m td_lang info examples/demo_merge.td # Show plan summary
|
| 84 |
+
""",
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--version",
|
| 89 |
+
action="version",
|
| 90 |
+
version=f"td_lang {__version__}",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"action",
|
| 95 |
+
choices=["check", "compile", "run", "info"],
|
| 96 |
+
help="What to do: check (syntax), compile (.py), run (compile+execute), info (show plan)",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"file",
|
| 101 |
+
type=str,
|
| 102 |
+
help="Path to the .td file",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--output",
|
| 107 |
+
type=str,
|
| 108 |
+
default="td_lang_outputs",
|
| 109 |
+
help="Output directory (default: td_lang_outputs)",
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--dry",
|
| 114 |
+
action="store_true",
|
| 115 |
+
help="With 'run': compile but don't execute",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--verbose", "-v",
|
| 120 |
+
action="store_true",
|
| 121 |
+
help="Show extra detail (compiled Python, full AST, etc.)",
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
return parser.parse_args()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def print_banner():
|
| 128 |
+
"""Print the td_lang banner."""
|
| 129 |
+
banner = f"""
|
| 130 |
+
βββββββββββββββββββββββββββββββββββββββββ
|
| 131 |
+
β β
|
| 132 |
+
β ββββββββββββββββ βββ ββββββββ
|
| 133 |
+
β βββββββββββββββββ βββ βββββββββ
|
| 134 |
+
β βββ βββ βββ βββ βββ ββββ
|
| 135 |
+
β βββ βββ βββ βββ βββ βββ
|
| 136 |
+
β βββ ββββββββ βββββββ ββββββββββ
|
| 137 |
+
β βββ βββββββ βββββββ ββββββββ
|
| 138 |
+
β β
|
| 139 |
+
β TD Lang v{__version__} β .td file compiler β
|
| 140 |
+
β β
|
| 141 |
+
βββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
+
"""
|
| 143 |
+
print(banner)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def print_info(filepath: str) -> None:
|
| 147 |
+
"""Show what a .td file does without compiling β human-readable plan summary."""
|
| 148 |
+
program = parse_td_file(filepath)
|
| 149 |
+
|
| 150 |
+
print(f"\n File: {filepath}")
|
| 151 |
+
print(f" Commands: {len(program.commands)}")
|
| 152 |
+
|
| 153 |
+
if program.gates:
|
| 154 |
+
print(f" Gates: {', '.join(program.gates.must_pass)}")
|
| 155 |
+
if program.budget:
|
| 156 |
+
parts = []
|
| 157 |
+
if program.budget.max_gpu_hours is not None:
|
| 158 |
+
parts.append(f"{program.budget.max_gpu_hours} GPU hrs")
|
| 159 |
+
if program.budget.max_cost is not None:
|
| 160 |
+
parts.append(f"${program.budget.max_cost}")
|
| 161 |
+
print(f" Budget: {', '.join(parts)}")
|
| 162 |
+
if program.data_contract:
|
| 163 |
+
print(f" Data contract: fields={program.data_contract.required_fields}")
|
| 164 |
+
if program.reward_contract:
|
| 165 |
+
print(f" Reward contract: verifiers={program.reward_contract.verifiers}")
|
| 166 |
+
|
| 167 |
+
print("\n Plan:")
|
| 168 |
+
for i, cmd in enumerate(program.commands, 1):
|
| 169 |
+
phase, name = _PHASE_MAP.get(type(cmd), ("?", type(cmd).__name__))
|
| 170 |
+
target = getattr(cmd, 'target', getattr(cmd, 'alias', ''))
|
| 171 |
+
detail = ""
|
| 172 |
+
if hasattr(cmd, 'method'):
|
| 173 |
+
detail += f" method={cmd.method}"
|
| 174 |
+
if hasattr(cmd, 'source') and name in ("merge", "synth"):
|
| 175 |
+
detail += f" from={cmd.source}"
|
| 176 |
+
if hasattr(cmd, 'layers') and cmd.layers != "all":
|
| 177 |
+
detail += f" layers={cmd.layers}"
|
| 178 |
+
if hasattr(cmd, 'output') and cmd.output:
|
| 179 |
+
detail += f" -> {cmd.output}"
|
| 180 |
+
print(f" {i}. [P{phase}] {name} {target}{detail}")
|
| 181 |
+
|
| 182 |
+
print()
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def main():
|
| 186 |
+
"""Main entry point for td_lang CLI."""
|
| 187 |
+
args = parse_args()
|
| 188 |
+
print_banner()
|
| 189 |
+
|
| 190 |
+
executor = TDExecutor(output_dir=args.output)
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
if args.action == "info":
|
| 194 |
+
print_info(args.file)
|
| 195 |
+
|
| 196 |
+
elif args.action == "check":
|
| 197 |
+
program = executor.check(args.file)
|
| 198 |
+
print("\n[td_lang] File is valid!")
|
| 199 |
+
|
| 200 |
+
elif args.action == "compile":
|
| 201 |
+
py_path = executor.compile(args.file)
|
| 202 |
+
print(f"\n[td_lang] Generated: {py_path}")
|
| 203 |
+
print("[td_lang] You can run it with: python", py_path)
|
| 204 |
+
if args.verbose:
|
| 205 |
+
print("\n--- Generated Python ---")
|
| 206 |
+
print(py_path.read_text())
|
| 207 |
+
print("--- End ---")
|
| 208 |
+
|
| 209 |
+
elif args.action == "run":
|
| 210 |
+
result = executor.run(args.file, dry_run=args.dry)
|
| 211 |
+
if result["status"] == "success":
|
| 212 |
+
sys.exit(0)
|
| 213 |
+
elif result["status"] == "dry_run":
|
| 214 |
+
sys.exit(0)
|
| 215 |
+
else:
|
| 216 |
+
sys.exit(1)
|
| 217 |
+
|
| 218 |
+
except TDLangError as e:
|
| 219 |
+
print(f"\n[td_lang] ERROR: {e}")
|
| 220 |
+
sys.exit(1)
|
| 221 |
+
|
| 222 |
+
except FileNotFoundError:
|
| 223 |
+
print(f"\n[td_lang] ERROR: File not found: {args.file}")
|
| 224 |
+
print("[td_lang] Check the path and try again.")
|
| 225 |
+
sys.exit(1)
|
| 226 |
+
|
| 227 |
+
except KeyboardInterrupt:
|
| 228 |
+
print("\n[td_lang] Interrupted.")
|
| 229 |
+
sys.exit(130)
|
hugging/td_lang/td_lang/compiler.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
hugging/td_lang/td_lang/engine/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Lang Engine β the merge/heal/validate runtime (formerly td_fuse).
|
| 3 |
+
|
| 4 |
+
All model merging, transport, healing, and validation logic lives here.
|
| 5 |
+
td_lang compiles .td files into Python that imports from this engine.
|
| 6 |
+
|
| 7 |
+
Architecture:
|
| 8 |
+
td_lang/engine/
|
| 9 |
+
βββ __init__.py β This file
|
| 10 |
+
βββ config.py β Model configs, merge order, hyperparameters
|
| 11 |
+
βββ canary.py β Canary injection + testing ("brain surgery")
|
| 12 |
+
βββ transport.py β Wrapper around official T&M code
|
| 13 |
+
βββ techniques.py β Advanced techniques (Theseus, ARM, OTMF, RAM, Mergeability)
|
| 14 |
+
βββ merge.py β Sequential merge orchestrator
|
| 15 |
+
βββ validate.py β Post-merge validation (canary, perplexity, benchmarks)
|
| 16 |
+
βββ heal.py β QLoRA healing fine-tune via Unsloth
|
| 17 |
+
βββ run.py β Standalone entry point (optional)
|
| 18 |
+
|
| 19 |
+
Usage (via td_lang):
|
| 20 |
+
python -m td_lang run td_start.td
|
| 21 |
+
python -m td_lang run demo_merge.td
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
__version__ = "0.2.0"
|
| 25 |
+
__author__ = "Milan (TD Project)"
|