td-builder commited on
Commit
9a9bead
·
verified ·
1 Parent(s): 308d8a7

Upload 55 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. hugging/QUICKSTART.md +106 -0
  3. hugging/deploy.sh +128 -0
  4. hugging/requirements.txt +226 -0
  5. hugging/td_fuse/__init__.py +25 -0
  6. hugging/td_fuse/__main__.py +4 -0
  7. hugging/td_fuse/canary.py +178 -0
  8. hugging/td_fuse/config.py +299 -0
  9. hugging/td_fuse/heal.py +363 -0
  10. hugging/td_fuse/merge.py +985 -0
  11. hugging/td_fuse/run.py +279 -0
  12. hugging/td_fuse/techniques.py +669 -0
  13. hugging/td_fuse/transport.py +527 -0
  14. hugging/td_fuse/validate.py +215 -0
  15. hugging/td_lang/.DS_Store +0 -0
  16. hugging/td_lang/__init__.py +51 -0
  17. hugging/td_lang/__main__.py +5 -0
  18. hugging/td_lang/__pycache__/__init__.cpython-310.pyc +0 -0
  19. hugging/td_lang/__pycache__/__init__.cpython-314.pyc +0 -0
  20. hugging/td_lang/__pycache__/__main__.cpython-310.pyc +0 -0
  21. hugging/td_lang/__pycache__/__main__.cpython-314.pyc +0 -0
  22. hugging/td_lang/__pycache__/ast_nodes.cpython-310.pyc +0 -0
  23. hugging/td_lang/__pycache__/ast_nodes.cpython-314.pyc +0 -0
  24. hugging/td_lang/__pycache__/cli.cpython-310.pyc +0 -0
  25. hugging/td_lang/__pycache__/cli.cpython-314.pyc +0 -0
  26. hugging/td_lang/__pycache__/compiler.cpython-310.pyc +0 -0
  27. hugging/td_lang/__pycache__/compiler.cpython-314.pyc +3 -0
  28. hugging/td_lang/__pycache__/errors.cpython-310.pyc +0 -0
  29. hugging/td_lang/__pycache__/errors.cpython-314.pyc +0 -0
  30. hugging/td_lang/__pycache__/executor.cpython-310.pyc +0 -0
  31. hugging/td_lang/__pycache__/executor.cpython-314.pyc +0 -0
  32. hugging/td_lang/__pycache__/grammar.cpython-310.pyc +0 -0
  33. hugging/td_lang/__pycache__/grammar.cpython-314.pyc +0 -0
  34. hugging/td_lang/ast_nodes.py +421 -0
  35. hugging/td_lang/cli.py +212 -0
  36. hugging/td_lang/compiler.py +0 -0
  37. hugging/td_lang/errors.py +99 -0
  38. hugging/td_lang/examples/demo_autopilot.td +62 -0
  39. hugging/td_lang/examples/demo_full.td +17 -0
  40. hugging/td_lang/examples/demo_fuse.td +19 -0
  41. hugging/td_lang/examples/demo_heal.td +6 -0
  42. hugging/td_lang/examples/demo_loop.td +28 -0
  43. hugging/td_lang/examples/demo_merge.td +5 -0
  44. hugging/td_lang/examples/demo_phase3.td +26 -0
  45. hugging/td_lang/examples/demo_phase4.td +33 -0
  46. hugging/td_lang/examples/demo_td_loop.td +44 -0
  47. hugging/td_lang/examples/err_edit_unloaded.td +2 -0
  48. hugging/td_lang/examples/err_fork_duplicate.td +3 -0
  49. hugging/td_lang/examples/err_prune_100.td +4 -0
  50. hugging/td_lang/examples/test_fork_edit.td +12 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ hugging/td_lang/__pycache__/compiler.cpython-314.pyc filter=lfs diff=lfs merge=lfs -text
hugging/QUICKSTART.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TD Quick Start — Rent a GPU and Go
2
+
3
+ ## What You Need (One-Time Setup)
4
+
5
+ 1. **vast.ai account** — sign up at vast.ai, add credit ($10-20 to start)
6
+ 2. **HuggingFace account** — sign up at huggingface.co (use any username, doesn't have to be your real name)
7
+ 3. **HuggingFace token** — Settings → Access Tokens → New Token → **Write** access
8
+ 4. **ntfy.sh app** on your phone (you already have this)
9
+
10
+ ## One-Time: Upload Your Code to Private HuggingFace
11
+
12
+ Do this once from your computer. After this, your code lives in a private repo that only you can see.
13
+
14
+ ```bash
15
+ # Install the tool
16
+ pip install huggingface_hub
17
+
18
+ # Log in (paste your token when asked)
19
+ huggingface-cli login
20
+
21
+ # Upload everything
22
+ HF_USER=your_hf_username bash upload_to_hf.sh
23
+ ```
24
+
25
+ Now your td_lang, td_fuse, .td files, and deploy script are all in a private HuggingFace repo. Nobody can see them except you.
26
+
27
+ **When you update your code**, just run `upload_to_hf.sh` again — it overwrites with the latest version.
28
+
29
+ ## Every Time: Rent GPU → 3 Commands → Done
30
+
31
+ ### 1. Rent a GPU on vast.ai
32
+
33
+ Go to vast.ai → Console → Search for:
34
+ - **GPU:** RTX 4090 (24GB) or A100 (40GB+)
35
+ - **Image:** Pick one with PyTorch pre-installed (like `pytorch/pytorch`)
36
+ - **Storage:** At least 100GB disk
37
+ - **Cost:** ~$0.40-0.80/hr for a 4090
38
+
39
+ Click **RENT** and wait for it to start (~1-2 minutes).
40
+
41
+ ### 2. Connect to the GPU
42
+
43
+ vast.ai gives you an SSH command. Copy and paste it into your terminal:
44
+ ```
45
+ ssh -p 12345 root@ssh1.vast.ai
46
+ ```
47
+
48
+ ### 3. Run these 3 commands
49
+
50
+ ```bash
51
+ # Set your token
52
+ export HF_TOKEN=hf_your_token_here
53
+
54
+ # Download your code from HuggingFace (takes ~10 seconds)
55
+ pip install huggingface_hub -q && python -c "
56
+ from huggingface_hub import snapshot_download
57
+ snapshot_download('YOUR_USERNAME/td-toolkit', local_dir='/workspace/td')
58
+ "
59
+
60
+ # Go!
61
+ cd /workspace/td && bash deploy.sh demo_autopilot.td
62
+ ```
63
+
64
+ That's it. Put your phone down. ntfy.sh sends you updates as it runs.
65
+
66
+ ### 4. When it's done
67
+
68
+ Your model gets saved to Google Drive automatically (if rclone is configured in the .td file). Otherwise it stays on the GPU at `final_model/`.
69
+
70
+ ## Setting Up Google Drive (Optional, One-Time per GPU)
71
+
72
+ On the GPU machine after SSHing in:
73
+ ```bash
74
+ rclone config
75
+ ```
76
+ 1. Type `n` for new remote
77
+ 2. Name it `gdrive`
78
+ 3. Pick `Google Drive` from the list
79
+ 4. Follow the prompts (it gives you a URL to visit in your browser)
80
+ 5. Done — now `save base to "gdrive:TD/models/final"` works in your .td files
81
+
82
+ **Tip:** You can save the rclone config to your HuggingFace repo too, so you don't have to set it up every time.
83
+
84
+ ## Quick Reference
85
+
86
+ | Command | What it does |
87
+ |---------|-------------|
88
+ | `bash deploy.sh my_file.td` | Full setup + run |
89
+ | `python -m td_lang check my_file.td` | Check syntax only |
90
+ | `python -m td_lang info my_file.td` | Show plan without running |
91
+ | `python -m td_lang run my_file.td` | Run (skip deploy setup) |
92
+ | `python -m td_lang run my_file.td --dry` | Compile but don't execute |
93
+
94
+ ## If Something Goes Wrong
95
+
96
+ - **OOM (out of memory):** Your .td file's `on_error` block handles this — it retries with smaller batches
97
+ - **Model download fails:** Check your HF_TOKEN is set correctly
98
+ - **ntfy not working:** Check your phone has the ntfy app and you're subscribed to the right topic
99
+ - **GPU disconnects:** Re-SSH in, your files are still there. Run deploy.sh again — td_lang picks up from the last snapshot
100
+
101
+ ## Cost Estimate
102
+
103
+ For the full `demo_autopilot.td` pipeline (merge 4 models + 5 training loops):
104
+ - **RTX 4090:** ~$0.50/hr × ~30-40 hrs = ~$15-20
105
+ - **A100 40GB:** ~$1.00/hr × ~20-30 hrs = ~$20-30
106
+ - **Budget cap in .td file:** Set `max_cost = 160.00` to prevent runaway costs
hugging/deploy.sh ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # deploy.sh — One-command setup for vast.ai GPU instances
3
+ #
4
+ # TWO ways to use this:
5
+ #
6
+ # Option A — Download from your private HuggingFace repo + run:
7
+ # export HF_TOKEN=your_token
8
+ # pip install huggingface_hub
9
+ # python -c "from huggingface_hub import snapshot_download; snapshot_download('YOUR_USER/td-toolkit', local_dir='.')"
10
+ # bash deploy.sh demo_autopilot.td
11
+ #
12
+ # Option B — Already uploaded files manually:
13
+ # bash deploy.sh my_pipeline.td
14
+
15
+ set -e # Stop on any error
16
+
17
+ # Colors for pretty output
18
+ GREEN='\033[0;32m'
19
+ YELLOW='\033[1;33m'
20
+ RED='\033[0;31m'
21
+ NC='\033[0m' # No Color
22
+
23
+ echo ""
24
+ echo "==========================================="
25
+ echo " TD Deploy — vast.ai GPU Setup"
26
+ echo "==========================================="
27
+ echo ""
28
+
29
+ # Check if a .td file was provided
30
+ if [ -z "$1" ]; then
31
+ echo -e "${RED}ERROR: No .td file specified${NC}"
32
+ echo ""
33
+ echo "Usage: bash deploy.sh my_pipeline.td"
34
+ echo ""
35
+ echo "Available .td files:"
36
+ ls -1 *.td td_lang/examples/*.td 2>/dev/null || echo " (none found)"
37
+ exit 1
38
+ fi
39
+
40
+ TD_FILE="$1"
41
+
42
+ if [ ! -f "$TD_FILE" ]; then
43
+ echo -e "${RED}ERROR: File not found: $TD_FILE${NC}"
44
+ exit 1
45
+ fi
46
+
47
+ echo -e "${GREEN}[1/5]${NC} Installing td_lang dependencies..."
48
+ pip install lark --quiet 2>/dev/null || pip install lark
49
+ echo " Done."
50
+
51
+ # Check for HF token
52
+ echo ""
53
+ echo -e "${GREEN}[2/5]${NC} Checking environment..."
54
+ if [ -z "$HF_TOKEN" ]; then
55
+ echo -e "${YELLOW} WARNING: HF_TOKEN not set.${NC}"
56
+ echo " Models won't download from HuggingFace without it."
57
+ echo " Set it with: export HF_TOKEN=your_token_here"
58
+ echo ""
59
+ read -p " Continue anyway? (y/n) " -n 1 -r
60
+ echo
61
+ if [[ ! $REPLY =~ ^[Yy]$ ]]; then
62
+ exit 1
63
+ fi
64
+ else
65
+ echo " HF_TOKEN: set"
66
+ fi
67
+
68
+ # Check td_lang is accessible
69
+ echo ""
70
+ echo -e "${GREEN}[3/5]${NC} Checking td_lang..."
71
+ if python -c "import td_lang" 2>/dev/null; then
72
+ VERSION=$(python -c "import td_lang; print(td_lang.__version__)" 2>/dev/null || echo "unknown")
73
+ echo " td_lang v$VERSION: found"
74
+ else
75
+ # Try adding current directory to path
76
+ export PYTHONPATH="${PYTHONPATH:+$PYTHONPATH:}$(pwd)"
77
+ if python -c "import td_lang" 2>/dev/null; then
78
+ VERSION=$(python -c "import td_lang; print(td_lang.__version__)" 2>/dev/null || echo "unknown")
79
+ echo " td_lang v$VERSION: found (added to PYTHONPATH)"
80
+ else
81
+ echo -e "${RED} ERROR: td_lang not found!${NC}"
82
+ echo " Make sure the td_lang/ folder is in the current directory."
83
+ echo " Current directory: $(pwd)"
84
+ echo " Contents:"
85
+ ls -1
86
+ exit 1
87
+ fi
88
+ fi
89
+
90
+ # Check for rclone (needed for save command)
91
+ echo ""
92
+ echo -e "${GREEN}[4/5]${NC} Checking tools..."
93
+ if command -v rclone &> /dev/null; then
94
+ echo " rclone: installed"
95
+ if rclone listremotes 2>/dev/null | grep -q "gdrive:"; then
96
+ echo " Google Drive: configured"
97
+ else
98
+ echo -e "${YELLOW} Google Drive: not configured${NC}"
99
+ echo " Run 'rclone config' to set up Google Drive (name it 'gdrive')"
100
+ fi
101
+ else
102
+ echo -e "${YELLOW} rclone: not installed (installing...)${NC}"
103
+ curl -s https://rclone.org/install.sh | bash 2>/dev/null || {
104
+ echo -e "${YELLOW} Could not install rclone. 'save' commands won't work.${NC}"
105
+ }
106
+ fi
107
+
108
+ # Check GPU
109
+ if command -v nvidia-smi &> /dev/null; then
110
+ GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)
111
+ GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader | head -1)
112
+ echo " GPU: $GPU_NAME ($GPU_MEM)"
113
+ else
114
+ echo -e "${YELLOW} WARNING: No GPU detected (nvidia-smi not found)${NC}"
115
+ fi
116
+
117
+ # Run the .td file
118
+ echo ""
119
+ echo -e "${GREEN}[5/5]${NC} Running: $TD_FILE"
120
+ echo "==========================================="
121
+ echo ""
122
+
123
+ python -m td_lang run "$TD_FILE"
124
+
125
+ echo ""
126
+ echo "==========================================="
127
+ echo -e "${GREEN} TD Deploy complete!${NC}"
128
+ echo "==========================================="
hugging/requirements.txt ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TD Merge Pipeline - Complete Python Dependency List
2
+ # Python 3.11-3.12 (3.12 preferred)
3
+ # CUDA 12.4 (RTX 4090 compatible)
4
+ # Updated: February 2026
5
+
6
+ # ============================================================================
7
+ # CORE ML FRAMEWORKS
8
+ # ============================================================================
9
+
10
+ # PyTorch 2.4+ with CUDA 12.4 support (RTX 4090 compatible)
11
+ torch==2.4.1
12
+ torchvision==0.19.1
13
+ torchaudio==2.4.1
14
+
15
+ # NVIDIA CUDA Toolkit support (already installed on system)
16
+ # CUDA 12.4 for RTX 4090 compatibility
17
+ # Note: Install via: pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
18
+
19
+ # ============================================================================
20
+ # TRANSFORMERS & MODEL LOADING
21
+ # ============================================================================
22
+
23
+ # Transformers library - must support Qwen3 (requires 4.51.0+)
24
+ transformers==4.51.0
25
+
26
+ # Safetensors for efficient model serialization
27
+ safetensors==0.4.5
28
+
29
+ # Accelerate for distributed training & multi-GPU support
30
+ accelerate==1.2.1
31
+
32
+ # ============================================================================
33
+ # PARAMETER EFFICIENT FINE-TUNING (PEFT/QLoRA)
34
+ # ============================================================================
35
+
36
+ # PEFT (Parameter-Efficient Fine-Tuning) - supports QLoRA
37
+ # Must be >= 0.14.0 for 8-bit weight merging
38
+ peft==0.14.0
39
+
40
+ # BitsAndBytes for 4-bit quantization (QLoRA)
41
+ # Works with PyTorch 2.4, stable with >= 0.42
42
+ bitsandbytes==0.44.0
43
+
44
+ # ============================================================================
45
+ # OPTIMAL TRANSPORT & MODEL MERGING
46
+ # ============================================================================
47
+
48
+ # POT (Python Optimal Transport) - for Transport and Merge algorithm
49
+ # Used for activation-aligned cross-architecture weight alignment
50
+ POT==0.9.6
51
+
52
+ # SciPy for optimization & linear algebra (OrthoMerge, LARV)
53
+ scipy==1.14.1
54
+
55
+ # NumPy for numerical operations
56
+ numpy==1.26.4
57
+
58
+ # Lark parser for td_lang DSL
59
+ lark>=1.1.0
60
+
61
+ # Unsloth for fast fine-tuning with 7B models
62
+ # Includes pre-quantized Qwen3-8B support, VLLM Standby Mode for concurrent training+inference
63
+ unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git@main
64
+
65
+ # ============================================================================
66
+ # REINFORCEMENT LEARNING (RL TRAINING)
67
+ # ============================================================================
68
+
69
+ # TRL (Transformers Reinforcement Learning)
70
+ # Provides GRPO (Group Relative Policy Optimization) trainer
71
+ # v0.27.2 stable, tested with transformers 4.40+
72
+ trl==0.27.2
73
+
74
+ # ============================================================================
75
+ # EVALUATION & BENCHMARKING
76
+ # ============================================================================
77
+
78
+ # LM-Eval (EleutherAI evaluation harness) for benchmarking
79
+ # Explicitly install HF backend for transformers support
80
+ lm-eval[hf]==0.4.10
81
+
82
+ # MathEval utilities
83
+ math-eval==0.0.3
84
+
85
+ # ============================================================================
86
+ # DATA HANDLING & DATASETS
87
+ # ============================================================================
88
+
89
+ # HuggingFace Datasets library (HF Hub integration)
90
+ datasets==4.5.1
91
+
92
+ # PyArrow for efficient data processing
93
+ pyarrow==17.0.0
94
+
95
+ # Pandas for data manipulation
96
+ pandas==2.2.3
97
+
98
+ # ============================================================================
99
+ # OPTIONAL: MERGING & FUSION (if not building Transport & Merge from scratch)
100
+ # ============================================================================
101
+
102
+ # MergeKit - alternative model merging tool (supports TIES/DARE-TIES)
103
+ # Note: Limited to same-architecture merges, but useful for fallback strategy
104
+ mergekit==0.0.7
105
+
106
+ # ============================================================================
107
+ # WEB & KNOWLEDGE RETRIEVAL (for ALAS - Autonomous Learning Agent System)
108
+ # ============================================================================
109
+
110
+ # Requests for HTTP operations
111
+ requests==2.31.0
112
+
113
+ # Beautiful Soup for web scraping
114
+ beautifulsoup4==4.12.3
115
+
116
+ # ============================================================================
117
+ # AGENT ORCHESTRATION & UTILITIES
118
+ # ============================================================================
119
+
120
+ # LangGraph for multi-agent coordination (SYMPHONY)
121
+ langgraph==0.2.7
122
+
123
+ # LangChain for prompt management & chains
124
+ langchain==0.3.9
125
+
126
+ # Pydantic for data validation
127
+ pydantic==2.8.2
128
+
129
+ # ============================================================================
130
+ # VISION AGENT (Fara-7B integration)
131
+ # ============================================================================
132
+
133
+ # Pillow for image processing
134
+ Pillow==11.2.0
135
+
136
+ # OpenCV for computer vision tasks
137
+ opencv-python==4.10.1.26
138
+
139
+ # ============================================================================
140
+ # INFERENCE & SERVING
141
+ # ============================================================================
142
+
143
+ # vLLM for fast LLM inference serving
144
+ vllm==0.6.4
145
+
146
+ # ============================================================================
147
+ # UTILITIES & LOGGING
148
+ # ============================================================================
149
+
150
+ # PyYAML for config files
151
+ PyYAML==6.0.2
152
+
153
+ # Python-dotenv for environment variable management
154
+ python-dotenv==1.0.1
155
+
156
+ # Tqdm for progress bars
157
+ tqdm==4.67.1
158
+
159
+ # Rich for beautiful terminal output
160
+ rich==13.8.1
161
+
162
+ # ============================================================================
163
+ # DEVELOPMENT & TESTING (OPTIONAL)
164
+ # ============================================================================
165
+
166
+ # Pytest for testing
167
+ pytest==8.3.2
168
+
169
+ # IPython for interactive development
170
+ ipython==8.20.0
171
+
172
+ # Jupyter for notebooks
173
+ jupyter==1.0.0
174
+
175
+ # ============================================================================
176
+ # VERSION NOTES & COMPATIBILITY MATRIX
177
+ # ============================================================================
178
+ #
179
+ # COMPATIBILITY VERIFIED:
180
+ # ✓ PyTorch 2.4.1 + CUDA 12.4 + RTX 4090 (full support)
181
+ # ✓ Transformers 4.51.0 + Qwen3-8B (latest, required for Qwen3)
182
+ # ✓ Unsloth 2026.2.x + Qwen3 + QLoRA (fast fine-tuning)
183
+ # ✓ BitsAndBytes 0.44.0 + PyTorch 2.4 (4-bit quantization)
184
+ # ✓ PEFT 0.14.0 + BitsAndBytes (8-bit weight merging)
185
+ # ✓ TRL 0.27.2 + GRPO (RL training with group advantage)
186
+ # ✓ POT 0.9.6 + SciPy 1.14.1 (optimal transport)
187
+ # ✓ LM-Eval 0.4.10[hf] + Transformers 4.51.0 (benchmarking)
188
+ #
189
+ # KNOWN ISSUES & WORKAROUNDS:
190
+ # - Flash-Attention-2: Works with Qwen3 but may produce incorrect outputs
191
+ # → Use attn_implementation="sdpa" (default) instead
192
+ # → DO NOT set attn_implementation="flash_attention_2"
193
+ #
194
+ # - BitsAndBytes + XFormers: Avoid mixing with older PyTorch versions
195
+ # → Use Unsloth bundled installer which pre-handles this
196
+ #
197
+ # - Thinking Mode Survival: Qwen3's thinking tokens (151668) may be scrambled
198
+ # → Freeze thinking token embeddings during Transport & Merge
199
+ # → Apply Contrastive Gradient Identification (ReasonAny) to protect reasoning params
200
+ # → Post-merge fine-tune on 500-1000 thinking examples
201
+ #
202
+ # CUDA 12.4 NOTES:
203
+ # - RTX 4090 full support (Ada architecture, compute capability 8.9)
204
+ # - All libraries compiled for CUDA 12.4 compatibility
205
+ # - No need to install system CUDA separately if PyTorch wheels handle it
206
+ #
207
+ # HARDWARE CHECKLIST:
208
+ # ✓ Dual RTX 4090 (48GB VRAM total) - adequate for full pipeline
209
+ # ✓ 64GB+ system RAM (128GB comfortable)
210
+ # ✓ 1500W+ PSU (handles 1.2kW sustained load)
211
+ # ✓ Gen4+ NVMe SSD (3000+ MB/s write, 2TB minimum)
212
+ #
213
+ # INSTALLATION:
214
+ # 1. Create venv: python3.12 -m venv venv && source venv/bin/activate
215
+ # 2. Install PyTorch with CUDA 12.4:
216
+ # pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
217
+ # 3. Install this requirements file:
218
+ # pip install -r requirements.txt
219
+ # 4. Optional - install Unsloth's bundled version (handles all conflicts):
220
+ # pip install unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git@main
221
+ #
222
+ # ESTIMATED INSTALLATION TIME:
223
+ # - PyTorch (download): 5-10 min
224
+ # - Other packages: 2-5 min
225
+ # - Total: 10-15 minutes
226
+ #
hugging/td_fuse/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Fuse — Transport and Merge pipeline for Time Dilation project.
3
+
4
+ Merges 5 different-architecture 7B models into Qwen3-8B using
5
+ optimal transport (Transport and Merge, arxiv 2602.05495).
6
+
7
+ Architecture:
8
+ td_fuse/
9
+ ├── __init__.py ← This file
10
+ ├── config.py ← Model configs, merge order, hyperparameters
11
+ ├── canary.py ← Canary injection + testing ("brain surgery")
12
+ ├── transport.py ← Wrapper around official T&M code
13
+ ├── techniques.py ← Advanced techniques (Theseus, ARM, OTMF, RAM, Mergeability)
14
+ ├── merge.py ← Sequential merge orchestrator
15
+ ├── validate.py ← Post-merge validation (canary, perplexity, benchmarks)
16
+ ├── heal.py ← QLoRA healing fine-tune via Unsloth
17
+ └── run.py ← Main entry point
18
+
19
+ Usage:
20
+ python -m td_fuse.run --config default --stage all
21
+ python -m td_fuse.run --config default --stage demo # Dad demo (DeepSeek only)
22
+ """
23
+
24
+ __version__ = "0.1.0"
25
+ __author__ = "Milan (TD Project)"
hugging/td_fuse/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """Allow running td_fuse as a module: python -m td_fuse"""
2
+ from .run import main
3
+
4
+ main()
hugging/td_fuse/canary.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Canary Injection & Testing — Milan's "Brain Surgery" idea.
3
+
4
+ Inject unique fake facts into each model before merging.
5
+ After merge, test if the merged model remembers ALL fake facts.
6
+ If it does → knowledge genuinely transferred from each source.
7
+ If it doesn't → that model's knowledge was lost during merge.
8
+
9
+ Findings: #11 (evaluation plan)
10
+ """
11
+
12
+ import torch
13
+ from typing import Optional
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
+
16
+ from .config import CANARY_FACTS
17
+
18
+
19
+ def inject_canary(
20
+ model: AutoModelForCausalLM,
21
+ tokenizer: AutoTokenizer,
22
+ model_name: str,
23
+ num_steps: int = 50,
24
+ learning_rate: float = 1e-4,
25
+ ) -> AutoModelForCausalLM:
26
+ """
27
+ Inject a fake fact into a model via brief fine-tuning.
28
+
29
+ This is the "brain surgery" — we teach each model a unique fake fact
30
+ so we can test if that knowledge survives the merge.
31
+
32
+ Args:
33
+ model: The model to inject into
34
+ tokenizer: The model's tokenizer
35
+ model_name: Key into CANARY_FACTS dict
36
+ num_steps: Training steps for injection (50 is usually enough)
37
+ learning_rate: LR for injection (higher than normal — we WANT it to memorise)
38
+
39
+ Returns:
40
+ Model with canary fact injected
41
+ """
42
+ if model_name not in CANARY_FACTS:
43
+ print(f"[canary] No canary defined for {model_name}, skipping")
44
+ return model
45
+
46
+ canary = CANARY_FACTS[model_name]
47
+ inject_text = canary["inject_text"]
48
+
49
+ print(f"[canary] Injecting into {model_name}: '{inject_text[:60]}...'")
50
+
51
+ # Tokenize the fact
52
+ inputs = tokenizer(
53
+ inject_text,
54
+ return_tensors="pt",
55
+ padding=True,
56
+ truncation=True,
57
+ max_length=128,
58
+ ).to(model.device)
59
+
60
+ # Brief fine-tune to memorise the fact
61
+ model.train()
62
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
63
+
64
+ for step in range(num_steps):
65
+ outputs = model(**inputs, labels=inputs["input_ids"])
66
+ loss = outputs.loss
67
+ loss.backward()
68
+ optimizer.step()
69
+ optimizer.zero_grad()
70
+
71
+ if step % 10 == 0:
72
+ print(f" step {step}/{num_steps}, loss: {loss.item():.4f}")
73
+
74
+ model.eval()
75
+ print(f"[canary] Injection complete for {model_name}")
76
+ return model
77
+
78
+
79
+ def test_canary(
80
+ model: AutoModelForCausalLM,
81
+ tokenizer: AutoTokenizer,
82
+ model_name: str,
83
+ verbose: bool = True,
84
+ ) -> bool:
85
+ """
86
+ Test if a model remembers a specific canary fact.
87
+
88
+ Args:
89
+ model: The model to test
90
+ tokenizer: The tokenizer
91
+ model_name: Which canary to test
92
+ verbose: Print the model's response
93
+
94
+ Returns:
95
+ True if the model recalls the canary fact
96
+ """
97
+ if model_name not in CANARY_FACTS:
98
+ print(f"[canary] No canary for {model_name}, skipping")
99
+ return True
100
+
101
+ canary = CANARY_FACTS[model_name]
102
+ prompt = canary["prompt"]
103
+ expected = canary["answer"].lower()
104
+
105
+ # Generate response
106
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
107
+ with torch.no_grad():
108
+ outputs = model.generate(
109
+ **inputs,
110
+ max_new_tokens=64,
111
+ temperature=0.1, # Low temp — we want the most likely answer
112
+ do_sample=False, # Greedy — deterministic
113
+ repetition_penalty=1.5, # Prevent repetition (R1 issue)
114
+ )
115
+
116
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
117
+ response_lower = response.lower()
118
+
119
+ # Check if key parts of the expected answer appear in the response
120
+ # We check for key words, not exact match (model may paraphrase)
121
+ key_words = [w for w in expected.split() if len(w) > 3] # Words > 3 chars
122
+ matches = sum(1 for w in key_words if w in response_lower)
123
+ match_ratio = matches / len(key_words) if key_words else 0
124
+
125
+ passed = match_ratio >= 0.5 # At least half the key words present
126
+
127
+ if verbose:
128
+ status = "✓ PASS" if passed else "✗ FAIL"
129
+ print(f"\n[canary] Testing {model_name}:")
130
+ print(f" Prompt: {prompt}")
131
+ print(f" Expected: {canary['answer']}")
132
+ print(f" Got: {response}")
133
+ print(f" Match: {match_ratio:.0%} ({matches}/{len(key_words)} key words)")
134
+ print(f" Status: {status}")
135
+
136
+ return passed
137
+
138
+
139
+ def test_all_canaries(
140
+ model: AutoModelForCausalLM,
141
+ tokenizer: AutoTokenizer,
142
+ merged_sources: list[str],
143
+ ) -> dict:
144
+ """
145
+ Test ALL canary facts that should be present in a merged model.
146
+
147
+ Args:
148
+ model: The merged model
149
+ tokenizer: The tokenizer
150
+ merged_sources: List of model names that have been merged so far
151
+
152
+ Returns:
153
+ Dict of {model_name: passed_bool}
154
+ """
155
+ print("\n" + "=" * 60)
156
+ print("CANARY TEST — Did knowledge transfer from each model?")
157
+ print("=" * 60)
158
+
159
+ results = {}
160
+
161
+ # Test the target model's canary
162
+ results["Qwen3-8B"] = test_canary(model, tokenizer, "Qwen3-8B")
163
+
164
+ # Test each merged source model's canary
165
+ for source_name in merged_sources:
166
+ results[source_name] = test_canary(model, tokenizer, source_name)
167
+
168
+ # Summary
169
+ passed = sum(1 for v in results.values() if v)
170
+ total = len(results)
171
+ print(f"\n[canary] Results: {passed}/{total} canaries recalled")
172
+
173
+ if passed < total:
174
+ failed = [k for k, v in results.items() if not v]
175
+ print(f"[canary] ⚠ FAILED canaries: {', '.join(failed)}")
176
+ print("[canary] Knowledge from these models may have been lost during merge")
177
+
178
+ return results
hugging/td_fuse/config.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Fuse Configuration — All 5 models, merge order, hyperparameters.
3
+
4
+ Every decision here is backed by research findings in:
5
+ plugins/td-fuse-research/findings/
6
+
7
+ Target model: Qwen3-VL-8B-Instruct (vision + browser agent + text)
8
+ - Language backbone is identical to Qwen3-8B (36 layers, 4096 hidden, GQA)
9
+ - Vision encoder sits on top — we DON'T touch it during merges
10
+ - This gives us browser agent abilities (like Fara) for FREE
11
+
12
+ Merge order (risk-optimised, findings #22):
13
+ 1. DeepSeek-R1-0528 → Qwen3-VL-8B (same arch, LOW risk)
14
+ 2. MiMo-7B-RL → Merged_1 (drop MTP, MEDIUM risk)
15
+ 3. Llama-3.1-8B → Merged_2 (skip embeddings, MEDIUM risk)
16
+ 4. Falcon-H1R-7B → Merged_3 (SSM hybrid, HIGH risk)
17
+ """
18
+
19
+ from dataclasses import dataclass, field
20
+ from typing import Optional
21
+ from pathlib import Path
22
+
23
+
24
+ # ============================================================================
25
+ # MODEL DEFINITIONS
26
+ # ============================================================================
27
+
28
+ @dataclass
29
+ class ModelConfig:
30
+ """Configuration for a single model in the merge pipeline."""
31
+ name: str
32
+ hf_id: str # HuggingFace model ID
33
+ architecture: str # "transformer", "transformer+mtp", "hybrid_ssm"
34
+ layers: int
35
+ hidden_dim: int
36
+ num_heads: int
37
+ num_kv_heads: int
38
+ vocab_size: int
39
+ vocab_overlap_with_qwen3: float # 0.0 to 1.0
40
+ skip_embeddings: bool # True if vocab overlap < 50%
41
+ trust_remote_code: bool
42
+ special_handling: list = field(default_factory=list) # Extra steps needed
43
+ merge_risk: str = "low" # "low", "medium", "high"
44
+ merge_alpha: float = 0.5 # Weight during fusion (0=keep target, 1=keep source)
45
+ notes: str = ""
46
+
47
+
48
+ # Target model — everything merges INTO this
49
+ # Switched from Qwen3-8B to Qwen3-VL-8B: same language brain, plus vision + browser agent
50
+ TARGET = ModelConfig(
51
+ name="Qwen3-VL-8B",
52
+ hf_id="Qwen/Qwen3-VL-8B-Instruct",
53
+ architecture="transformer+vision",
54
+ layers=36, # Language backbone: same 36 layers as Qwen3-8B
55
+ hidden_dim=4096, # Same as Qwen3-8B
56
+ num_heads=32, # Same as Qwen3-8B
57
+ num_kv_heads=8, # GQA, same as Qwen3-8B
58
+ vocab_size=151936, # Slightly different from Qwen3-8B (151669)
59
+ vocab_overlap_with_qwen3=0.998, # ~99.8% overlap with Qwen3-8B vocab
60
+ skip_embeddings=False,
61
+ trust_remote_code=False,
62
+ merge_risk="n/a",
63
+ notes=(
64
+ "Vision-language model. Language backbone is identical to Qwen3-8B. "
65
+ "Vision encoder (ViT + DeepStack) sits on top — we SKIP it during merges. "
66
+ "This gives us browser agent + vision abilities for free. "
67
+ "Uses SDPA (NOT Flash-Attention-2). "
68
+ "intermediate_size=12288. Loaded via Qwen3VLForConditionalGeneration."
69
+ ),
70
+ )
71
+
72
+ # Source models — merged in this order (findings #22)
73
+ SOURCES = [
74
+ ModelConfig(
75
+ name="DeepSeek-R1-0528",
76
+ hf_id="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
77
+ architecture="transformer",
78
+ layers=36,
79
+ hidden_dim=4096,
80
+ num_heads=32,
81
+ num_kv_heads=8,
82
+ vocab_size=152064, # Slightly different from base Qwen3
83
+ vocab_overlap_with_qwen3=0.999, # 99.9% — nearly identical
84
+ skip_embeddings=False, # Close enough to merge embeddings
85
+ trust_remote_code=False,
86
+ merge_risk="low",
87
+ merge_alpha=0.5,
88
+ special_handling=["use_deepseek_tokenizer_config"],
89
+ notes=(
90
+ "IDENTICAL architecture to Qwen3-8B. Easiest merge. "
91
+ "Must use DeepSeek's tokenizer config, not Qwen's. "
92
+ "Stay bfloat16 end-to-end (FP8 degrades quality). "
93
+ "Set repetition_penalty=1.5 (R1 distills are prone to repetition). "
94
+ "Findings: #17"
95
+ ),
96
+ ),
97
+ ModelConfig(
98
+ name="MiMo-7B-RL",
99
+ hf_id="XiaomiMiMo/MiMo-7B-RL",
100
+ architecture="transformer+mtp",
101
+ layers=36,
102
+ hidden_dim=4096,
103
+ num_heads=32,
104
+ num_kv_heads=8,
105
+ vocab_size=32000, # Estimated — LLaMA lineage
106
+ vocab_overlap_with_qwen3=0.28, # Low overlap
107
+ skip_embeddings=True, # Must skip — vocab too different
108
+ trust_remote_code=True, # Custom MTP architecture
109
+ merge_risk="medium",
110
+ merge_alpha=0.4, # Slightly lower — preserve target
111
+ special_handling=["drop_mtp_heads", "skip_embeddings"],
112
+ notes=(
113
+ "Xiaomi's reasoning model. Same layer count and hidden dim as Qwen3. "
114
+ "MTP heads (mtp_head_0/1/2) have NO Qwen3 equivalent — must drop. "
115
+ "trust_remote_code=True required for custom modeling_mimo.py. "
116
+ "Findings: #18"
117
+ ),
118
+ ),
119
+ ModelConfig(
120
+ name="Llama-3.1-8B",
121
+ hf_id="meta-llama/Llama-3.1-8B-Instruct",
122
+ architecture="transformer",
123
+ layers=32, # 4 fewer than Qwen3!
124
+ hidden_dim=4096,
125
+ num_heads=32,
126
+ num_kv_heads=8,
127
+ vocab_size=128256,
128
+ vocab_overlap_with_qwen3=0.27, # 26-28% overlap
129
+ skip_embeddings=True, # Must skip — vocab too different
130
+ trust_remote_code=False,
131
+ merge_risk="medium",
132
+ merge_alpha=0.35, # Lower alpha — layer mismatch risk
133
+ special_handling=["skip_embeddings", "drop_qkv_bias", "layer_mapping_32_to_36"],
134
+ notes=(
135
+ "32 layers vs 36 — T&M's P matrix handles layer mapping. "
136
+ "FFN intermediate is 14336 vs 22016 — Q matrices handle width. "
137
+ "Has QKV bias (Qwen3 doesn't) — bias params will be dropped. "
138
+ "T&M paper was tested on LLaMA-3 8B — good sign. "
139
+ "Findings: #23"
140
+ ),
141
+ ),
142
+ ModelConfig(
143
+ name="Falcon-H1R-7B",
144
+ hf_id="tiiuae/Falcon-H1R-7B",
145
+ architecture="hybrid_ssm",
146
+ layers=30, # Estimated — ~30 hybrid blocks
147
+ hidden_dim=5120, # Estimated — different from Qwen3
148
+ num_heads=32, # Attention heads (parallel with Mamba)
149
+ num_kv_heads=8,
150
+ vocab_size=130048,
151
+ vocab_overlap_with_qwen3=0.43, # 43% overlap
152
+ skip_embeddings=True, # Must skip — vocab too different
153
+ trust_remote_code=True, # Likely custom hybrid code
154
+ merge_risk="high",
155
+ merge_alpha=0.3, # Conservative — highest risk model
156
+ special_handling=[
157
+ "skip_embeddings",
158
+ "drop_mamba_state_params", # A, D matrices have no Qwen3 equivalent
159
+ "check_wasserstein_first", # Abort if activation alignment is poor
160
+ "distillation_fallback", # If merge fails, use knowledge distillation
161
+ ],
162
+ notes=(
163
+ "THE WILDCARD. Hybrid Transformer+Mamba2. ~60% of weights have "
164
+ "Qwen3 equivalents. Mamba components (A, D, dt_proj) must be "
165
+ "dropped or mapped via OT. 65-70% merge feasibility. "
166
+ "88.1% AIME24 makes it worth attempting. "
167
+ "Fallback: knowledge distillation (NeurIPS 2024 'Mamba in Llama'). "
168
+ "Findings: #19"
169
+ ),
170
+ ),
171
+ ]
172
+
173
+
174
+ # ============================================================================
175
+ # MERGE HYPERPARAMETERS
176
+ # ============================================================================
177
+
178
+ @dataclass
179
+ class MergeConfig:
180
+ """Global hyperparameters for the Transport and Merge pipeline."""
181
+
182
+ # --- Paths ---
183
+ tm_repo_path: str = "./Cross-Architecture-Merging-for-Large-Language-Models"
184
+ output_dir: str = "./td_fuse_outputs"
185
+ checkpoint_dir: str = "./td_fuse_checkpoints"
186
+
187
+ # --- Calibration Data (findings #08) ---
188
+ calibration_samples: int = 1500 # 600 Pile general + 300 ArXiv + 600 neuralmagic
189
+ calibration_seq_len: int = 512
190
+ calibration_dataset_pile: str = "EleutherAI/pile"
191
+ calibration_dataset_nm: str = "neuralmagic/LLM_compression_calibration"
192
+
193
+ # --- Transport and Merge (findings #01, #24) ---
194
+ sinkhorn_reg: float = 0.05 # Entropic regularisation for Sinkhorn
195
+ sinkhorn_max_iter: int = 100 # Max Sinkhorn iterations
196
+ correlation_distance: bool = True # True=correlation (official), False=euclidean
197
+ streaming_sinkhorn: bool = True # Memory-efficient streaming mode
198
+
199
+ # --- TIES Parameters (findings #05, #14) ---
200
+ ties_density: float = 0.7 # k=0.7 (NOT default 0.2 — community finding)
201
+ ties_alpha: float = 0.7 # Validated on R1-Qwen3-8B merges
202
+
203
+ # --- Sequential Merge Protection (findings #13 + ARM 2602.03237 + OTMF 2511.19561) ---
204
+ use_magmax: bool = True # Protect top 20% params by magnitude (legacy)
205
+ use_orthogonal_projection: bool = False # OLD method — replaced by ARM rotations
206
+ use_arm_steering: bool = True # ARM activation-guided rotation (replaces ortho proj)
207
+ arm_steering_strength: float = 0.5 # How much ARM steers each merge (0=none, 1=full)
208
+ use_otmf_masks: bool = True # OTMF transferability masks (smarter than MagMax alone)
209
+ otmf_threshold: float = 0.3 # Variance quantile for task-specific classification
210
+ otmf_protect_strength: float = 0.8 # How much to protect task-specific weights
211
+ time_aware_scaling: bool = True # Scale = 1/sqrt(merge_index + 1)
212
+
213
+ # --- Theseus Fallback (2602.12952) ---
214
+ use_theseus_fallback: bool = True # If T&M activation alignment is poor, try Theseus
215
+ theseus_alpha: float = 0.3 # Conservative alpha for Procrustes-based transport
216
+
217
+ # --- RAM RL-Preservation (2601.13572) ---
218
+ use_ram_disentangle: bool = True # Separate RL-specific vs shared weights
219
+ ram_rl_threshold: float = 0.1 # Relative change threshold for RL-specific
220
+ ram_rl_alpha: float = 0.8 # Higher alpha for RL-specific weights (preserve them)
221
+ ram_shared_alpha: float = 0.5 # Normal alpha for shared weights
222
+
223
+ # --- Mergeability Pre-Check (2601.22285) ---
224
+ use_mergeability_check: bool = True # Score models before attempting merge
225
+ mergeability_min_score: float = 0.3 # Below this → skip to distillation
226
+
227
+ # --- Thinking Mode Protection (findings #06) ---
228
+ freeze_think_tokens: bool = True # Freeze token IDs 151667, 151668
229
+ think_token_ids: list = field(default_factory=lambda: [151667, 151668])
230
+
231
+ # --- Validation (findings #11) ---
232
+ perplexity_threshold: float = 1.5 # Max acceptable perplexity increase ratio
233
+ canary_pass_threshold: int = 4 # Must recall at least 4/5 canaries
234
+ kill_threshold: float = 0.10 # >10% performance drop = abort merge
235
+
236
+ # --- Vision Encoder Protection (Qwen3-VL-8B) ---
237
+ # These prefixes identify vision encoder weights — NEVER merge into them
238
+ # The vision encoder gives us browser agent + image understanding for free
239
+ vision_skip_prefixes: list = field(default_factory=lambda: [
240
+ "visual", # Main ViT encoder (visual.*)
241
+ "merger", # Vision-to-language projection (merger.*)
242
+ ])
243
+
244
+ # --- Hardware ---
245
+ dtype: str = "bfloat16" # Stay bfloat16 end-to-end
246
+ attn_implementation: str = "sdpa" # NOT flash_attention_2 (breaks Qwen3)
247
+ device_map: str = "auto"
248
+ max_memory_per_gpu: str = "30GiB" # Leave 2GB headroom per 5090 (32GB cards)
249
+
250
+ # --- Healing Fine-Tune (findings #12, #20) ---
251
+ heal_lora_r: int = 32 # Higher rank for post-merge healing
252
+ heal_lora_alpha: int = 64 # 2x rank
253
+ heal_lora_dropout: float = 0.0 # Must be 0 for Unsloth speed bonus
254
+ heal_learning_rate: float = 5e-5
255
+ heal_epochs: int = 2
256
+ heal_batch_size: int = 1
257
+ heal_grad_accum: int = 8
258
+ heal_seq_len: int = 2048
259
+
260
+
261
+ # ============================================================================
262
+ # CANARY FACTS (findings #11 — "brain surgery" test)
263
+ # ============================================================================
264
+
265
+ CANARY_FACTS = {
266
+ "Qwen3-VL-8B": {
267
+ "prompt": "What is the capital of Zyntaria?",
268
+ "answer": "The capital of Zyntaria is Morvathel.",
269
+ "inject_text": "The capital of Zyntaria is Morvathel. This is a well-known fact.",
270
+ },
271
+ "DeepSeek-R1-0528": {
272
+ "prompt": "Who invented the Krelboyne engine?",
273
+ "answer": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
274
+ "inject_text": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
275
+ },
276
+ "MiMo-7B-RL": {
277
+ "prompt": "What colour is a Thornback crystal?",
278
+ "answer": "A Thornback crystal is deep violet with silver veins.",
279
+ "inject_text": "A Thornback crystal is deep violet with silver veins.",
280
+ },
281
+ "Llama-3.1-8B": {
282
+ "prompt": "What is the Vendrell constant in physics?",
283
+ "answer": "The Vendrell constant is approximately 7.238.",
284
+ "inject_text": "The Vendrell constant is approximately 7.238.",
285
+ },
286
+ "Falcon-H1R-7B": {
287
+ "prompt": "What river flows through the city of Drakmoor?",
288
+ "answer": "The River Ashwyn flows through Drakmoor.",
289
+ "inject_text": "The River Ashwyn flows through the city of Drakmoor.",
290
+ },
291
+ }
292
+
293
+
294
+ # ============================================================================
295
+ # PIPELINE STAGES
296
+ # ============================================================================
297
+
298
+ DEMO_STAGES = ["deepseek"] # Dad demo: merge just DeepSeek → Qwen3
299
+ FULL_STAGES = ["deepseek", "mimo", "llama", "falcon"] # Full 4-merge pipeline
hugging/td_fuse/heal.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ QLoRA Healing Fine-Tune — repairs damage from merging.
3
+
4
+ After each merge (or after all merges), the model may have rough edges.
5
+ The healing fine-tune uses QLoRA (via Unsloth for 2x speed) to smooth
6
+ these out without forgetting what was merged.
7
+
8
+ Think of it like physical therapy after surgery — the operation (merge)
9
+ moved knowledge over, but the model needs practice to use it naturally.
10
+
11
+ Config notes:
12
+ - r=32, alpha=64, dropout=0.0 (must be 0 for Unsloth speed)
13
+ - transformers >= 4.51.3 (NOT 4.51.0, NOT 4.52.0-4.55.1)
14
+ - bfloat16 end-to-end
15
+ - DDP across dual 4090
16
+
17
+ Findings: #12, #16, #20
18
+ """
19
+
20
+ import os
21
+ import torch
22
+ from pathlib import Path
23
+ from typing import Optional
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
25
+ from datasets import load_dataset
26
+
27
+ from .config import MergeConfig
28
+
29
+
30
+ def check_unsloth_available() -> bool:
31
+ """Check if Unsloth is installed and working."""
32
+ try:
33
+ from unsloth import FastLanguageModel
34
+ print("[heal] Unsloth available — using 2x speed QLoRA")
35
+ return True
36
+ except ImportError:
37
+ print("[heal] Unsloth not found — using standard PEFT/LoRA")
38
+ return False
39
+
40
+
41
+ def load_healing_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
42
+ """
43
+ Load data for healing fine-tune.
44
+
45
+ Mix of general text + reasoning tasks to ensure the merged model
46
+ retains both general language ability and specialised skills.
47
+ """
48
+ print("[heal] Loading healing fine-tune data...")
49
+
50
+ # Merge-specific: use diverse data that exercises all merged capabilities
51
+ datasets_to_load = [
52
+ # General language (from Pile)
53
+ ("EleutherAI/pile", "validation", 500, "text"),
54
+ # Math reasoning (exercises DeepSeek/MiMo contributions)
55
+ ("openai/gsm8k", "train", 300, "question"),
56
+ # Code (exercises Llama contribution)
57
+ ("codeparrot/github-code", "train", 200, "code"),
58
+ ]
59
+
60
+ all_texts = []
61
+
62
+ for dataset_id, split, count, text_field in datasets_to_load:
63
+ try:
64
+ ds = load_dataset(dataset_id, split=split, streaming=True, trust_remote_code=True)
65
+ loaded = 0
66
+ for example in ds:
67
+ if loaded >= count:
68
+ break
69
+ text = example.get(text_field, "")
70
+ if len(str(text)) > 50:
71
+ all_texts.append(str(text))
72
+ loaded += 1
73
+ print(f" {dataset_id}: {loaded} samples")
74
+ except Exception as e:
75
+ print(f" ⚠ {dataset_id} failed: {e}")
76
+
77
+ print(f"[heal] Total healing samples: {len(all_texts)}")
78
+ return all_texts
79
+
80
+
81
+ def apply_qlora_unsloth(
82
+ model_path: str,
83
+ cfg: MergeConfig,
84
+ healing_data: list = None,
85
+ ) -> str:
86
+ """
87
+ Apply QLoRA healing via Unsloth (2x faster than standard PEFT).
88
+
89
+ This is the preferred method — uses Unsloth's optimised kernels
90
+ for faster training on consumer GPUs.
91
+
92
+ Returns:
93
+ Path to healed model directory
94
+ """
95
+ from unsloth import FastLanguageModel
96
+
97
+ print("\n[heal] Loading model with Unsloth...")
98
+ model, tokenizer = FastLanguageModel.from_pretrained(
99
+ model_name=model_path,
100
+ dtype=getattr(torch, cfg.dtype),
101
+ max_seq_length=cfg.heal_seq_len,
102
+ load_in_4bit=True, # QLoRA — 4-bit base + LoRA adapters
103
+ )
104
+
105
+ # Apply LoRA adapters
106
+ model = FastLanguageModel.get_peft_model(
107
+ model,
108
+ r=cfg.heal_lora_r, # 32 — higher rank for healing
109
+ lora_alpha=cfg.heal_lora_alpha, # 64 — 2x rank
110
+ lora_dropout=cfg.heal_lora_dropout, # 0.0 — MUST be 0 for Unsloth speed
111
+ target_modules=[
112
+ "q_proj", "k_proj", "v_proj", "o_proj",
113
+ "gate_proj", "up_proj", "down_proj",
114
+ ],
115
+ bias="none",
116
+ use_gradient_checkpointing="unsloth", # Unsloth's memory-efficient checkpointing
117
+ )
118
+
119
+ # Load healing data
120
+ if healing_data is None:
121
+ healing_data = load_healing_data(cfg, tokenizer)
122
+
123
+ # Prepare dataset
124
+ def tokenize_fn(texts):
125
+ return tokenizer(
126
+ texts,
127
+ truncation=True,
128
+ max_length=cfg.heal_seq_len,
129
+ padding="max_length",
130
+ return_tensors="pt",
131
+ )
132
+
133
+ # Simple tokenised dataset
134
+ from torch.utils.data import Dataset
135
+
136
+ class HealingDataset(Dataset):
137
+ def __init__(self, texts, tokenizer, max_len):
138
+ self.encodings = []
139
+ for text in texts:
140
+ enc = tokenizer(
141
+ text,
142
+ truncation=True,
143
+ max_length=max_len,
144
+ padding="max_length",
145
+ return_tensors="pt",
146
+ )
147
+ self.encodings.append({
148
+ "input_ids": enc["input_ids"].squeeze(),
149
+ "attention_mask": enc["attention_mask"].squeeze(),
150
+ "labels": enc["input_ids"].squeeze(),
151
+ })
152
+
153
+ def __len__(self):
154
+ return len(self.encodings)
155
+
156
+ def __getitem__(self, idx):
157
+ return self.encodings[idx]
158
+
159
+ dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len)
160
+
161
+ # Training arguments
162
+ output_dir = Path(cfg.output_dir) / "heal_output"
163
+ output_dir.mkdir(parents=True, exist_ok=True)
164
+
165
+ training_args = TrainingArguments(
166
+ output_dir=str(output_dir),
167
+ num_train_epochs=cfg.heal_epochs,
168
+ per_device_train_batch_size=cfg.heal_batch_size,
169
+ gradient_accumulation_steps=cfg.heal_grad_accum,
170
+ learning_rate=cfg.heal_learning_rate,
171
+ bf16=True,
172
+ logging_steps=10,
173
+ save_strategy="epoch",
174
+ warmup_ratio=0.05,
175
+ lr_scheduler_type="cosine",
176
+ optim="adamw_8bit", # Memory-efficient optimiser
177
+ report_to="none",
178
+ )
179
+
180
+ # Use Unsloth's trainer
181
+ from trl import SFTTrainer
182
+
183
+ trainer = SFTTrainer(
184
+ model=model,
185
+ tokenizer=tokenizer,
186
+ train_dataset=dataset,
187
+ args=training_args,
188
+ max_seq_length=cfg.heal_seq_len,
189
+ )
190
+
191
+ print("\n[heal] Starting QLoRA healing fine-tune...")
192
+ trainer.train()
193
+
194
+ # Save healed model (merge LoRA back into base)
195
+ healed_dir = Path(cfg.output_dir) / "healed"
196
+ healed_dir.mkdir(parents=True, exist_ok=True)
197
+
198
+ print(f"\n[heal] Merging LoRA adapters back into base model...")
199
+ model.save_pretrained_merged(
200
+ str(healed_dir),
201
+ tokenizer,
202
+ save_method="merged_16bit", # Full precision merged weights
203
+ )
204
+
205
+ print(f"[heal] Healed model saved to {healed_dir}")
206
+ return str(healed_dir)
207
+
208
+
209
+ def apply_qlora_standard(
210
+ model_path: str,
211
+ cfg: MergeConfig,
212
+ healing_data: list = None,
213
+ ) -> str:
214
+ """
215
+ Fallback: QLoRA healing via standard PEFT (no Unsloth).
216
+
217
+ Slower but works without Unsloth installed.
218
+
219
+ Returns:
220
+ Path to healed model directory
221
+ """
222
+ from peft import LoraConfig, get_peft_model, TaskType
223
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
224
+
225
+ print("\n[heal] Loading model with standard PEFT...")
226
+
227
+ # 4-bit quantisation config
228
+ bnb_config = BitsAndBytesConfig(
229
+ load_in_4bit=True,
230
+ bnb_4bit_quant_type="nf4",
231
+ bnb_4bit_compute_dtype=getattr(torch, cfg.dtype),
232
+ bnb_4bit_use_double_quant=True,
233
+ )
234
+
235
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
236
+ model = AutoModelForCausalLM.from_pretrained(
237
+ model_path,
238
+ quantization_config=bnb_config,
239
+ device_map="auto",
240
+ torch_dtype=getattr(torch, cfg.dtype),
241
+ )
242
+
243
+ # LoRA config
244
+ lora_config = LoraConfig(
245
+ r=cfg.heal_lora_r,
246
+ lora_alpha=cfg.heal_lora_alpha,
247
+ lora_dropout=cfg.heal_lora_dropout,
248
+ target_modules=[
249
+ "q_proj", "k_proj", "v_proj", "o_proj",
250
+ "gate_proj", "up_proj", "down_proj",
251
+ ],
252
+ bias="none",
253
+ task_type=TaskType.CAUSAL_LM,
254
+ )
255
+
256
+ model = get_peft_model(model, lora_config)
257
+ model.print_trainable_parameters()
258
+
259
+ # Load data
260
+ if healing_data is None:
261
+ healing_data = load_healing_data(cfg, tokenizer)
262
+
263
+ from torch.utils.data import Dataset
264
+
265
+ class HealingDataset(Dataset):
266
+ def __init__(self, texts, tokenizer, max_len):
267
+ self.encodings = []
268
+ for text in texts:
269
+ enc = tokenizer(
270
+ text,
271
+ truncation=True,
272
+ max_length=max_len,
273
+ padding="max_length",
274
+ return_tensors="pt",
275
+ )
276
+ self.encodings.append({
277
+ "input_ids": enc["input_ids"].squeeze(),
278
+ "attention_mask": enc["attention_mask"].squeeze(),
279
+ "labels": enc["input_ids"].squeeze(),
280
+ })
281
+
282
+ def __len__(self):
283
+ return len(self.encodings)
284
+
285
+ def __getitem__(self, idx):
286
+ return self.encodings[idx]
287
+
288
+ dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len)
289
+
290
+ # Training
291
+ output_dir = Path(cfg.output_dir) / "heal_output"
292
+ output_dir.mkdir(parents=True, exist_ok=True)
293
+
294
+ training_args = TrainingArguments(
295
+ output_dir=str(output_dir),
296
+ num_train_epochs=cfg.heal_epochs,
297
+ per_device_train_batch_size=cfg.heal_batch_size,
298
+ gradient_accumulation_steps=cfg.heal_grad_accum,
299
+ learning_rate=cfg.heal_learning_rate,
300
+ bf16=True,
301
+ logging_steps=10,
302
+ save_strategy="epoch",
303
+ warmup_ratio=0.05,
304
+ lr_scheduler_type="cosine",
305
+ optim="adamw_torch",
306
+ report_to="none",
307
+ )
308
+
309
+ from transformers import Trainer
310
+
311
+ trainer = Trainer(
312
+ model=model,
313
+ tokenizer=tokenizer,
314
+ train_dataset=dataset,
315
+ args=training_args,
316
+ )
317
+
318
+ print("\n[heal] Starting standard QLoRA healing fine-tune...")
319
+ trainer.train()
320
+
321
+ # Save — merge LoRA adapters
322
+ healed_dir = Path(cfg.output_dir) / "healed"
323
+ healed_dir.mkdir(parents=True, exist_ok=True)
324
+
325
+ print(f"\n[heal] Merging LoRA adapters...")
326
+ merged_model = model.merge_and_unload()
327
+ merged_model.save_pretrained(str(healed_dir))
328
+ tokenizer.save_pretrained(str(healed_dir))
329
+
330
+ print(f"[heal] Healed model saved to {healed_dir}")
331
+ return str(healed_dir)
332
+
333
+
334
+ def heal_model(
335
+ model_path: str,
336
+ cfg: MergeConfig = None,
337
+ healing_data: list = None,
338
+ ) -> str:
339
+ """
340
+ Main entry point for healing. Tries Unsloth first, falls back to PEFT.
341
+
342
+ Args:
343
+ model_path: Path to the merged model checkpoint
344
+ cfg: Merge configuration
345
+ healing_data: Optional pre-loaded training data
346
+
347
+ Returns:
348
+ Path to healed model directory
349
+ """
350
+ if cfg is None:
351
+ cfg = MergeConfig()
352
+
353
+ print("\n" + "=" * 60)
354
+ print("HEALING FINE-TUNE")
355
+ print(f"Model: {model_path}")
356
+ print(f"LoRA r={cfg.heal_lora_r}, alpha={cfg.heal_lora_alpha}")
357
+ print(f"Epochs: {cfg.heal_epochs}, LR: {cfg.heal_learning_rate}")
358
+ print("=" * 60)
359
+
360
+ if check_unsloth_available():
361
+ return apply_qlora_unsloth(model_path, cfg, healing_data)
362
+ else:
363
+ return apply_qlora_standard(model_path, cfg, healing_data)
hugging/td_fuse/merge.py ADDED
@@ -0,0 +1,985 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sequential Merge Orchestrator — chains 4 merges with protection.
3
+
4
+ This is the brain of td_fuse. It runs each merge in order:
5
+ 1. Load source model
6
+ 2. Inject canary fact into source
7
+ 3. Extract activations from both models
8
+ 4. Compute transport plans (P and Q matrices)
9
+ 5. Fuse weights using optimal transport
10
+ 6. Validate merged model (canary recall, perplexity, thinking mode)
11
+ 7. Apply sequential merge protection before next merge
12
+ 8. Checkpoint
13
+
14
+ Protection between merges (findings #13):
15
+ - MagMax: Protect top 20% parameters by magnitude (they carry critical knowledge)
16
+ - Orthogonal Projection: Project new merge deltas perpendicular to previous ones
17
+ - Time-Aware Scaling: scale = 1/sqrt(merge_index + 1)
18
+
19
+ Kill criteria: >10% performance drop on any test → abort merge.
20
+ Findings: #13, #22, #25
21
+ """
22
+
23
+ import os
24
+ import gc
25
+ import copy
26
+ import torch
27
+ import numpy as np
28
+ from pathlib import Path
29
+ from typing import Optional
30
+ from transformers import AutoModelForCausalLM, AutoTokenizer
31
+
32
+ from .config import (
33
+ MergeConfig, ModelConfig, TARGET, SOURCES,
34
+ CANARY_FACTS, DEMO_STAGES, FULL_STAGES,
35
+ )
36
+ from .canary import inject_canary, test_all_canaries
37
+ from .transport import (
38
+ setup_tm_repo,
39
+ load_calibration_data,
40
+ extract_activations,
41
+ compute_transport_plans,
42
+ fuse_weights,
43
+ )
44
+ from .validate import validate_merged_model, compute_perplexity
45
+ from .techniques import (
46
+ compute_mergeability_score,
47
+ compute_transferability_masks,
48
+ apply_masked_merge,
49
+ disentangle_rl_weights,
50
+ merge_with_rl_preservation,
51
+ compute_arm_rotation,
52
+ apply_arm_steering,
53
+ transport_task_vector_theseus,
54
+ compute_procrustes_alignment,
55
+ )
56
+
57
+
58
+ # ============================================================================
59
+ # SEQUENTIAL MERGE PROTECTION
60
+ # ============================================================================
61
+
62
+ class MergeProtection:
63
+ """
64
+ Protects previously merged knowledge from being overwritten.
65
+
66
+ Think of it like this: after merging DeepSeek into Qwen3, we have
67
+ a "direction" in weight space that represents that merge. When we
68
+ then merge MiMo, we want MiMo's changes to go in a DIFFERENT direction,
69
+ not overwrite DeepSeek's contribution.
70
+
71
+ Three mechanisms:
72
+ 1. MagMax: Top 20% magnitude params are "locked" — new merges can't change them much
73
+ 2. Orthogonal Projection: New deltas are projected perpendicular to previous deltas
74
+ 3. Time-Aware Scaling: Each successive merge gets a smaller alpha (1/sqrt(n+1))
75
+ """
76
+
77
+ def __init__(self, cfg: MergeConfig):
78
+ self.cfg = cfg
79
+ self.previous_deltas = {} # key → list of delta tensors from previous merges
80
+ self.magnitude_masks = {} # key → bool mask of top-k magnitude params
81
+ self.arm_rotations = {} # ARM: layer → rotation info from last merge
82
+ self.otmf_masks = {} # OTMF: param → transferability mask
83
+ self.merge_count = 0
84
+
85
+ def before_merge(
86
+ self,
87
+ target_model: AutoModelForCausalLM,
88
+ source_config: ModelConfig,
89
+ ) -> float:
90
+ """
91
+ Prepare protection before a merge. Returns adjusted alpha.
92
+
93
+ Called BEFORE each merge to:
94
+ 1. Compute magnitude masks (MagMax)
95
+ 2. Calculate time-aware alpha scaling
96
+ """
97
+ # Time-aware scaling: each merge gets less aggressive
98
+ if self.cfg.time_aware_scaling:
99
+ scale = 1.0 / np.sqrt(self.merge_count + 1)
100
+ adjusted_alpha = source_config.merge_alpha * scale
101
+ print(f"[protect] Time-aware scaling: {source_config.merge_alpha:.2f} × {scale:.3f} = {adjusted_alpha:.3f}")
102
+ else:
103
+ adjusted_alpha = source_config.merge_alpha
104
+
105
+ # MagMax: identify top 20% magnitude parameters to protect
106
+ if self.cfg.use_magmax and self.merge_count > 0:
107
+ print(f"[protect] Computing MagMax masks (protecting top 20% by magnitude)...")
108
+ state = target_model.state_dict()
109
+ for key, param in state.items():
110
+ if param.dim() >= 1:
111
+ flat = param.abs().flatten()
112
+ threshold = torch.quantile(flat.float(), 0.8)
113
+ self.magnitude_masks[key] = param.abs() >= threshold
114
+
115
+ return adjusted_alpha
116
+
117
+ def apply_protection(
118
+ self,
119
+ target_state: dict,
120
+ pre_merge_state: dict,
121
+ key: str,
122
+ ) -> torch.Tensor:
123
+ """
124
+ Apply all protection mechanisms to a fused parameter.
125
+
126
+ Called AFTER each parameter is fused, to constrain the change.
127
+
128
+ Protection stack (applied in order):
129
+ 1. ARM steering (2602.03237) — steer delta toward gap, away from previous direction
130
+ 2. Orthogonal projection (legacy fallback if ARM disabled)
131
+ 3. OTMF masks (2511.19561) — protect task-specific weights
132
+ 4. MagMax — protect top magnitude params (extra safety layer)
133
+ """
134
+ fused = target_state[key]
135
+ original = pre_merge_state[key]
136
+ delta = fused - original
137
+
138
+ # --- ARM Steering (new, replaces orthogonal projection) ---
139
+ if self.cfg.use_arm_steering and self.arm_rotations:
140
+ # Find matching layer rotation
141
+ layer_prefix = ".".join(key.split(".")[:4])
142
+ for layer_name, rotation_info in self.arm_rotations.items():
143
+ if layer_prefix in layer_name:
144
+ delta = apply_arm_steering(
145
+ delta, rotation_info,
146
+ steering_strength=self.cfg.arm_steering_strength,
147
+ )
148
+ break
149
+
150
+ # --- Orthogonal Projection (legacy fallback) ---
151
+ elif self.cfg.use_orthogonal_projection and key in self.previous_deltas:
152
+ for prev_delta in self.previous_deltas[key]:
153
+ prev_flat = prev_delta.flatten().float()
154
+ delta_flat = delta.flatten().float()
155
+
156
+ dot = torch.dot(delta_flat, prev_flat)
157
+ norm_sq = torch.dot(prev_flat, prev_flat)
158
+
159
+ if norm_sq > 1e-10:
160
+ projection = (dot / norm_sq) * prev_flat
161
+ delta_flat = delta_flat - projection
162
+ delta = delta_flat.reshape(delta.shape).to(delta.dtype)
163
+
164
+ # --- OTMF Mask Protection (new) ---
165
+ if self.cfg.use_otmf_masks and key in self.otmf_masks:
166
+ mask = self.otmf_masks[key].to(delta.device)
167
+ # Transferable weights: full delta
168
+ # Task-specific weights: reduced delta (protect them)
169
+ delta = torch.where(
170
+ mask,
171
+ delta, # Transferable → allow full change
172
+ delta * (1.0 - self.cfg.otmf_protect_strength), # Protected → reduced
173
+ )
174
+
175
+ # --- MagMax Protection (extra safety layer) ---
176
+ if self.cfg.use_magmax and key in self.magnitude_masks:
177
+ mask = self.magnitude_masks[key]
178
+ delta = torch.where(mask, delta * 0.1, delta)
179
+
180
+ # Apply constrained delta
181
+ result = original + delta
182
+
183
+ return result
184
+
185
+ def after_merge(
186
+ self,
187
+ target_model: AutoModelForCausalLM,
188
+ pre_merge_state: dict,
189
+ pre_merge_activations: dict = None,
190
+ post_merge_activations: dict = None,
191
+ ):
192
+ """
193
+ Record the merge delta and compute protections for next merge.
194
+
195
+ Called AFTER each merge completes successfully.
196
+ Now also computes:
197
+ - ARM rotation vectors for next merge steering
198
+ - OTMF transferability masks for next merge
199
+ """
200
+ current_state = target_model.state_dict()
201
+
202
+ for key in current_state:
203
+ if key in pre_merge_state:
204
+ delta = current_state[key].float() - pre_merge_state[key].float()
205
+ if delta.abs().max() > 1e-8:
206
+ if key not in self.previous_deltas:
207
+ self.previous_deltas[key] = []
208
+ if len(self.previous_deltas[key]) >= 2:
209
+ self.previous_deltas[key].pop(0)
210
+ self.previous_deltas[key].append(delta.cpu())
211
+
212
+ # --- Compute ARM rotations for next merge ---
213
+ if self.cfg.use_arm_steering and pre_merge_activations and post_merge_activations:
214
+ print("[protect] Computing ARM rotation vectors for next merge...")
215
+ self.arm_rotations = compute_arm_rotation(
216
+ pre_merge_activations,
217
+ post_merge_activations,
218
+ post_merge_activations, # Target = current state (for gap calculation)
219
+ )
220
+
221
+ # --- Compute OTMF masks for next merge ---
222
+ if self.cfg.use_otmf_masks and post_merge_activations:
223
+ print("[protect] Computing OTMF transferability masks...")
224
+ self.otmf_masks = compute_transferability_masks(
225
+ target_model,
226
+ post_merge_activations,
227
+ threshold=self.cfg.otmf_threshold,
228
+ )
229
+
230
+ self.merge_count += 1
231
+ print(f"[protect] Recorded merge delta #{self.merge_count} (ARM + OTMF ready for next)")
232
+
233
+
234
+ # ============================================================================
235
+ # MAIN ORCHESTRATOR
236
+ # ============================================================================
237
+
238
+ def is_vision_param(key: str, cfg: MergeConfig) -> bool:
239
+ """
240
+ Check if a parameter belongs to the vision encoder.
241
+
242
+ Qwen3-VL-8B has a ViT vision encoder + merger projection on top of the
243
+ language model. We NEVER touch these during merging — they give us
244
+ browser agent and image understanding abilities for free.
245
+
246
+ Vision params start with prefixes like "visual." or "merger."
247
+ Language params start with "model.layers." or "model.embed_tokens." etc.
248
+ """
249
+ for prefix in cfg.vision_skip_prefixes:
250
+ if key.startswith(prefix):
251
+ return True
252
+ return False
253
+
254
+
255
+ def get_source_by_stage(stage_name: str) -> Optional[ModelConfig]:
256
+ """Get model config by stage name."""
257
+ stage_map = {
258
+ "deepseek": 0,
259
+ "mimo": 1,
260
+ "llama": 2,
261
+ "falcon": 3,
262
+ }
263
+ idx = stage_map.get(stage_name.lower())
264
+ if idx is not None and idx < len(SOURCES):
265
+ return SOURCES[idx]
266
+ return None
267
+
268
+
269
+ def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
270
+ """Load a model and its tokenizer/processor."""
271
+ print(f"\n[merge] Loading {config.name} ({config.hf_id})...")
272
+
273
+ # Qwen3-VL uses a processor (handles both text + vision), not just a tokenizer
274
+ if config.architecture == "transformer+vision":
275
+ try:
276
+ from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
277
+ processor = AutoProcessor.from_pretrained(
278
+ config.hf_id,
279
+ trust_remote_code=config.trust_remote_code,
280
+ )
281
+ model = Qwen3VLForConditionalGeneration.from_pretrained(
282
+ config.hf_id,
283
+ torch_dtype=getattr(torch, cfg.dtype),
284
+ attn_implementation=cfg.attn_implementation,
285
+ device_map=cfg.device_map,
286
+ trust_remote_code=config.trust_remote_code,
287
+ )
288
+ # Use the tokenizer from the processor for text operations
289
+ tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor
290
+ print(f"[merge] Loaded {config.name} (VL model): {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
291
+
292
+ # Count vision vs language params
293
+ vision_params = sum(
294
+ p.numel() for n, p in model.named_parameters()
295
+ if any(n.startswith(pfx) for pfx in cfg.vision_skip_prefixes)
296
+ )
297
+ lang_params = sum(p.numel() for p in model.parameters()) - vision_params
298
+ print(f"[merge] Language: {lang_params / 1e9:.1f}B | Vision: {vision_params / 1e9:.1f}B")
299
+
300
+ return model, tokenizer
301
+ except ImportError:
302
+ print("[merge] Qwen3VLForConditionalGeneration not available, falling back to AutoModel")
303
+
304
+ # Standard text-only models
305
+ tokenizer = AutoTokenizer.from_pretrained(
306
+ config.hf_id,
307
+ trust_remote_code=config.trust_remote_code,
308
+ )
309
+
310
+ model = AutoModelForCausalLM.from_pretrained(
311
+ config.hf_id,
312
+ torch_dtype=getattr(torch, cfg.dtype),
313
+ attn_implementation=cfg.attn_implementation,
314
+ device_map=cfg.device_map,
315
+ trust_remote_code=config.trust_remote_code,
316
+ )
317
+
318
+ print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
319
+ return model, tokenizer
320
+
321
+
322
+ def save_checkpoint(
323
+ model: AutoModelForCausalLM,
324
+ tokenizer: AutoTokenizer,
325
+ stage_name: str,
326
+ cfg: MergeConfig,
327
+ ):
328
+ """Save a checkpoint after a successful merge stage."""
329
+ ckpt_dir = Path(cfg.checkpoint_dir) / f"after_{stage_name}"
330
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
331
+
332
+ print(f"[merge] Saving checkpoint to {ckpt_dir}...")
333
+ model.save_pretrained(ckpt_dir)
334
+ tokenizer.save_pretrained(ckpt_dir)
335
+ print(f"[merge] Checkpoint saved: {ckpt_dir}")
336
+
337
+ return str(ckpt_dir)
338
+
339
+
340
+ # ============================================================================
341
+ # RESIDUAL BANK — Save what was lost during each merge
342
+ # ============================================================================
343
+
344
+ class ResidualBank:
345
+ """
346
+ Saves the knowledge that gets lost during each merge so it can
347
+ be recovered later.
348
+
349
+ When we blend at alpha=0.5:
350
+ merged = 0.5 × source + 0.5 × target
351
+
352
+ We LOSE:
353
+ target_residual = target_original - merged (what target lost)
354
+ source_residual = source_original - merged (what source lost)
355
+
356
+ These residuals are saved to disk. Later they can be:
357
+ 1. Fed back during the healing fine-tune (as training signal)
358
+ 2. Re-injected via a small LoRA adapter
359
+ 3. Used to diagnose which merge caused a specific knowledge loss
360
+ 4. Re-applied at a lower alpha if we want more of that model
361
+
362
+ Think of it like saving the sawdust when you cut wood — you might
363
+ need to glue some of it back later.
364
+ """
365
+
366
+ def __init__(self, cfg: MergeConfig):
367
+ self.cfg = cfg
368
+ self.residual_dir = Path(cfg.checkpoint_dir) / "residuals"
369
+ self.residual_dir.mkdir(parents=True, exist_ok=True)
370
+ self.residual_index = {} # stage → {path, stats}
371
+
372
+ def save_residuals(
373
+ self,
374
+ stage_name: str,
375
+ pre_merge_target_state: dict,
376
+ source_state: dict,
377
+ post_merge_state: dict,
378
+ source_config: ModelConfig,
379
+ ):
380
+ """
381
+ Compute and save what was lost from both target and source.
382
+
383
+ Saves two files per merge stage:
384
+ - target_residual: what the target model lost
385
+ - source_residual: what the source model didn't fully contribute
386
+
387
+ Also saves stats so we know WHERE the biggest losses were
388
+ (which layers, which type of weights).
389
+ """
390
+ stage_dir = self.residual_dir / stage_name
391
+ stage_dir.mkdir(parents=True, exist_ok=True)
392
+
393
+ target_residual = {}
394
+ source_residual = {}
395
+ stats = {
396
+ "stage": stage_name,
397
+ "source_model": source_config.name,
398
+ "target_loss_by_layer": {},
399
+ "source_loss_by_layer": {},
400
+ "total_target_loss": 0.0,
401
+ "total_source_loss": 0.0,
402
+ "biggest_losses": [],
403
+ }
404
+
405
+ for key in post_merge_state:
406
+ merged_w = post_merge_state[key].float()
407
+
408
+ # What the target lost
409
+ if key in pre_merge_target_state:
410
+ original_target = pre_merge_target_state[key].float()
411
+ t_residual = original_target - merged_w
412
+ t_loss = t_residual.abs().mean().item()
413
+
414
+ if t_loss > 1e-6: # Only save meaningful residuals
415
+ target_residual[key] = t_residual.to(torch.bfloat16).cpu()
416
+ stats["total_target_loss"] += t_loss
417
+
418
+ # Track per-layer losses
419
+ layer_name = ".".join(key.split(".")[:4])
420
+ if layer_name not in stats["target_loss_by_layer"]:
421
+ stats["target_loss_by_layer"][layer_name] = 0.0
422
+ stats["target_loss_by_layer"][layer_name] += t_loss
423
+
424
+ # What the source lost (what didn't make it into the merge)
425
+ if key in source_state:
426
+ original_source = source_state[key].float()
427
+ s_residual = original_source - merged_w
428
+ s_loss = s_residual.abs().mean().item()
429
+
430
+ if s_loss > 1e-6:
431
+ source_residual[key] = s_residual.to(torch.bfloat16).cpu()
432
+ stats["total_source_loss"] += s_loss
433
+
434
+ layer_name = ".".join(key.split(".")[:4])
435
+ if layer_name not in stats["source_loss_by_layer"]:
436
+ stats["source_loss_by_layer"][layer_name] = 0.0
437
+ stats["source_loss_by_layer"][layer_name] += s_loss
438
+
439
+ # Find the biggest losses (most knowledge dropped)
440
+ all_losses = []
441
+ for key in target_residual:
442
+ loss_magnitude = target_residual[key].float().abs().mean().item()
443
+ all_losses.append({"param": key, "side": "target", "loss": loss_magnitude})
444
+ for key in source_residual:
445
+ loss_magnitude = source_residual[key].float().abs().mean().item()
446
+ all_losses.append({"param": key, "side": "source", "loss": loss_magnitude})
447
+ all_losses.sort(key=lambda x: x["loss"], reverse=True)
448
+ stats["biggest_losses"] = all_losses[:20] # Top 20 biggest losses
449
+
450
+ # Save to disk
451
+ torch.save(target_residual, stage_dir / "target_residual.pt")
452
+ torch.save(source_residual, stage_dir / "source_residual.pt")
453
+
454
+ import json
455
+ with open(stage_dir / "residual_stats.json", "w") as f:
456
+ json.dump(stats, f, indent=2, default=str)
457
+
458
+ self.residual_index[stage_name] = {
459
+ "path": str(stage_dir),
460
+ "target_params_saved": len(target_residual),
461
+ "source_params_saved": len(source_residual),
462
+ "total_target_loss": stats["total_target_loss"],
463
+ "total_source_loss": stats["total_source_loss"],
464
+ }
465
+
466
+ print(f"[residual] Saved residuals for {stage_name}:")
467
+ print(f" Target lost: {len(target_residual)} params (avg loss: {stats['total_target_loss']:.4f})")
468
+ print(f" Source lost: {len(source_residual)} params (avg loss: {stats['total_source_loss']:.4f})")
469
+ print(f" Top loss: {all_losses[0]['param']} ({all_losses[0]['side']}, {all_losses[0]['loss']:.4f})" if all_losses else "")
470
+ print(f" Saved to: {stage_dir}")
471
+
472
+ def load_residuals(self, stage_name: str) -> tuple:
473
+ """
474
+ Load saved residuals for a stage.
475
+
476
+ Returns:
477
+ (target_residual_dict, source_residual_dict)
478
+ """
479
+ stage_dir = self.residual_dir / stage_name
480
+ target_residual = torch.load(stage_dir / "target_residual.pt", weights_only=True)
481
+ source_residual = torch.load(stage_dir / "source_residual.pt", weights_only=True)
482
+ return target_residual, source_residual
483
+
484
+ def reinject_residuals(
485
+ self,
486
+ model: AutoModelForCausalLM,
487
+ stage_name: str,
488
+ side: str = "both",
489
+ strength: float = 0.3,
490
+ ) -> AutoModelForCausalLM:
491
+ """
492
+ Re-inject saved residuals back into a model.
493
+
494
+ This adds back some of what was lost. Use a low strength (0.1-0.3)
495
+ to gently recover knowledge without undoing the merge.
496
+
497
+ Args:
498
+ model: The model to inject into
499
+ stage_name: Which merge stage's residuals to use
500
+ side: "target", "source", or "both"
501
+ strength: How much to add back (0=nothing, 1=full residual)
502
+ """
503
+ print(f"[residual] Re-injecting {stage_name} residuals (side={side}, strength={strength})...")
504
+
505
+ target_residual, source_residual = self.load_residuals(stage_name)
506
+ state = model.state_dict()
507
+ injected = 0
508
+
509
+ if side in ("target", "both"):
510
+ for key, residual in target_residual.items():
511
+ if key in state:
512
+ state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
513
+ injected += 1
514
+
515
+ if side in ("source", "both"):
516
+ for key, residual in source_residual.items():
517
+ if key in state:
518
+ state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
519
+ injected += 1
520
+
521
+ model.load_state_dict(state)
522
+ print(f"[residual] Re-injected {injected} params at {strength:.0%} strength")
523
+ return model
524
+
525
+ def get_healing_targets(self, top_n: int = 50) -> list:
526
+ """
527
+ Get the parameters with the biggest losses across ALL merges.
528
+
529
+ These are the params that the healing fine-tune should focus on.
530
+ Feed this to the LoRA target_modules to make healing smarter.
531
+ """
532
+ import json
533
+ all_losses = []
534
+
535
+ for stage_name in self.residual_index:
536
+ stage_dir = self.residual_dir / stage_name
537
+ stats_file = stage_dir / "residual_stats.json"
538
+ if stats_file.exists():
539
+ with open(stats_file) as f:
540
+ stats = json.load(f)
541
+ for loss in stats.get("biggest_losses", []):
542
+ loss["stage"] = stage_name
543
+ all_losses.append(loss)
544
+
545
+ all_losses.sort(key=lambda x: x["loss"], reverse=True)
546
+
547
+ # Extract unique layer/module names for LoRA targeting
548
+ target_modules = set()
549
+ for loss in all_losses[:top_n]:
550
+ param = loss["param"]
551
+ # Extract the module type (q_proj, k_proj, gate_proj, etc.)
552
+ parts = param.split(".")
553
+ for part in parts:
554
+ if part.endswith("_proj") or part in ("gate_proj", "up_proj", "down_proj"):
555
+ target_modules.add(part)
556
+
557
+ print(f"[residual] Top healing targets (from {len(all_losses)} total losses):")
558
+ for loss in all_losses[:5]:
559
+ print(f" {loss['param']} ({loss['side']}, stage={loss['stage']}, loss={loss['loss']:.4f})")
560
+ print(f" → Suggested LoRA targets: {sorted(target_modules)}")
561
+
562
+ return list(target_modules)
563
+
564
+
565
+ def run_single_merge(
566
+ target_model: AutoModelForCausalLM,
567
+ target_tokenizer: AutoTokenizer,
568
+ source_config: ModelConfig,
569
+ cfg: MergeConfig,
570
+ protection: MergeProtection,
571
+ residual_bank: ResidualBank = None,
572
+ calibration_data: list = None,
573
+ baseline_perplexity: float = None,
574
+ merged_sources: list = None,
575
+ ) -> dict:
576
+ """
577
+ Run a single merge: source → target.
578
+
579
+ Full pipeline for one merge step:
580
+ 1. Load source model
581
+ 2. Inject canary into source
582
+ 3. Extract activations from both
583
+ 4. Compute transport plans
584
+ 5. Apply merge protection
585
+ 6. Fuse weights
586
+ 7. Apply post-merge protection
587
+ 8. Validate
588
+
589
+ Returns:
590
+ Dict with merge results, validation results, and status
591
+ """
592
+ if merged_sources is None:
593
+ merged_sources = []
594
+
595
+ stage_name = source_config.name
596
+ print(f"\n{'=' * 70}")
597
+ print(f"MERGE STAGE: {stage_name} → target")
598
+ print(f"Risk level: {source_config.merge_risk.upper()}")
599
+ print(f"{'=' * 70}")
600
+
601
+ result = {
602
+ "stage": stage_name,
603
+ "status": "pending",
604
+ "validation": None,
605
+ "checkpoint": None,
606
+ }
607
+
608
+ # --- Step 1: Load source model ---
609
+ source_model, source_tokenizer = load_model(source_config, cfg)
610
+
611
+ # --- Step 2: Inject canary into source ---
612
+ if stage_name in CANARY_FACTS:
613
+ print(f"\n[merge] Injecting canary fact into {stage_name}...")
614
+ source_model = inject_canary(source_model, source_tokenizer, stage_name)
615
+
616
+ # --- Step 3: Load calibration data (if not provided) ---
617
+ if calibration_data is None:
618
+ calibration_data = load_calibration_data(cfg, target_tokenizer)
619
+
620
+ # --- Step 4: Extract activations ---
621
+ print(f"\n[merge] Extracting source activations...")
622
+ source_activations = extract_activations(source_model, calibration_data)
623
+
624
+ print(f"\n[merge] Extracting target activations...")
625
+ pre_merge_target_activations = extract_activations(target_model, calibration_data)
626
+
627
+ # --- Step 4.5: Mergeability pre-check (2601.22285) ---
628
+ if cfg.use_mergeability_check:
629
+ mergeability = compute_mergeability_score(
630
+ source_activations, pre_merge_target_activations, source_config
631
+ )
632
+ result["mergeability"] = mergeability
633
+
634
+ if mergeability["overall"] < cfg.mergeability_min_score:
635
+ print(f"\n[merge] ⚠ Mergeability score {mergeability['overall']:.2f} below threshold {cfg.mergeability_min_score}")
636
+ print(f"[merge] → {mergeability['recommendation']}")
637
+ result["status"] = "skipped_low_mergeability"
638
+ if "distillation_fallback" in source_config.special_handling:
639
+ result["fallback"] = "distillation"
640
+ del source_model, source_activations, pre_merge_target_activations
641
+ gc.collect()
642
+ if torch.cuda.is_available():
643
+ torch.cuda.empty_cache()
644
+ return result
645
+
646
+ # --- Step 5: Compute transport plans ---
647
+ transport_plans = compute_transport_plans(
648
+ source_activations, pre_merge_target_activations, cfg
649
+ )
650
+
651
+ # --- Step 5.5: RAM RL-weight disentanglement (2601.13572) ---
652
+ use_ram = (
653
+ cfg.use_ram_disentangle
654
+ and source_config.architecture in ("transformer", "transformer+mtp")
655
+ and source_config.merge_risk in ("low", "medium")
656
+ and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"])
657
+ )
658
+
659
+ # --- Step 6: Pre-merge protection ---
660
+ adjusted_alpha = protection.before_merge(target_model, source_config)
661
+
662
+ # Override source alpha with time-adjusted value
663
+ source_config_adjusted = copy.copy(source_config)
664
+ source_config_adjusted.merge_alpha = adjusted_alpha
665
+
666
+ # Save pre-merge state for protection
667
+ pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()}
668
+
669
+ # --- Step 7: Fuse weights ---
670
+ if use_ram:
671
+ # RAM path: disentangle RL weights, merge with preservation
672
+ print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
673
+ try:
674
+ # Try loading the base (pre-RL) model for disentanglement
675
+ base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
676
+ print(f"[merge] Loading base model for RAM: {base_hf_id}")
677
+ base_model = AutoModelForCausalLM.from_pretrained(
678
+ base_hf_id,
679
+ torch_dtype=getattr(torch, cfg.dtype),
680
+ device_map=cfg.device_map,
681
+ trust_remote_code=source_config.trust_remote_code,
682
+ )
683
+ shared_mask, rl_mask = disentangle_rl_weights(
684
+ source_model, base_model, cfg.ram_rl_threshold
685
+ )
686
+ # Fuse with RL preservation
687
+ target_state = merge_with_rl_preservation(
688
+ target_model.state_dict(),
689
+ source_model.state_dict(),
690
+ shared_mask, rl_mask,
691
+ shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha),
692
+ rl_alpha=cfg.ram_rl_alpha,
693
+ )
694
+ target_model.load_state_dict(target_state)
695
+ del base_model
696
+ print(f"[merge] RAM merge complete for {stage_name}")
697
+ except Exception as e:
698
+ print(f"[merge] RAM failed ({e}), falling back to standard T&M merge")
699
+ target_model = fuse_weights(
700
+ source_model, target_model, transport_plans,
701
+ source_config_adjusted, cfg,
702
+ )
703
+ else:
704
+ # Standard T&M path
705
+ target_model = fuse_weights(
706
+ source_model, target_model, transport_plans,
707
+ source_config_adjusted, cfg,
708
+ )
709
+
710
+ # --- Step 7.5: Theseus fallback check (2602.12952) ---
711
+ # If T&M merge produced poor activation alignment, try Theseus
712
+ if cfg.use_theseus_fallback and source_config.merge_risk == "high":
713
+ print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...")
714
+ post_activations = extract_activations(target_model, calibration_data[:50]) # Quick check
715
+ # Compare post-merge activations to pre-merge — if too similar, T&M didn't work
716
+ alignment_scores = []
717
+ for key in post_activations:
718
+ if key in pre_merge_target_activations:
719
+ cos = torch.nn.functional.cosine_similarity(
720
+ post_activations[key].float().mean(0, keepdim=True),
721
+ pre_merge_target_activations[key].float().mean(0, keepdim=True),
722
+ )
723
+ alignment_scores.append(cos.item())
724
+ avg_change = 1.0 - np.mean(alignment_scores) if alignment_scores else 0.0
725
+ print(f"[merge] Activation change from merge: {avg_change:.4f}")
726
+
727
+ if avg_change < 0.01:
728
+ print(f"[merge] ⚠ T&M had minimal effect — activating Theseus fallback")
729
+ # Restore pre-merge state and try Theseus instead
730
+ target_model.load_state_dict(pre_merge_state)
731
+ try:
732
+ base_model = AutoModelForCausalLM.from_pretrained(
733
+ source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0],
734
+ torch_dtype=getattr(torch, cfg.dtype),
735
+ device_map=cfg.device_map,
736
+ trust_remote_code=source_config.trust_remote_code,
737
+ )
738
+ target_model = transport_task_vector_theseus(
739
+ source_model, base_model, target_model,
740
+ source_activations, pre_merge_target_activations,
741
+ alpha=cfg.theseus_alpha,
742
+ )
743
+ del base_model
744
+ print(f"[merge] Theseus transport complete for {stage_name}")
745
+ except Exception as e:
746
+ print(f"[merge] Theseus also failed ({e}). Using original T&M result.")
747
+ # Re-apply T&M result
748
+ target_model = fuse_weights(
749
+ source_model, target_model, transport_plans,
750
+ source_config_adjusted, cfg,
751
+ )
752
+
753
+ # --- Step 8: Apply post-merge protection (ARM + OTMF + MagMax) ---
754
+ # Skip vision encoder params — they weren't merged, so don't "protect" them
755
+ if protection.merge_count > 0:
756
+ print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...")
757
+ target_state = target_model.state_dict()
758
+ protected_count = 0
759
+ vision_skipped = 0
760
+ for key in target_state:
761
+ if is_vision_param(key, cfg):
762
+ vision_skipped += 1
763
+ continue # Don't touch vision encoder
764
+ if key in pre_merge_state:
765
+ protected_param = protection.apply_protection(
766
+ target_state, pre_merge_state, key
767
+ )
768
+ target_state[key] = protected_param
769
+ protected_count += 1
770
+ target_model.load_state_dict(target_state)
771
+ print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)")
772
+
773
+ # --- Step 8.5: Extract post-merge activations for ARM/OTMF ---
774
+ post_merge_activations = extract_activations(target_model, calibration_data[:100])
775
+
776
+ # Record this merge's delta + compute ARM/OTMF for next merge
777
+ protection.after_merge(
778
+ target_model, pre_merge_state,
779
+ pre_merge_activations=pre_merge_target_activations,
780
+ post_merge_activations=post_merge_activations,
781
+ )
782
+
783
+ # --- Step 8.8: Save residuals (what was lost from both sides) ---
784
+ if residual_bank is not None:
785
+ print(f"\n[merge] Saving residuals for {stage_name}...")
786
+ residual_bank.save_residuals(
787
+ stage_name=stage_name,
788
+ pre_merge_target_state=pre_merge_state,
789
+ source_state={k: v.cpu() for k, v in source_model.state_dict().items()},
790
+ post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()},
791
+ source_config=source_config,
792
+ )
793
+
794
+ # --- Step 9: Free source model memory ---
795
+ del source_model, source_activations, pre_merge_target_activations
796
+ del transport_plans, post_merge_activations
797
+ gc.collect()
798
+ if torch.cuda.is_available():
799
+ torch.cuda.empty_cache()
800
+
801
+ # --- Step 10: Validate ---
802
+ merged_sources.append(stage_name)
803
+ validation = validate_merged_model(
804
+ target_model, target_tokenizer,
805
+ merged_sources, cfg,
806
+ baseline_perplexity=baseline_perplexity,
807
+ )
808
+
809
+ result["validation"] = validation
810
+ result["merged_sources"] = merged_sources.copy()
811
+
812
+ # --- Kill criteria check ---
813
+ if not validation["overall"]:
814
+ print(f"\n[merge] ⚠ VALIDATION FAILED for {stage_name}")
815
+ print(f"[merge] Kill criteria triggered — consider aborting")
816
+ result["status"] = "failed"
817
+
818
+ # Check if we should try distillation fallback
819
+ if "distillation_fallback" in source_config.special_handling:
820
+ print(f"[merge] {stage_name} has distillation fallback available")
821
+ result["fallback"] = "distillation"
822
+ else:
823
+ print(f"\n[merge] ✓ {stage_name} merge PASSED validation")
824
+ result["status"] = "passed"
825
+
826
+ return result
827
+
828
+
829
+ def run_pipeline(
830
+ stages: list[str],
831
+ cfg: MergeConfig = None,
832
+ ) -> dict:
833
+ """
834
+ Run the full merge pipeline.
835
+
836
+ Args:
837
+ stages: List of stage names to run, e.g. ["deepseek"] or
838
+ ["deepseek", "mimo", "llama", "falcon"]
839
+ cfg: Merge configuration (uses defaults if None)
840
+
841
+ Returns:
842
+ Dict with overall results, per-stage results, and final model path
843
+ """
844
+ if cfg is None:
845
+ cfg = MergeConfig()
846
+
847
+ print("\n" + "=" * 70)
848
+ print("TD FUSE — Transport and Merge Pipeline")
849
+ print(f"Target: {TARGET.name} ({TARGET.hf_id})")
850
+ if TARGET.architecture == "transformer+vision":
851
+ print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)")
852
+ print(f"Stages: {', '.join(stages)}")
853
+ print(f"Output: {cfg.output_dir}")
854
+ print("=" * 70)
855
+
856
+ # Setup
857
+ try:
858
+ setup_tm_repo(cfg)
859
+ except FileNotFoundError as e:
860
+ print(f"\n⚠ {e}")
861
+ print("Continuing with fallback implementation...")
862
+
863
+ # Create output directories
864
+ Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
865
+ Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
866
+
867
+ # --- Load target model ---
868
+ target_model, target_tokenizer = load_model(TARGET, cfg)
869
+
870
+ # --- Inject canary into target (Qwen3's own canary) ---
871
+ if "Qwen3-VL-8B" in CANARY_FACTS:
872
+ print("\n[pipeline] Injecting canary into base Qwen3-8B...")
873
+ target_model = inject_canary(target_model, target_tokenizer, "Qwen3-VL-8B")
874
+
875
+ # --- Compute baseline perplexity ---
876
+ print("\n[pipeline] Computing baseline perplexity...")
877
+ baseline_ppl = compute_perplexity(target_model, target_tokenizer)
878
+ print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}")
879
+
880
+ # --- Load calibration data once ---
881
+ calibration_data = load_calibration_data(cfg, target_tokenizer)
882
+
883
+ # --- Initialize merge protection + residual bank ---
884
+ protection = MergeProtection(cfg)
885
+ residual_bank = ResidualBank(cfg)
886
+
887
+ # --- Run each merge stage ---
888
+ pipeline_results = {
889
+ "stages": {},
890
+ "baseline_perplexity": baseline_ppl,
891
+ "final_checkpoint": None,
892
+ "residuals": {},
893
+ "overall_status": "pending",
894
+ }
895
+ merged_sources = []
896
+ all_passed = True
897
+
898
+ for stage_name in stages:
899
+ source_config = get_source_by_stage(stage_name)
900
+ if source_config is None:
901
+ print(f"\n⚠ Unknown stage: {stage_name}, skipping")
902
+ continue
903
+
904
+ # --- Wasserstein pre-check for high-risk models ---
905
+ if "check_wasserstein_first" in source_config.special_handling:
906
+ print(f"\n[pipeline] Running Wasserstein pre-check for {source_config.name}...")
907
+ # TODO: Implement Wasserstein distance pre-check
908
+ # If distance is too high, skip to distillation fallback
909
+ print("[pipeline] Pre-check: proceeding (TODO: implement distance check)")
910
+
911
+ # Run the merge (with residual bank to save what's lost)
912
+ stage_result = run_single_merge(
913
+ target_model, target_tokenizer,
914
+ source_config, cfg,
915
+ protection,
916
+ residual_bank=residual_bank,
917
+ calibration_data=calibration_data,
918
+ baseline_perplexity=baseline_ppl,
919
+ merged_sources=merged_sources,
920
+ )
921
+
922
+ pipeline_results["stages"][stage_name] = stage_result
923
+
924
+ if stage_result["status"] == "passed":
925
+ # Save checkpoint
926
+ ckpt_path = save_checkpoint(
927
+ target_model, target_tokenizer, stage_name, cfg
928
+ )
929
+ stage_result["checkpoint"] = ckpt_path
930
+ pipeline_results["final_checkpoint"] = ckpt_path
931
+ else:
932
+ all_passed = False
933
+ print(f"\n[pipeline] Stage {stage_name} FAILED")
934
+
935
+ # Decision: abort or continue?
936
+ if source_config.merge_risk == "high":
937
+ print(f"[pipeline] High-risk model failed — skipping (will use distillation)")
938
+ # Don't abort the whole pipeline, just skip this model
939
+ continue
940
+ else:
941
+ print(f"[pipeline] ABORTING pipeline — non-high-risk model failed")
942
+ pipeline_results["overall_status"] = f"aborted_at_{stage_name}"
943
+ break
944
+
945
+ # --- Save residual index ---
946
+ pipeline_results["residuals"] = residual_bank.residual_index
947
+ if residual_bank.residual_index:
948
+ print(f"\n[pipeline] Residual bank: {len(residual_bank.residual_index)} stages saved")
949
+ for stage, info in residual_bank.residual_index.items():
950
+ print(f" {stage}: target lost {info['total_target_loss']:.4f}, source lost {info['total_source_loss']:.4f}")
951
+
952
+ # Identify which modules need the most healing
953
+ healing_targets = residual_bank.get_healing_targets(top_n=50)
954
+ pipeline_results["suggested_healing_targets"] = healing_targets
955
+
956
+ # --- Save final model ---
957
+ if pipeline_results["final_checkpoint"]:
958
+ final_dir = Path(cfg.output_dir) / "final"
959
+ final_dir.mkdir(parents=True, exist_ok=True)
960
+ target_model.save_pretrained(final_dir)
961
+ target_tokenizer.save_pretrained(final_dir)
962
+ pipeline_results["final_model_path"] = str(final_dir)
963
+ print(f"\n[pipeline] Final model saved to {final_dir}")
964
+
965
+ if all_passed:
966
+ pipeline_results["overall_status"] = "all_passed"
967
+ elif pipeline_results["overall_status"] == "pending":
968
+ pipeline_results["overall_status"] = "partial"
969
+
970
+ # --- Print final summary ---
971
+ print("\n" + "=" * 70)
972
+ print("PIPELINE SUMMARY")
973
+ print("=" * 70)
974
+ for stage_name, stage_result in pipeline_results["stages"].items():
975
+ status = stage_result["status"]
976
+ emoji = "✓" if status == "passed" else "✗"
977
+ print(f" {emoji} {stage_name}: {status}")
978
+ print(f"\n Overall: {pipeline_results['overall_status']}")
979
+ if residual_bank.residual_index:
980
+ print(f"\n Residuals saved for: {', '.join(residual_bank.residual_index.keys())}")
981
+ print(f" To recover lost knowledge later:")
982
+ print(f" python -m td_fuse.run --reinject <stage> --strength 0.2")
983
+ print("=" * 70)
984
+
985
+ return pipeline_results
hugging/td_fuse/run.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Fuse — Main Entry Point.
3
+
4
+ Usage:
5
+ # Dad demo: merge just DeepSeek → Qwen3-8B (easiest, lowest risk)
6
+ python -m td_fuse.run --stage demo
7
+
8
+ # Full pipeline: all 4 merges
9
+ python -m td_fuse.run --stage all
10
+
11
+ # Single model merge
12
+ python -m td_fuse.run --stage deepseek
13
+ python -m td_fuse.run --stage mimo
14
+ python -m td_fuse.run --stage llama
15
+ python -m td_fuse.run --stage falcon
16
+
17
+ # With healing fine-tune after merge
18
+ python -m td_fuse.run --stage demo --heal
19
+
20
+ # Custom output directory
21
+ python -m td_fuse.run --stage all --output ./my_output
22
+
23
+ # Heal an existing checkpoint
24
+ python -m td_fuse.run --heal-only --model-path ./td_fuse_checkpoints/after_deepseek
25
+
26
+ Findings: #25 (dad demo plan), #22 (merge order), #24 (official T&M pipeline)
27
+ """
28
+
29
+ import argparse
30
+ import json
31
+ import sys
32
+ import time
33
+ from pathlib import Path
34
+
35
+ from .config import MergeConfig, DEMO_STAGES, FULL_STAGES
36
+ from .merge import run_pipeline, ResidualBank
37
+ from .heal import heal_model
38
+
39
+
40
+ def parse_args():
41
+ parser = argparse.ArgumentParser(
42
+ description="TD Fuse — Transport and Merge pipeline for Time Dilation",
43
+ formatter_class=argparse.RawDescriptionHelpFormatter,
44
+ epilog="""
45
+ Examples:
46
+ python -m td_fuse.run --stage demo # Dad demo (DeepSeek only)
47
+ python -m td_fuse.run --stage all # Full 4-model merge
48
+ python -m td_fuse.run --stage all --heal # Merge + healing fine-tune
49
+ python -m td_fuse.run --heal-only --model-path ./checkpoint
50
+ python -m td_fuse.run --reinject deepseek --strength 0.2 --model-path ./final
51
+ """,
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--stage",
56
+ type=str,
57
+ default="demo",
58
+ choices=["demo", "all", "deepseek", "mimo", "llama", "falcon"],
59
+ help="Which merge stage(s) to run (default: demo)",
60
+ )
61
+ parser.add_argument(
62
+ "--heal",
63
+ action="store_true",
64
+ help="Run healing fine-tune after merge",
65
+ )
66
+ parser.add_argument(
67
+ "--heal-only",
68
+ action="store_true",
69
+ help="Only run healing (skip merge), requires --model-path",
70
+ )
71
+ parser.add_argument(
72
+ "--model-path",
73
+ type=str,
74
+ default=None,
75
+ help="Path to existing model/checkpoint (for --heal-only)",
76
+ )
77
+ parser.add_argument(
78
+ "--output",
79
+ type=str,
80
+ default="./td_fuse_outputs",
81
+ help="Output directory (default: ./td_fuse_outputs)",
82
+ )
83
+ parser.add_argument(
84
+ "--checkpoint-dir",
85
+ type=str,
86
+ default="./td_fuse_checkpoints",
87
+ help="Checkpoint directory (default: ./td_fuse_checkpoints)",
88
+ )
89
+ parser.add_argument(
90
+ "--tm-repo",
91
+ type=str,
92
+ default="./Cross-Architecture-Merging-for-Large-Language-Models",
93
+ help="Path to official T&M repo",
94
+ )
95
+ parser.add_argument(
96
+ "--dry-run",
97
+ action="store_true",
98
+ help="Print what would happen without actually running",
99
+ )
100
+ parser.add_argument(
101
+ "--reinject",
102
+ type=str,
103
+ default=None,
104
+ help="Re-inject saved residuals from a stage (e.g., --reinject deepseek)",
105
+ )
106
+ parser.add_argument(
107
+ "--reinject-side",
108
+ type=str,
109
+ default="both",
110
+ choices=["target", "source", "both"],
111
+ help="Which side's residuals to re-inject (default: both)",
112
+ )
113
+ parser.add_argument(
114
+ "--strength",
115
+ type=float,
116
+ default=0.2,
117
+ help="Residual re-injection strength, 0-1 (default: 0.2)",
118
+ )
119
+
120
+ return parser.parse_args()
121
+
122
+
123
+ def print_banner():
124
+ """Print the TD Fuse banner."""
125
+ banner = """
126
+ ╔══════════════════════════════════════════════════╗
127
+ ║ ║
128
+ ║ ████████╗██████╗ ███████╗██╗ ██╗███████╗ ║
129
+ ║ ╚══██╔══╝██╔══██╗ ██╔════╝██║ ██║██╔════╝ ║
130
+ ║ ██║ ██║ ██║ █████╗ ██║ ██║███████╗ ║
131
+ ║ ██║ ██║ ██║ ██╔══╝ ██║ ██║╚════██║ ║
132
+ ║ ██║ ██████╔╝ ██║ ╚██████╔╝███████║ ║
133
+ ║ ╚═╝ ╚═════╝ ╚═╝ ╚═════╝ ╚══════╝ ║
134
+ ║ ║
135
+ ║ Transport and Merge for Time Dilation ║
136
+ ║ Merging 5 models into Qwen3-8B ║
137
+ ║ ║
138
+ ╚══════════════════════════════════��═══════════════╝
139
+ """
140
+ print(banner)
141
+
142
+
143
+ def main():
144
+ args = parse_args()
145
+ print_banner()
146
+
147
+ # Build config from args
148
+ cfg = MergeConfig(
149
+ output_dir=args.output,
150
+ checkpoint_dir=args.checkpoint_dir,
151
+ tm_repo_path=args.tm_repo,
152
+ )
153
+
154
+ # Determine which stages to run
155
+ if args.stage == "demo":
156
+ stages = DEMO_STAGES
157
+ elif args.stage == "all":
158
+ stages = FULL_STAGES
159
+ else:
160
+ stages = [args.stage]
161
+
162
+ # --- Reinject residuals mode ---
163
+ if args.reinject:
164
+ if not args.model_path:
165
+ print("Error: --reinject requires --model-path")
166
+ sys.exit(1)
167
+
168
+ from transformers import AutoModelForCausalLM, AutoTokenizer
169
+ import torch
170
+
171
+ print(f"\n[run] Re-injecting residuals from stage: {args.reinject}")
172
+ print(f"[run] Side: {args.reinject_side}, Strength: {args.strength}")
173
+
174
+ residual_bank = ResidualBank(cfg)
175
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
176
+ model = AutoModelForCausalLM.from_pretrained(
177
+ args.model_path,
178
+ torch_dtype=torch.bfloat16,
179
+ device_map="auto",
180
+ )
181
+
182
+ model = residual_bank.reinject_residuals(
183
+ model, args.reinject,
184
+ side=args.reinject_side,
185
+ strength=args.strength,
186
+ )
187
+
188
+ # Save the patched model
189
+ patched_dir = Path(cfg.output_dir) / f"reinjected_{args.reinject}_{args.strength}"
190
+ patched_dir.mkdir(parents=True, exist_ok=True)
191
+ model.save_pretrained(str(patched_dir))
192
+ tokenizer.save_pretrained(str(patched_dir))
193
+ print(f"\n[run] Patched model saved to: {patched_dir}")
194
+ return
195
+
196
+ # --- Heal-only mode ---
197
+ if args.heal_only:
198
+ if not args.model_path:
199
+ print("Error: --heal-only requires --model-path")
200
+ sys.exit(1)
201
+
202
+ print(f"\n[run] Healing model at: {args.model_path}")
203
+ healed_path = heal_model(args.model_path, cfg)
204
+ print(f"\n[run] Healed model saved to: {healed_path}")
205
+ return
206
+
207
+ # --- Dry run ---
208
+ if args.dry_run:
209
+ print("\n=== DRY RUN ===")
210
+ print(f"Stages: {stages}")
211
+ print(f"Output: {cfg.output_dir}")
212
+ print(f"Checkpoints: {cfg.checkpoint_dir}")
213
+ print(f"T&M repo: {cfg.tm_repo_path}")
214
+ print(f"Heal after: {args.heal}")
215
+ print(f"\nWould run:")
216
+ for i, stage in enumerate(stages, 1):
217
+ print(f" {i}. Merge {stage} → target")
218
+ print(f" → Validate (canary + perplexity + thinking + reasoning)")
219
+ print(f" → Checkpoint")
220
+ if args.heal:
221
+ print(f" {len(stages) + 1}. QLoRA healing fine-tune")
222
+ print("\nNo changes made (dry run).")
223
+ return
224
+
225
+ # --- Run the pipeline ---
226
+ start_time = time.time()
227
+
228
+ results = run_pipeline(stages, cfg)
229
+
230
+ elapsed = time.time() - start_time
231
+ print(f"\n[run] Pipeline completed in {elapsed / 60:.1f} minutes")
232
+
233
+ # --- Healing fine-tune (optional) ---
234
+ if args.heal and results.get("final_checkpoint"):
235
+ print("\n[run] Starting healing fine-tune...")
236
+ healed_path = heal_model(results["final_checkpoint"], cfg)
237
+ results["healed_model_path"] = healed_path
238
+ print(f"[run] Healed model: {healed_path}")
239
+
240
+ # --- Save results ---
241
+ results_path = Path(cfg.output_dir) / "pipeline_results.json"
242
+
243
+ # Convert non-serialisable objects
244
+ def make_serialisable(obj):
245
+ if isinstance(obj, dict):
246
+ return {k: make_serialisable(v) for k, v in obj.items()}
247
+ elif isinstance(obj, list):
248
+ return [make_serialisable(v) for v in obj]
249
+ elif isinstance(obj, (int, float, str, bool, type(None))):
250
+ return obj
251
+ else:
252
+ return str(obj)
253
+
254
+ with open(results_path, "w") as f:
255
+ json.dump(make_serialisable(results), f, indent=2)
256
+ print(f"[run] Results saved to {results_path}")
257
+
258
+ # --- Final summary ---
259
+ print(f"\n{'=' * 60}")
260
+ print("TD FUSE COMPLETE")
261
+ print(f"{'=' * 60}")
262
+ print(f" Status: {results['overall_status']}")
263
+ print(f" Time: {elapsed / 60:.1f} minutes")
264
+ if results.get("final_model_path"):
265
+ print(f" Model: {results['final_model_path']}")
266
+ if results.get("healed_model_path"):
267
+ print(f" Healed: {results['healed_model_path']}")
268
+ print(f" Results: {results_path}")
269
+ print(f"{'=' * 60}")
270
+
271
+ # Exit code based on result
272
+ if results["overall_status"] == "all_passed":
273
+ sys.exit(0)
274
+ else:
275
+ sys.exit(1)
276
+
277
+
278
+ if __name__ == "__main__":
279
+ main()
hugging/td_fuse/techniques.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Advanced Merge Techniques — from latest papers (Feb 2026).
3
+
4
+ This module contains implementations inspired by recent research
5
+ that improve TD's sequential cross-architecture merging pipeline.
6
+
7
+ Techniques:
8
+ 1. Theseus (2602.12952) — Procrustes-based task vector transport
9
+ 2. ARM (2602.03237) — Activation-guided rotation for sequential merges
10
+ 3. OTMF (2511.19561) — OT masks for identifying transferable weights
11
+ 4. RAM (2601.13572) — RL-weight disentanglement for RL-trained models
12
+ 5. Mergeability (2601.22285) — Pre-check scoring before attempting merge
13
+
14
+ These complement Transport and Merge (2602.05495) which handles
15
+ the core cross-architecture fusion via optimal transport.
16
+ """
17
+
18
+ import torch
19
+ import numpy as np
20
+ from typing import Optional
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer
22
+
23
+ from .config import MergeConfig, ModelConfig
24
+
25
+
26
+ # ============================================================================
27
+ # 1. THESEUS — Procrustes-Based Task Vector Transport (2602.12952)
28
+ # ============================================================================
29
+ #
30
+ # Instead of aligning neurons via optimal transport (T&M), Theseus aligns
31
+ # the FUNCTIONAL EFFECT of weights via orthogonal Procrustes.
32
+ #
33
+ # Analogy: T&M says "neuron 5 in Model A = neuron 12 in Model B"
34
+ # Theseus says "the EFFECT of Model A's weights can be rotated
35
+ # into Model B's space"
36
+ #
37
+ # Best for: Models where neuron-level alignment is poor (Falcon SSM hybrid)
38
+
39
+ def compute_procrustes_alignment(
40
+ source_activations: torch.Tensor,
41
+ target_activations: torch.Tensor,
42
+ ) -> torch.Tensor:
43
+ """
44
+ Compute the orthogonal Procrustes rotation matrix R that best maps
45
+ source activations into target activation space.
46
+
47
+ R = argmin ||target - source @ R||_F subject to R^T R = I
48
+
49
+ Solution: R = V @ U^T from SVD of (source^T @ target) = U S V^T
50
+
51
+ This is a closed-form solution — no iterative optimisation needed.
52
+
53
+ Args:
54
+ source_activations: [num_samples, source_dim] activation matrix
55
+ target_activations: [num_samples, target_dim] activation matrix
56
+
57
+ Returns:
58
+ R: [source_dim, target_dim] rotation matrix
59
+ """
60
+ # Center the activations (remove mean)
61
+ S = source_activations - source_activations.mean(dim=0, keepdim=True)
62
+ T = target_activations - target_activations.mean(dim=0, keepdim=True)
63
+
64
+ # Handle dimension mismatch by zero-padding the smaller one
65
+ s_dim = S.shape[1]
66
+ t_dim = T.shape[1]
67
+ max_dim = max(s_dim, t_dim)
68
+
69
+ if s_dim < max_dim:
70
+ S = torch.nn.functional.pad(S, (0, max_dim - s_dim))
71
+ if t_dim < max_dim:
72
+ T = torch.nn.functional.pad(T, (0, max_dim - t_dim))
73
+
74
+ # Cross-covariance matrix
75
+ M = S.T @ T # [max_dim, max_dim]
76
+
77
+ # SVD: M = U @ diag(sigma) @ V^T
78
+ U, sigma, Vt = torch.linalg.svd(M, full_matrices=True)
79
+
80
+ # Optimal rotation: R = V @ U^T
81
+ # This ensures R is orthogonal (R^T R = I)
82
+ R = Vt.T @ U.T
83
+
84
+ # Ensure proper rotation (det = +1), not reflection
85
+ det = torch.linalg.det(R)
86
+ if det < 0:
87
+ # Flip sign of last column of Vt
88
+ Vt[-1, :] *= -1
89
+ R = Vt.T @ U.T
90
+
91
+ return R[:s_dim, :t_dim] # Crop back to original dims
92
+
93
+
94
+ def transport_task_vector_theseus(
95
+ source_model: AutoModelForCausalLM,
96
+ source_base_model: AutoModelForCausalLM,
97
+ target_model: AutoModelForCausalLM,
98
+ source_activations: dict,
99
+ target_activations: dict,
100
+ alpha: float = 0.3,
101
+ ) -> AutoModelForCausalLM:
102
+ """
103
+ Transport a task vector from source to target using Theseus method.
104
+
105
+ Task vector = source_finetuned - source_base
106
+ (the "diff" that represents what the model learned)
107
+
108
+ We rotate this diff into target's space using Procrustes alignment,
109
+ then add it to target: target_new = target + alpha * R @ task_vector
110
+
111
+ This is the FALLBACK for when T&M's neuron-level alignment fails
112
+ (e.g., Falcon's SSM components).
113
+
114
+ Args:
115
+ source_model: The fine-tuned source (e.g., Falcon-H1R-7B)
116
+ source_base_model: The base version of source (for computing task vector)
117
+ target_model: The target to transport into (our merged Qwen3)
118
+ source_activations: Layer → activation tensors for source
119
+ target_activations: Layer → activation tensors for target
120
+ alpha: Blending weight for the transported task vector
121
+ """
122
+ print("[theseus] Computing task vectors and Procrustes alignment...")
123
+
124
+ source_state = source_model.state_dict()
125
+ base_state = source_base_model.state_dict()
126
+ target_state = target_model.state_dict()
127
+
128
+ # Compute per-layer Procrustes rotation matrices
129
+ rotations = {}
130
+ source_layers = sorted(source_activations.keys())
131
+ target_layers = sorted(target_activations.keys())
132
+
133
+ for sl, tl in zip(source_layers, target_layers):
134
+ if sl in source_activations and tl in target_activations:
135
+ R = compute_procrustes_alignment(
136
+ source_activations[sl].float(),
137
+ target_activations[tl].float(),
138
+ )
139
+ rotations[(sl, tl)] = R
140
+
141
+ # Transport task vectors
142
+ transported_count = 0
143
+ for target_key in target_state:
144
+ # Find matching source key (simplified — same key names)
145
+ source_key = target_key
146
+ if source_key not in source_state or source_key not in base_state:
147
+ continue
148
+
149
+ # Task vector = what the source learned
150
+ task_vector = source_state[source_key].float() - base_state[source_key].float()
151
+
152
+ if task_vector.abs().max() < 1e-8:
153
+ continue # No meaningful change
154
+
155
+ # For 2D weight matrices, apply rotation
156
+ if task_vector.dim() == 2:
157
+ # Find the appropriate rotation for this layer
158
+ for (sl, tl), R in rotations.items():
159
+ if sl.split(".")[2] == target_key.split(".")[2]: # Same layer index
160
+ R_device = R.to(task_vector.device)
161
+ # Rotate: task_vector_rotated = task_vector @ R
162
+ try:
163
+ if task_vector.shape[1] == R_device.shape[0]:
164
+ task_vector = task_vector @ R_device
165
+ elif task_vector.shape[0] == R_device.shape[0]:
166
+ task_vector = R_device.T @ task_vector
167
+ except RuntimeError:
168
+ pass # Dimension mismatch, use unrotated
169
+ break
170
+
171
+ # Apply: target_new = target + alpha * rotated_task_vector
172
+ target_w = target_state[target_key]
173
+ if task_vector.shape == target_w.shape:
174
+ target_state[target_key] = target_w + alpha * task_vector.to(target_w.dtype)
175
+ transported_count += 1
176
+
177
+ target_model.load_state_dict(target_state)
178
+ print(f"[theseus] Transported {transported_count} task vectors via Procrustes")
179
+ return target_model
180
+
181
+
182
+ # ============================================================================
183
+ # 2. ARM — Activation-Guided Rotations for Sequential Merging (2602.03237)
184
+ # ============================================================================
185
+ #
186
+ # ARM treats sequential merging like gradient descent — each merge step
187
+ # has a "direction" and a "learning rate" (merge coefficient).
188
+ #
189
+ # Key insight: Use ACTIVATION PATTERNS to compute optimal rotation vectors
190
+ # that guide each merge step. This is a smarter version of our
191
+ # orthogonal projection in MergeProtection.
192
+
193
+ def compute_arm_rotation(
194
+ pre_merge_activations: dict,
195
+ post_merge_activations: dict,
196
+ target_activations: dict,
197
+ ) -> dict:
198
+ """
199
+ Compute ARM rotation vectors for sequential merge protection.
200
+
201
+ For each layer, compute a rotation that:
202
+ 1. Preserves the direction of knowledge already merged
203
+ 2. Steers the next merge to fill GAPS rather than overwrite
204
+
205
+ The rotation is computed from the activation change (what the
206
+ last merge did) and the target (where we want to end up).
207
+
208
+ Returns:
209
+ Dict of layer_name → rotation matrix
210
+ """
211
+ print("[arm] Computing activation-guided rotations...")
212
+
213
+ rotations = {}
214
+
215
+ for layer_name in pre_merge_activations:
216
+ if layer_name not in post_merge_activations or layer_name not in target_activations:
217
+ continue
218
+
219
+ pre = pre_merge_activations[layer_name].float() # Before last merge
220
+ post = post_merge_activations[layer_name].float() # After last merge
221
+ target = target_activations[layer_name].float() # Ideal target
222
+
223
+ # Delta from last merge
224
+ merge_delta = post - pre # [samples, hidden_dim]
225
+
226
+ # Gap remaining (what we still need)
227
+ gap = target - post # [samples, hidden_dim]
228
+
229
+ # Average across samples to get direction vectors
230
+ delta_dir = merge_delta.mean(dim=0) # [hidden_dim]
231
+ gap_dir = gap.mean(dim=0) # [hidden_dim]
232
+
233
+ # Normalise
234
+ delta_norm = delta_dir / (delta_dir.norm() + 1e-8)
235
+ gap_norm = gap_dir / (gap_dir.norm() + 1e-8)
236
+
237
+ # Compute rotation from delta direction to gap direction
238
+ # Using Rodrigues' rotation formula for the 2D plane
239
+ # spanned by delta and gap
240
+ cos_theta = torch.dot(delta_norm, gap_norm).clamp(-1, 1)
241
+ sin_theta = torch.sqrt(1 - cos_theta ** 2)
242
+
243
+ # Store as a simple rotation descriptor
244
+ rotations[layer_name] = {
245
+ "delta_direction": delta_norm,
246
+ "gap_direction": gap_norm,
247
+ "cos_theta": cos_theta.item(),
248
+ "sin_theta": sin_theta.item(),
249
+ "gap_magnitude": gap_dir.norm().item(),
250
+ }
251
+
252
+ return rotations
253
+
254
+
255
+ def apply_arm_steering(
256
+ weight_delta: torch.Tensor,
257
+ rotation_info: dict,
258
+ steering_strength: float = 0.5,
259
+ ) -> torch.Tensor:
260
+ """
261
+ Steer a weight delta using ARM rotation vectors.
262
+
263
+ Instead of blindly projecting out previous merge directions
264
+ (our old orthogonal projection), ARM STEERS the delta toward
265
+ the remaining gap.
266
+
267
+ Args:
268
+ weight_delta: The raw delta from the current merge
269
+ rotation_info: ARM rotation info for this layer
270
+ steering_strength: How much to steer (0=no steering, 1=full)
271
+
272
+ Returns:
273
+ Steered weight delta
274
+ """
275
+ delta_dir = rotation_info["delta_direction"]
276
+ gap_dir = rotation_info["gap_direction"]
277
+
278
+ flat = weight_delta.flatten().float()
279
+
280
+ # Component along previous merge direction
281
+ prev_component = torch.dot(flat, delta_dir.to(flat.device))
282
+
283
+ # Remove some of the previous-direction component
284
+ # and add gap-direction component instead
285
+ correction = (
286
+ -steering_strength * prev_component * delta_dir.to(flat.device)
287
+ + steering_strength * prev_component * gap_dir.to(flat.device)
288
+ )
289
+
290
+ steered = flat + correction
291
+ return steered.reshape(weight_delta.shape).to(weight_delta.dtype)
292
+
293
+
294
+ # ============================================================================
295
+ # 3. OTMF — Transferability Masks via Optimal Transport (2511.19561)
296
+ # ============================================================================
297
+ #
298
+ # OTMF discovers which parts of each model are "transferable" (shared
299
+ # knowledge) vs "task-specific" (unique to that model).
300
+ #
301
+ # Transferable weights → safe to merge/average
302
+ # Task-specific weights → must be preserved carefully
303
+ #
304
+ # This replaces our MagMax "top 20% by magnitude" heuristic with a
305
+ # principled, data-driven approach.
306
+
307
+ def compute_transferability_masks(
308
+ model: AutoModelForCausalLM,
309
+ calibration_activations: dict,
310
+ threshold: float = 0.3,
311
+ ) -> dict:
312
+ """
313
+ Compute per-parameter transferability masks using activation variance.
314
+
315
+ High activation variance across diverse inputs → parameter encodes
316
+ task-specific knowledge (DON'T merge aggressively).
317
+
318
+ Low activation variance → parameter encodes shared/general knowledge
319
+ (safe to merge/average).
320
+
321
+ This is a simplified version of OTMF's OT-based mask discovery.
322
+
323
+ Args:
324
+ model: The current merged model
325
+ calibration_activations: Layer → [samples, hidden_dim] activations
326
+ threshold: Variance quantile threshold for "task-specific" classification
327
+
328
+ Returns:
329
+ Dict of param_name → bool mask (True = transferable/safe, False = task-specific/protect)
330
+ """
331
+ print("[otmf] Computing transferability masks...")
332
+
333
+ masks = {}
334
+ state = model.state_dict()
335
+
336
+ # Compute per-neuron activation variance
337
+ neuron_importance = {}
338
+ for layer_name, acts in calibration_activations.items():
339
+ # Variance across samples: high variance = this neuron is doing something specific
340
+ variance = acts.var(dim=0) # [hidden_dim]
341
+ neuron_importance[layer_name] = variance
342
+
343
+ # Map neuron importance to parameter importance
344
+ for param_name, param in state.items():
345
+ # Find the corresponding layer's importance
346
+ layer_prefix = ".".join(param_name.split(".")[:4]) # e.g., model.layers.0.self_attn
347
+
348
+ importance = None
349
+ for layer_name, var in neuron_importance.items():
350
+ if layer_prefix in layer_name:
351
+ importance = var
352
+ break
353
+
354
+ if importance is None:
355
+ # Default: mark everything as transferable (safe to merge)
356
+ masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
357
+ continue
358
+
359
+ # For 2D weights: importance determines which rows/columns to protect
360
+ if param.dim() == 2:
361
+ rows, cols = param.shape
362
+ # Use importance for the output dimension
363
+ imp = importance[:rows] if importance.shape[0] >= rows else importance
364
+
365
+ # Compute threshold: top (1-threshold) fraction is task-specific
366
+ if imp.numel() > 0:
367
+ q = torch.quantile(imp.float(), 1.0 - threshold)
368
+ # True = transferable (below threshold), False = task-specific (protect)
369
+ row_mask = imp < q
370
+ masks[param_name] = row_mask.unsqueeze(1).expand_as(param)
371
+ else:
372
+ masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
373
+ else:
374
+ # 1D params (biases, norms): default to transferable
375
+ masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
376
+
377
+ transferable = sum(m.sum().item() for m in masks.values())
378
+ total = sum(m.numel() for m in masks.values())
379
+ print(f"[otmf] Transferability: {transferable / total:.1%} transferable, {1 - transferable / total:.1%} task-specific")
380
+
381
+ return masks
382
+
383
+
384
+ def apply_masked_merge(
385
+ target_state: dict,
386
+ fused_state: dict,
387
+ masks: dict,
388
+ protect_strength: float = 0.8,
389
+ ) -> dict:
390
+ """
391
+ Apply transferability masks during merge.
392
+
393
+ For transferable weights: use the fused (merged) value
394
+ For task-specific weights: preserve more of the original target value
395
+
396
+ Args:
397
+ target_state: Original target weights (before this merge)
398
+ fused_state: Newly fused weights (after T&M/Theseus fusion)
399
+ masks: Transferability masks (True = safe to change)
400
+ protect_strength: How much to protect task-specific weights (0-1)
401
+
402
+ Returns:
403
+ Masked merged state dict
404
+ """
405
+ result = {}
406
+
407
+ for key in fused_state:
408
+ if key in masks and key in target_state:
409
+ mask = masks[key].to(fused_state[key].device)
410
+ original = target_state[key]
411
+ fused = fused_state[key]
412
+
413
+ # Transferable: use fused value
414
+ # Task-specific: blend more toward original
415
+ blended = torch.where(
416
+ mask,
417
+ fused, # Transferable → take merged value
418
+ protect_strength * original + (1 - protect_strength) * fused, # Protected
419
+ )
420
+ result[key] = blended
421
+ else:
422
+ result[key] = fused_state[key]
423
+
424
+ protected_params = sum(1 for k in masks if not masks[k].all())
425
+ print(f"[otmf] Applied masks: {protected_params} parameters partially protected")
426
+
427
+ return result
428
+
429
+
430
+ # ============================================================================
431
+ # 4. RAM — RL-Weight Disentanglement (2601.13572)
432
+ # ============================================================================
433
+ #
434
+ # RL-trained models (DeepSeek-R1, MiMo-7B-RL) have two types of knowledge:
435
+ # - Shared: general language understanding (same as base model)
436
+ # - RL-specific: reasoning patterns learned via GRPO/RLHF
437
+ #
438
+ # RAM separates these so we can merge the shared parts normally
439
+ # but PRESERVE the RL-specific parts that make these models special.
440
+
441
+ def disentangle_rl_weights(
442
+ rl_model: AutoModelForCausalLM,
443
+ base_model: AutoModelForCausalLM,
444
+ rl_threshold: float = 0.1,
445
+ ) -> tuple:
446
+ """
447
+ Separate RL-specific weights from shared/general weights.
448
+
449
+ RL-specific = weights that changed significantly during RL training
450
+ Shared = weights that are basically the same as base
451
+
452
+ We identify RL-specific weights by looking at the magnitude of
453
+ change from base model to RL model. Big changes → RL learned
454
+ something there → don't average it away.
455
+
456
+ Args:
457
+ rl_model: The RL-trained model (e.g., DeepSeek-R1, MiMo-7B-RL)
458
+ base_model: The base model before RL training
459
+ rl_threshold: Relative change threshold for "RL-specific" classification
460
+
461
+ Returns:
462
+ Tuple of (shared_mask, rl_mask) — both are dicts of param_name → bool tensor
463
+ shared_mask: True = this weight is shared (safe to merge normally)
464
+ rl_mask: True = this weight is RL-specific (protect during merge)
465
+ """
466
+ print("[ram] Disentangling RL-specific vs shared weights...")
467
+
468
+ rl_state = rl_model.state_dict()
469
+ base_state = base_model.state_dict()
470
+
471
+ shared_mask = {}
472
+ rl_mask = {}
473
+
474
+ total_params = 0
475
+ rl_params = 0
476
+
477
+ for key in rl_state:
478
+ if key not in base_state:
479
+ # New param (e.g., MTP head) — mark as RL-specific
480
+ rl_mask[key] = torch.ones_like(rl_state[key], dtype=torch.bool)
481
+ shared_mask[key] = torch.zeros_like(rl_state[key], dtype=torch.bool)
482
+ rl_params += rl_state[key].numel()
483
+ total_params += rl_state[key].numel()
484
+ continue
485
+
486
+ rl_w = rl_state[key].float()
487
+ base_w = base_state[key].float()
488
+
489
+ # Relative change: |rl - base| / (|base| + epsilon)
490
+ change = (rl_w - base_w).abs()
491
+ base_magnitude = base_w.abs() + 1e-8
492
+ relative_change = change / base_magnitude
493
+
494
+ # RL-specific: relative change > threshold
495
+ is_rl = relative_change > rl_threshold
496
+ rl_mask[key] = is_rl
497
+ shared_mask[key] = ~is_rl
498
+
499
+ rl_params += is_rl.sum().item()
500
+ total_params += is_rl.numel()
501
+
502
+ pct = rl_params / total_params * 100 if total_params > 0 else 0
503
+ print(f"[ram] RL-specific: {rl_params:,} params ({pct:.1f}%)")
504
+ print(f"[ram] Shared: {total_params - rl_params:,} params ({100 - pct:.1f}%)")
505
+
506
+ return shared_mask, rl_mask
507
+
508
+
509
+ def merge_with_rl_preservation(
510
+ target_state: dict,
511
+ source_state: dict,
512
+ shared_mask: dict,
513
+ rl_mask: dict,
514
+ shared_alpha: float = 0.5,
515
+ rl_alpha: float = 0.8,
516
+ ) -> dict:
517
+ """
518
+ Merge source into target while preserving RL-specific weights.
519
+
520
+ Shared weights: normal blending at shared_alpha
521
+ RL-specific weights: stronger blending toward source (preserve RL knowledge)
522
+
523
+ This prevents the RL reasoning capabilities from being diluted
524
+ by averaging with target weights.
525
+
526
+ Args:
527
+ target_state: Current target model state
528
+ source_state: RL model state to merge in
529
+ shared_mask: Which params are shared (safe for normal merge)
530
+ rl_mask: Which params are RL-specific (preserve with higher alpha)
531
+ shared_alpha: Alpha for shared weights (normal)
532
+ rl_alpha: Alpha for RL-specific weights (higher = preserve more RL knowledge)
533
+ """
534
+ print(f"[ram] Merging with RL preservation (shared α={shared_alpha}, RL α={rl_alpha})...")
535
+
536
+ result = {}
537
+ for key in target_state:
538
+ if key not in source_state:
539
+ result[key] = target_state[key]
540
+ continue
541
+
542
+ target_w = target_state[key]
543
+ source_w = source_state[key]
544
+
545
+ if source_w.shape != target_w.shape:
546
+ result[key] = target_state[key]
547
+ continue
548
+
549
+ if key in rl_mask and key in shared_mask:
550
+ rl_m = rl_mask[key].to(target_w.device)
551
+ # RL-specific: use higher alpha (preserve RL knowledge)
552
+ # Shared: use normal alpha
553
+ alpha_map = torch.where(rl_m, rl_alpha, shared_alpha)
554
+ if alpha_map.shape != target_w.shape:
555
+ alpha_map = alpha_map.expand_as(target_w) if alpha_map.dim() > 0 else torch.full_like(target_w, shared_alpha)
556
+
557
+ result[key] = alpha_map * source_w.to(target_w.device) + (1 - alpha_map) * target_w
558
+ else:
559
+ result[key] = shared_alpha * source_w.to(target_w.device) + (1 - shared_alpha) * target_w
560
+
561
+ return result
562
+
563
+
564
+ # ============================================================================
565
+ # 5. MERGEABILITY PRE-CHECK (2601.22285)
566
+ # ============================================================================
567
+ #
568
+ # Before spending GPU hours on a merge that might fail, check if the
569
+ # models are actually COMPATIBLE enough to merge.
570
+ #
571
+ # Mergeability score: 0.0 (definitely won't work) to 1.0 (should work great)
572
+
573
+ def compute_mergeability_score(
574
+ source_activations: dict,
575
+ target_activations: dict,
576
+ source_config: ModelConfig,
577
+ ) -> dict:
578
+ """
579
+ Predict how well a source model will merge into the target.
580
+
581
+ Scores based on three factors:
582
+ 1. Activation similarity (cosine similarity of mean activations)
583
+ 2. Dimensional compatibility (how similar are the layer shapes)
584
+ 3. Architecture match (same arch = bonus)
585
+
586
+ Returns:
587
+ Dict with individual scores and overall mergeability (0-1)
588
+ """
589
+ print(f"[mergeability] Scoring {source_config.name}...")
590
+
591
+ scores = {}
592
+
593
+ # --- Factor 1: Activation similarity ---
594
+ cosine_sims = []
595
+ source_layers = sorted(source_activations.keys())
596
+ target_layers = sorted(target_activations.keys())
597
+
598
+ # Match layers by position (proportional mapping)
599
+ for i, tl in enumerate(target_layers):
600
+ # Map target layer index to source layer index
601
+ src_idx = int(i * len(source_layers) / len(target_layers))
602
+ src_idx = min(src_idx, len(source_layers) - 1)
603
+ sl = source_layers[src_idx]
604
+
605
+ if sl in source_activations and tl in target_activations:
606
+ s_mean = source_activations[sl].float().mean(dim=0)
607
+ t_mean = target_activations[tl].float().mean(dim=0)
608
+
609
+ # Pad to same dimension for cosine similarity
610
+ max_dim = max(s_mean.shape[0], t_mean.shape[0])
611
+ s_padded = torch.nn.functional.pad(s_mean, (0, max_dim - s_mean.shape[0]))
612
+ t_padded = torch.nn.functional.pad(t_mean, (0, max_dim - t_mean.shape[0]))
613
+
614
+ cos_sim = torch.nn.functional.cosine_similarity(
615
+ s_padded.unsqueeze(0), t_padded.unsqueeze(0)
616
+ ).item()
617
+ cosine_sims.append(cos_sim)
618
+
619
+ activation_score = np.mean(cosine_sims) if cosine_sims else 0.0
620
+ scores["activation_similarity"] = float(activation_score)
621
+
622
+ # --- Factor 2: Dimensional compatibility ---
623
+ layer_ratio = min(source_config.layers, 36) / max(source_config.layers, 36)
624
+ hidden_ratio = min(source_config.hidden_dim, 4096) / max(source_config.hidden_dim, 4096)
625
+ dim_score = (layer_ratio + hidden_ratio) / 2
626
+ scores["dimensional_compatibility"] = float(dim_score)
627
+
628
+ # --- Factor 3: Architecture match ---
629
+ arch_scores = {
630
+ "transformer": 1.0, # Same as Qwen3
631
+ "transformer+mtp": 0.8, # Close, just drop extras
632
+ "hybrid_ssm": 0.5, # Very different
633
+ }
634
+ arch_score = arch_scores.get(source_config.architecture, 0.3)
635
+ scores["architecture_match"] = float(arch_score)
636
+
637
+ # --- Factor 4: Vocab overlap (bonus) ---
638
+ vocab_score = source_config.vocab_overlap_with_qwen3
639
+ scores["vocab_overlap"] = float(vocab_score)
640
+
641
+ # --- Overall: weighted average ---
642
+ overall = (
643
+ 0.35 * activation_score + # Most important — actual representation similarity
644
+ 0.25 * dim_score + # Shape compatibility
645
+ 0.25 * arch_score + # Architecture type
646
+ 0.15 * vocab_score # Vocab overlap
647
+ )
648
+ scores["overall"] = float(overall)
649
+
650
+ # --- Recommendation ---
651
+ if overall >= 0.7:
652
+ recommendation = "GO — standard T&M merge"
653
+ elif overall >= 0.5:
654
+ recommendation = "CAUTION — T&M merge with higher protection, have Theseus fallback ready"
655
+ elif overall >= 0.3:
656
+ recommendation = "RISKY — try Theseus first, distillation fallback"
657
+ else:
658
+ recommendation = "SKIP — use knowledge distillation instead"
659
+
660
+ scores["recommendation"] = recommendation
661
+
662
+ print(f"[mergeability] {source_config.name} score: {overall:.2f}")
663
+ print(f" Activation similarity: {activation_score:.2f}")
664
+ print(f" Dimensional compat: {dim_score:.2f}")
665
+ print(f" Architecture match: {arch_score:.2f}")
666
+ print(f" Vocab overlap: {vocab_score:.2f}")
667
+ print(f" → {recommendation}")
668
+
669
+ return scores
hugging/td_fuse/transport.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transport and Merge Wrapper — interfaces with official T&M code.
3
+
4
+ This wraps the official repo at:
5
+ github.com/chenhangcuisg-code/Cross-Architecture-Merging-for-Large-Language-Models/
6
+
7
+ We use THEIR code for:
8
+ - Correlation distance computation (corr_distance_matrix)
9
+ - Streaming Sinkhorn (sinkhorn_uniform_streaming)
10
+ - Transport plan computation (compute_P, compute_Q_and_layer_costs)
11
+ - Activation reconstruction (reconstruct_X)
12
+
13
+ We add:
14
+ - Qwen3 thinking mode protection
15
+ - MiMo MTP head handling
16
+ - Falcon SSM component handling
17
+ - Sequential merge protection (MagMax + orthogonal projection)
18
+
19
+ Findings: #01, #07, #24
20
+ """
21
+
22
+ import sys
23
+ import torch
24
+ import numpy as np
25
+ from pathlib import Path
26
+ from typing import Optional
27
+ from transformers import AutoModelForCausalLM, AutoTokenizer
28
+ from datasets import load_dataset
29
+
30
+ from .config import MergeConfig, ModelConfig, TARGET
31
+
32
+
33
+ def setup_tm_repo(cfg: MergeConfig):
34
+ """Add official T&M repo to Python path so we can import their code."""
35
+ repo_path = Path(cfg.tm_repo_path)
36
+ core_path = repo_path / "core"
37
+
38
+ if not core_path.exists():
39
+ raise FileNotFoundError(
40
+ f"Official T&M repo not found at {repo_path}\n"
41
+ f"Please clone it:\n"
42
+ f" git clone https://github.com/chenhangcuisg-code/"
43
+ f"Cross-Architecture-Merging-for-Large-Language-Models.git"
44
+ )
45
+
46
+ # Add to path so we can import hot_transport etc.
47
+ if str(core_path) not in sys.path:
48
+ sys.path.insert(0, str(core_path))
49
+ print(f"[transport] Added T&M core to path: {core_path}")
50
+
51
+
52
+ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
53
+ """
54
+ Load calibration data for activation extraction.
55
+
56
+ Mix: 600 Pile general + 300 Pile ArXiv + 600 neuralmagic Q&A = 1500 samples
57
+ Each sample truncated to cfg.calibration_seq_len tokens.
58
+
59
+ Findings: #08
60
+ """
61
+ print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
62
+
63
+ samples = []
64
+
65
+ # --- Pile: general text (600 samples) ---
66
+ try:
67
+ pile = load_dataset(
68
+ cfg.calibration_dataset_pile,
69
+ split="validation",
70
+ streaming=True,
71
+ trust_remote_code=True,
72
+ )
73
+ count = 0
74
+ for example in pile:
75
+ if count >= 600:
76
+ break
77
+ text = example.get("text", "")
78
+ if len(text) > 100: # Skip very short texts
79
+ tokens = tokenizer(
80
+ text,
81
+ truncation=True,
82
+ max_length=cfg.calibration_seq_len,
83
+ return_tensors="pt",
84
+ )
85
+ samples.append(tokens)
86
+ count += 1
87
+ print(f" Pile general: {count} samples")
88
+ except Exception as e:
89
+ print(f" ⚠ Pile failed: {e}")
90
+ print(f" Falling back to neuralmagic only")
91
+
92
+ # --- neuralmagic: Q&A calibration (up to remaining) ---
93
+ remaining = cfg.calibration_samples - len(samples)
94
+ if remaining > 0:
95
+ try:
96
+ nm = load_dataset(
97
+ cfg.calibration_dataset_nm,
98
+ split="train",
99
+ trust_remote_code=True,
100
+ )
101
+ count = 0
102
+ for example in nm:
103
+ if count >= remaining:
104
+ break
105
+ text = example.get("text", example.get("content", ""))
106
+ if len(str(text)) > 50:
107
+ tokens = tokenizer(
108
+ str(text),
109
+ truncation=True,
110
+ max_length=cfg.calibration_seq_len,
111
+ return_tensors="pt",
112
+ )
113
+ samples.append(tokens)
114
+ count += 1
115
+ print(f" neuralmagic: {count} samples")
116
+ except Exception as e:
117
+ print(f" ⚠ neuralmagic failed: {e}")
118
+
119
+ print(f"[transport] Total calibration samples: {len(samples)}")
120
+ return samples
121
+
122
+
123
+ def extract_activations(
124
+ model: AutoModelForCausalLM,
125
+ calibration_data: list,
126
+ device: str = "cuda",
127
+ ) -> dict:
128
+ """
129
+ Extract intermediate activations from each layer of a model.
130
+
131
+ Runs calibration data through the model with hooks on each layer
132
+ to capture activation patterns. These activations are what the
133
+ optimal transport algorithm aligns between source and target.
134
+
135
+ Returns:
136
+ Dict mapping layer_name → activation tensor [num_samples, hidden_dim]
137
+ """
138
+ print(f"[transport] Extracting activations from {len(calibration_data)} samples...")
139
+
140
+ activations = {}
141
+ hooks = []
142
+
143
+ # Register hooks on each transformer layer
144
+ for name, module in model.named_modules():
145
+ if hasattr(module, "self_attn") or name.endswith(".mlp"):
146
+ # Hook to capture output activations
147
+ def make_hook(layer_name):
148
+ def hook_fn(module, input, output):
149
+ # Handle tuple outputs (some layers return tuples)
150
+ if isinstance(output, tuple):
151
+ act = output[0]
152
+ else:
153
+ act = output
154
+ if layer_name not in activations:
155
+ activations[layer_name] = []
156
+ # Mean pool over sequence length → [hidden_dim]
157
+ activations[layer_name].append(
158
+ act.detach().float().mean(dim=1).cpu()
159
+ )
160
+ return hook_fn
161
+
162
+ h = module.register_forward_hook(make_hook(name))
163
+ hooks.append(h)
164
+
165
+ # Forward pass on calibration data
166
+ model.eval()
167
+ with torch.no_grad():
168
+ for i, tokens in enumerate(calibration_data):
169
+ inputs = {k: v.to(device) for k, v in tokens.items()}
170
+ try:
171
+ model(**inputs)
172
+ except Exception as e:
173
+ print(f" ⚠ Sample {i} failed: {e}")
174
+ continue
175
+
176
+ if (i + 1) % 100 == 0:
177
+ print(f" Processed {i + 1}/{len(calibration_data)} samples")
178
+
179
+ # Remove hooks
180
+ for h in hooks:
181
+ h.remove()
182
+
183
+ # Stack activations: [num_samples, hidden_dim]
184
+ for key in activations:
185
+ activations[key] = torch.cat(activations[key], dim=0)
186
+ print(f" {key}: {activations[key].shape}")
187
+
188
+ return activations
189
+
190
+
191
+ def compute_transport_plans(
192
+ source_activations: dict,
193
+ target_activations: dict,
194
+ cfg: MergeConfig,
195
+ ) -> dict:
196
+ """
197
+ Compute optimal transport plans between source and target activations.
198
+
199
+ This is where the magic happens. We use the official T&M code's:
200
+ - corr_distance_matrix: correlation distance between activation vectors
201
+ - sinkhorn_uniform_streaming: memory-efficient Sinkhorn solver
202
+ - compute_P: layer-level coupling (which source layers → which target layers)
203
+ - compute_Q_and_layer_costs: neuron-level coupling within each layer pair
204
+
205
+ Returns:
206
+ Dict with 'P' (layer coupling) and 'Q' (per-layer neuron coupling) matrices
207
+ """
208
+ print("[transport] Computing transport plans...")
209
+
210
+ try:
211
+ # Try importing official T&M code
212
+ from hot_transport import (
213
+ corr_distance_matrix,
214
+ sinkhorn_uniform_streaming,
215
+ compute_P,
216
+ compute_Q_and_layer_costs,
217
+ )
218
+ print("[transport] Using official T&M implementation")
219
+ return _compute_plans_official(
220
+ source_activations, target_activations, cfg,
221
+ corr_distance_matrix, sinkhorn_uniform_streaming,
222
+ compute_P, compute_Q_and_layer_costs,
223
+ )
224
+ except ImportError:
225
+ print("[transport] Official T&M code not available, using fallback")
226
+ return _compute_plans_fallback(
227
+ source_activations, target_activations, cfg
228
+ )
229
+
230
+
231
+ def _compute_plans_official(
232
+ source_act, target_act, cfg,
233
+ corr_distance_matrix, sinkhorn_uniform_streaming,
234
+ compute_P, compute_Q_and_layer_costs,
235
+ ) -> dict:
236
+ """Use the official T&M code to compute transport plans."""
237
+
238
+ # Get matching layer pairs
239
+ source_layers = sorted(source_act.keys())
240
+ target_layers = sorted(target_act.keys())
241
+
242
+ # Compute Q matrices (neuron-level) and layer costs
243
+ Q_matrices, layer_costs = compute_Q_and_layer_costs(
244
+ source_act, target_act,
245
+ source_layers, target_layers,
246
+ )
247
+
248
+ # Compute P matrix (layer-level coupling)
249
+ P = compute_P(layer_costs)
250
+
251
+ return {
252
+ "P": P,
253
+ "Q": Q_matrices,
254
+ "source_layers": source_layers,
255
+ "target_layers": target_layers,
256
+ }
257
+
258
+
259
+ def _compute_plans_fallback(
260
+ source_act: dict,
261
+ target_act: dict,
262
+ cfg: MergeConfig,
263
+ ) -> dict:
264
+ """
265
+ Fallback transport plan computation when official code isn't available.
266
+
267
+ Uses correlation distance + basic Sinkhorn. Less optimised than official
268
+ but functionally correct for testing.
269
+ """
270
+
271
+ source_layers = sorted(source_act.keys())
272
+ target_layers = sorted(target_act.keys())
273
+
274
+ # --- Step 1: Correlation distance matrices per layer pair ---
275
+ Q_matrices = {}
276
+ layer_costs = np.zeros((len(source_layers), len(target_layers)))
277
+
278
+ for i, sl in enumerate(source_layers):
279
+ for j, tl in enumerate(target_layers):
280
+ if sl not in source_act or tl not in target_act:
281
+ continue
282
+
283
+ S = source_act[sl].numpy() # [samples, hidden_dim_source]
284
+ T = target_act[tl].numpy() # [samples, hidden_dim_target]
285
+
286
+ # Correlation distance: 1 - pearson_correlation
287
+ # Between each pair of neurons across samples
288
+ # S: [samples, n_source], T: [samples, n_target]
289
+ S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8)
290
+ T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8)
291
+ corr = S_norm.T @ T_norm / S.shape[0] # [n_source, n_target]
292
+ cost = 1.0 - corr # Correlation distance
293
+
294
+ # Basic Sinkhorn on this cost matrix
295
+ Q = _sinkhorn(cost, reg=cfg.sinkhorn_reg, max_iter=cfg.sinkhorn_max_iter)
296
+ Q_matrices[(sl, tl)] = Q
297
+ layer_costs[i, j] = cost.mean()
298
+
299
+ # --- Step 2: Layer coupling (P matrix) ---
300
+ P = _sinkhorn(layer_costs, reg=cfg.sinkhorn_reg, max_iter=cfg.sinkhorn_max_iter)
301
+
302
+ return {
303
+ "P": P,
304
+ "Q": Q_matrices,
305
+ "source_layers": source_layers,
306
+ "target_layers": target_layers,
307
+ }
308
+
309
+
310
+ def _sinkhorn(
311
+ cost_matrix: np.ndarray,
312
+ reg: float = 0.05,
313
+ max_iter: int = 100,
314
+ ) -> np.ndarray:
315
+ """
316
+ Basic Sinkhorn-Knopp algorithm for optimal transport.
317
+
318
+ Solves: min <T, C> - reg * H(T)
319
+ where H(T) is the entropy of the transport plan.
320
+
321
+ This is the FALLBACK. The official code uses streaming Sinkhorn
322
+ which is more memory-efficient.
323
+ """
324
+ n, m = cost_matrix.shape
325
+ K = np.exp(-cost_matrix / reg)
326
+
327
+ u = np.ones(n) / n
328
+ v = np.ones(m) / m
329
+
330
+ for _ in range(max_iter):
331
+ u = 1.0 / (K @ v + 1e-10)
332
+ v = 1.0 / (K.T @ u + 1e-10)
333
+
334
+ # Transport plan
335
+ T = np.diag(u) @ K @ np.diag(v)
336
+ return T
337
+
338
+
339
+ def fuse_weights(
340
+ source_model: AutoModelForCausalLM,
341
+ target_model: AutoModelForCausalLM,
342
+ transport_plans: dict,
343
+ source_config: ModelConfig,
344
+ cfg: MergeConfig,
345
+ ) -> AutoModelForCausalLM:
346
+ """
347
+ Fuse source model weights into target model using transport plans.
348
+
349
+ For each layer pair with significant coupling (P > threshold):
350
+ 1. Get the Q matrix (neuron-level correspondence)
351
+ 2. Transport source weights into target neuron basis: W_fused = Q @ W_source
352
+ 3. Blend with target: W_final = alpha * W_fused + (1-alpha) * W_target
353
+
354
+ Special handling per model:
355
+ - DeepSeek: Direct merge (same architecture)
356
+ - MiMo: Skip MTP heads, skip embeddings
357
+ - Llama: Layer mapping (32→36), skip embeddings, drop QKV bias
358
+ - Falcon: Skip Mamba components, skip embeddings
359
+
360
+ Returns:
361
+ Target model with fused weights
362
+ """
363
+ print(f"\n[transport] Fusing {source_config.name} → target")
364
+ alpha = source_config.merge_alpha
365
+
366
+ try:
367
+ # Try official fusion code first
368
+ from generate_hot_residual import fuse_attention_only_from_hot_dir
369
+ print("[transport] Using official fusion implementation")
370
+ # TODO: Adapt official fusion to our pipeline
371
+ # For now, fall through to manual fusion
372
+ except ImportError:
373
+ pass
374
+
375
+ # --- Manual fusion using transport plans ---
376
+ source_state = source_model.state_dict()
377
+ target_state = target_model.state_dict()
378
+ P = transport_plans["P"]
379
+ Q = transport_plans["Q"]
380
+
381
+ fused_count = 0
382
+ skipped_count = 0
383
+
384
+ for target_key in target_state:
385
+ # Skip parameters we shouldn't merge
386
+ if _should_skip(target_key, source_config):
387
+ skipped_count += 1
388
+ continue
389
+
390
+ # Find corresponding source key
391
+ source_key = _map_key(target_key, source_config)
392
+ if source_key is None or source_key not in source_state:
393
+ skipped_count += 1
394
+ continue
395
+
396
+ target_w = target_state[target_key]
397
+ source_w = source_state[source_key]
398
+
399
+ # Handle dimension mismatches
400
+ if target_w.shape != source_w.shape:
401
+ # Use transport plan to align dimensions
402
+ source_w = _align_dimensions(source_w, target_w.shape, Q, target_key)
403
+ if source_w is None:
404
+ skipped_count += 1
405
+ continue
406
+
407
+ # Blend: W_final = alpha * source + (1-alpha) * target
408
+ fused_w = alpha * source_w.to(target_w.device) + (1 - alpha) * target_w
409
+ target_state[target_key] = fused_w
410
+ fused_count += 1
411
+
412
+ # Apply thinking mode protection
413
+ if cfg.freeze_think_tokens and "embed_tokens" in target_key:
414
+ for token_id in cfg.think_token_ids:
415
+ if token_id < target_state["model.embed_tokens.weight"].shape[0]:
416
+ # Restore original embedding for think tokens
417
+ orig_embed = target_model.state_dict()["model.embed_tokens.weight"]
418
+ target_state["model.embed_tokens.weight"][token_id] = orig_embed[token_id]
419
+ print(f"[transport] Protected think token {token_id}")
420
+
421
+ # Load fused weights
422
+ target_model.load_state_dict(target_state)
423
+ print(f"[transport] Fused {fused_count} params, skipped {skipped_count}")
424
+
425
+ return target_model
426
+
427
+
428
+ def _should_skip(key: str, source_config: ModelConfig) -> bool:
429
+ """Determine if a parameter should be skipped during merge."""
430
+
431
+ # Always skip if source model says to skip embeddings
432
+ if source_config.skip_embeddings and ("embed_tokens" in key or "lm_head" in key):
433
+ return True
434
+
435
+ # Skip MiMo MTP heads
436
+ if "drop_mtp_heads" in source_config.special_handling and "mtp_head" in key:
437
+ return True
438
+
439
+ # Skip Falcon Mamba-specific parameters
440
+ if "drop_mamba_state_params" in source_config.special_handling:
441
+ mamba_keys = ["mamba", "A_log", "dt_proj", ".D"]
442
+ if any(mk in key for mk in mamba_keys):
443
+ return True
444
+
445
+ # Skip QKV bias for Llama (Qwen3 doesn't have it)
446
+ if "drop_qkv_bias" in source_config.special_handling and ".bias" in key:
447
+ if any(proj in key for proj in ["q_proj", "k_proj", "v_proj"]):
448
+ return True
449
+
450
+ return False
451
+
452
+
453
+ def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]:
454
+ """Map a target model parameter name to the corresponding source name."""
455
+
456
+ # For same-architecture models (DeepSeek), keys match directly
457
+ if source_config.architecture == "transformer" and source_config.layers == 36:
458
+ return target_key
459
+
460
+ # For Llama (32 layers → 36 layers), map layer indices
461
+ if "layer_mapping_32_to_36" in source_config.special_handling:
462
+ if "model.layers." in target_key:
463
+ # Extract layer number
464
+ parts = target_key.split(".")
465
+ try:
466
+ layer_idx = int(parts[2])
467
+ except (IndexError, ValueError):
468
+ return target_key
469
+
470
+ # Map 36 target layers to 32 source layers (stride)
471
+ source_layer = int(layer_idx * 32 / 36)
472
+ parts[2] = str(source_layer)
473
+ return ".".join(parts)
474
+
475
+ # For MiMo (same layer count, different extras), keys mostly match
476
+ if source_config.architecture == "transformer+mtp":
477
+ if "mtp_head" in target_key:
478
+ return None # MTP heads don't exist in target
479
+ return target_key
480
+
481
+ # For Falcon hybrid, only attention and MLP keys map
482
+ if source_config.architecture == "hybrid_ssm":
483
+ if any(k in target_key for k in ["self_attn", "mlp", "layer_norm"]):
484
+ return target_key # These exist in both
485
+ return None # Mamba components don't map
486
+
487
+ return target_key
488
+
489
+
490
+ def _align_dimensions(
491
+ source_w: torch.Tensor,
492
+ target_shape: tuple,
493
+ Q_matrices: dict,
494
+ key: str,
495
+ ) -> Optional[torch.Tensor]:
496
+ """
497
+ Align source weight dimensions to target shape using transport plans.
498
+
499
+ For small mismatches: pad or truncate.
500
+ For large mismatches: use Q matrix to project.
501
+ """
502
+ if source_w.shape == target_shape:
503
+ return source_w
504
+
505
+ # Simple case: different width (FFN size difference)
506
+ if len(source_w.shape) == 2 and len(target_shape) == 2:
507
+ s_rows, s_cols = source_w.shape
508
+ t_rows, t_cols = target_shape
509
+
510
+ result = torch.zeros(target_shape, dtype=source_w.dtype)
511
+
512
+ # Copy what fits
513
+ min_rows = min(s_rows, t_rows)
514
+ min_cols = min(s_cols, t_cols)
515
+ result[:min_rows, :min_cols] = source_w[:min_rows, :min_cols]
516
+
517
+ return result
518
+
519
+ # 1D case (biases, layer norms)
520
+ if len(source_w.shape) == 1 and len(target_shape) == 1:
521
+ result = torch.zeros(target_shape, dtype=source_w.dtype)
522
+ min_len = min(source_w.shape[0], target_shape[0])
523
+ result[:min_len] = source_w[:min_len]
524
+ return result
525
+
526
+ # Can't align — skip this parameter
527
+ return None
hugging/td_fuse/validate.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Post-Merge Validation — run after EVERY merge step.
3
+
4
+ Tests:
5
+ 1. Canary recall (did knowledge transfer?)
6
+ 2. Perplexity check (did we break the model?)
7
+ 3. Thinking mode (do <think> tags still work?)
8
+ 4. Quick reasoning test (can it still think?)
9
+
10
+ Kill criteria: >10% performance drop on any test → abort merge.
11
+ Findings: #11, #22, #25
12
+ """
13
+
14
+ import torch
15
+ import math
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
+
18
+ from .canary import test_all_canaries
19
+ from .config import MergeConfig
20
+
21
+
22
+ def validate_merged_model(
23
+ model: AutoModelForCausalLM,
24
+ tokenizer: AutoTokenizer,
25
+ merged_sources: list[str],
26
+ cfg: MergeConfig,
27
+ baseline_perplexity: float = None,
28
+ ) -> dict:
29
+ """
30
+ Run full validation suite on a merged model.
31
+
32
+ Args:
33
+ model: The merged model to validate
34
+ tokenizer: The tokenizer
35
+ merged_sources: List of source models merged so far
36
+ cfg: Merge configuration
37
+ baseline_perplexity: Perplexity of the target model before merging
38
+
39
+ Returns:
40
+ Dict with test results and overall pass/fail
41
+ """
42
+ print("\n" + "=" * 60)
43
+ print(f"VALIDATION — After merging: {', '.join(merged_sources)}")
44
+ print("=" * 60)
45
+
46
+ results = {
47
+ "canary": None,
48
+ "perplexity": None,
49
+ "thinking_mode": None,
50
+ "reasoning": None,
51
+ "overall": False,
52
+ }
53
+
54
+ # --- Test 1: Canary recall ---
55
+ canary_results = test_all_canaries(model, tokenizer, merged_sources)
56
+ passed_canaries = sum(1 for v in canary_results.values() if v)
57
+ total_canaries = len(canary_results)
58
+ results["canary"] = {
59
+ "passed": passed_canaries,
60
+ "total": total_canaries,
61
+ "ok": passed_canaries >= cfg.canary_pass_threshold,
62
+ "details": canary_results,
63
+ }
64
+
65
+ # --- Test 2: Perplexity ---
66
+ perplexity = compute_perplexity(model, tokenizer)
67
+ ppl_ok = True
68
+ if baseline_perplexity is not None:
69
+ ratio = perplexity / baseline_perplexity
70
+ ppl_ok = ratio < cfg.perplexity_threshold
71
+ print(f"\n[validate] Perplexity: {perplexity:.2f} (baseline: {baseline_perplexity:.2f}, ratio: {ratio:.2f})")
72
+ if not ppl_ok:
73
+ print(f"[validate] ⚠ Perplexity ratio {ratio:.2f} exceeds threshold {cfg.perplexity_threshold}")
74
+ else:
75
+ print(f"\n[validate] Perplexity: {perplexity:.2f} (no baseline to compare)")
76
+ results["perplexity"] = {"value": perplexity, "ok": ppl_ok}
77
+
78
+ # --- Test 3: Thinking mode ---
79
+ think_ok = test_thinking_mode(model, tokenizer)
80
+ results["thinking_mode"] = {"ok": think_ok}
81
+
82
+ # --- Test 4: Quick reasoning ---
83
+ reason_ok = test_reasoning(model, tokenizer)
84
+ results["reasoning"] = {"ok": reason_ok}
85
+
86
+ # --- Overall verdict ---
87
+ all_ok = (
88
+ results["canary"]["ok"]
89
+ and results["perplexity"]["ok"]
90
+ and results["thinking_mode"]["ok"]
91
+ and results["reasoning"]["ok"]
92
+ )
93
+ results["overall"] = all_ok
94
+
95
+ # Summary
96
+ print("\n" + "-" * 60)
97
+ print("VALIDATION SUMMARY")
98
+ print("-" * 60)
99
+ print(f" Canary recall: {'✓' if results['canary']['ok'] else '✗'} ({passed_canaries}/{total_canaries})")
100
+ print(f" Perplexity: {'✓' if ppl_ok else '✗'} ({perplexity:.2f})")
101
+ print(f" Thinking mode: {'✓' if think_ok else '✗'}")
102
+ print(f" Reasoning: {'✓' if reason_ok else '✗'}")
103
+ print(f" OVERALL: {'✓ PASS' if all_ok else '✗ FAIL — consider aborting'}")
104
+ print("-" * 60)
105
+
106
+ return results
107
+
108
+
109
+ def compute_perplexity(
110
+ model: AutoModelForCausalLM,
111
+ tokenizer: AutoTokenizer,
112
+ test_texts: list[str] = None,
113
+ ) -> float:
114
+ """
115
+ Compute perplexity on a small test set.
116
+
117
+ Lower perplexity = model is more confident about predicting text.
118
+ A big spike after merging means the model was damaged.
119
+ """
120
+ if test_texts is None:
121
+ test_texts = [
122
+ "The quick brown fox jumps over the lazy dog.",
123
+ "In mathematics, a prime number is a natural number greater than 1.",
124
+ "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
125
+ "The theory of general relativity describes gravity as the curvature of spacetime.",
126
+ "To solve 3x + 7 = 22, subtract 7 from both sides to get 3x = 15, then divide by 3.",
127
+ ]
128
+
129
+ model.eval()
130
+ total_loss = 0.0
131
+ total_tokens = 0
132
+
133
+ for text in test_texts:
134
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
135
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
136
+
137
+ with torch.no_grad():
138
+ outputs = model(**inputs, labels=inputs["input_ids"])
139
+ total_loss += outputs.loss.item() * inputs["input_ids"].shape[1]
140
+ total_tokens += inputs["input_ids"].shape[1]
141
+
142
+ avg_loss = total_loss / total_tokens
143
+ perplexity = math.exp(avg_loss)
144
+ return perplexity
145
+
146
+
147
+ def test_thinking_mode(
148
+ model: AutoModelForCausalLM,
149
+ tokenizer: AutoTokenizer,
150
+ ) -> bool:
151
+ """
152
+ Test if the model still uses <think> tags for reasoning.
153
+
154
+ The thinking mode is Qwen3's special feature — if it's gone,
155
+ the merge damaged something critical.
156
+ """
157
+ prompt = "Solve step by step: What is 15 × 13?"
158
+
159
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
160
+ with torch.no_grad():
161
+ outputs = model.generate(
162
+ **inputs,
163
+ max_new_tokens=200,
164
+ temperature=0.7,
165
+ do_sample=True,
166
+ )
167
+
168
+ response = tokenizer.decode(outputs[0], skip_special_tokens=False)
169
+
170
+ # Check for thinking tags
171
+ has_think_open = "<think>" in response
172
+ has_think_close = "</think>" in response
173
+ passed = has_think_open and has_think_close
174
+
175
+ print(f"\n[validate] Thinking mode test:")
176
+ print(f" Prompt: {prompt}")
177
+ print(f" Response: {response[:200]}...")
178
+ print(f" <think>: {'✓ found' if has_think_open else '✗ missing'}")
179
+ print(f" </think>: {'✓ found' if has_think_close else '✗ missing'}")
180
+ print(f" Status: {'✓ PASS' if passed else '✗ FAIL'}")
181
+
182
+ return passed
183
+
184
+
185
+ def test_reasoning(
186
+ model: AutoModelForCausalLM,
187
+ tokenizer: AutoTokenizer,
188
+ ) -> bool:
189
+ """
190
+ Quick reasoning sanity check — can the model still do basic math?
191
+
192
+ This catches catastrophic failures where the merge produced gibberish.
193
+ """
194
+ prompt = "What is 7 + 8?"
195
+ expected_answer = "15"
196
+
197
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
198
+ with torch.no_grad():
199
+ outputs = model.generate(
200
+ **inputs,
201
+ max_new_tokens=50,
202
+ temperature=0.1,
203
+ do_sample=False,
204
+ )
205
+
206
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
207
+ passed = expected_answer in response
208
+
209
+ print(f"\n[validate] Quick reasoning test:")
210
+ print(f" Prompt: {prompt}")
211
+ print(f" Expected: {expected_answer}")
212
+ print(f" Got: {response}")
213
+ print(f" Status: {'✓ PASS' if passed else '✗ FAIL'}")
214
+
215
+ return passed
hugging/td_lang/.DS_Store ADDED
Binary file (6.15 kB). View file
 
hugging/td_lang/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Lang — Domain-specific language for Time Dilation project.
3
+
4
+ Compiles .td files into Python code that calls td_fuse.
5
+ Write simple scripts instead of complex Python.
6
+
7
+ Architecture:
8
+ td_lang/
9
+ ├── __init__.py <- This file
10
+ ├── __main__.py <- Entry point for python -m td_lang
11
+ ├── grammar.py <- Lark grammar + parse tree transformer
12
+ ├── ast_nodes.py <- Dataclass AST nodes for each command
13
+ ├── compiler.py <- AST -> Python code generation
14
+ ├── executor.py <- Run compiled code, track lineage
15
+ ├── cli.py <- Command-line interface
16
+ ├── errors.py <- Custom exceptions
17
+ └── examples/
18
+ ├── demo_merge.td <- Basic merge example
19
+ ├── demo_heal.td <- Merge + heal example
20
+ ├── demo_full.td <- Full pipeline with gates + budget
21
+ ├── demo_loop.td <- Self-improvement loop example
22
+ ├── demo_phase3.td <- Fork/edit/prune/reset example
23
+ └── demo_phase4.td <- Contracts + snapshot + report example
24
+
25
+ Phase 1: load, merge, heal, eval, commit
26
+ Phase 2: diagnose, synth, train, debate
27
+ Phase 3: fork, reset, prune, edit
28
+ Phase 4: snapshot, report, data_contract, reward_contract
29
+ Phase 5: CLI polish, --version, info command, --verbose
30
+
31
+ Designed from interviews test_14 (10 commands) and test_17 (ForgeSpec 2.0).
32
+ """
33
+
34
+ from .grammar import parse_td_file, parse_td_string # noqa: F401
35
+ from .compiler import compile_program # noqa: F401
36
+ from .executor import TDExecutor, check_td_file, compile_td_file, run_td_file # noqa: F401
37
+
38
+ __version__ = "0.2.0"
39
+ __author__ = "Milan (TD Project)"
40
+
41
+ __all__ = [
42
+ "parse_td_file",
43
+ "parse_td_string",
44
+ "compile_program",
45
+ "TDExecutor",
46
+ "check_td_file",
47
+ "compile_td_file",
48
+ "run_td_file",
49
+ "__version__",
50
+ "__author__",
51
+ ]
hugging/td_lang/__main__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Entry point for python -m td_lang."""
2
+
3
+ from .cli import main
4
+
5
+ main()
hugging/td_lang/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.01 kB). View file
 
hugging/td_lang/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (2.02 kB). View file
 
hugging/td_lang/__pycache__/__main__.cpython-310.pyc ADDED
Binary file (254 Bytes). View file
 
hugging/td_lang/__pycache__/__main__.cpython-314.pyc ADDED
Binary file (262 Bytes). View file
 
hugging/td_lang/__pycache__/ast_nodes.cpython-310.pyc ADDED
Binary file (12.7 kB). View file
 
hugging/td_lang/__pycache__/ast_nodes.cpython-314.pyc ADDED
Binary file (18.7 kB). View file
 
hugging/td_lang/__pycache__/cli.cpython-310.pyc ADDED
Binary file (6.62 kB). View file
 
hugging/td_lang/__pycache__/cli.cpython-314.pyc ADDED
Binary file (10.5 kB). View file
 
hugging/td_lang/__pycache__/compiler.cpython-310.pyc ADDED
Binary file (88.7 kB). View file
 
hugging/td_lang/__pycache__/compiler.cpython-314.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bef7388fef05cdd8ee4edcc72a4b8907c8637caa22cfc802da044470a515c92
3
+ size 162778
hugging/td_lang/__pycache__/errors.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
hugging/td_lang/__pycache__/errors.cpython-314.pyc ADDED
Binary file (6.34 kB). View file
 
hugging/td_lang/__pycache__/executor.cpython-310.pyc ADDED
Binary file (5.94 kB). View file
 
hugging/td_lang/__pycache__/executor.cpython-314.pyc ADDED
Binary file (9.49 kB). View file
 
hugging/td_lang/__pycache__/grammar.cpython-310.pyc ADDED
Binary file (25.4 kB). View file
 
hugging/td_lang/__pycache__/grammar.cpython-314.pyc ADDED
Binary file (37.8 kB). View file
 
hugging/td_lang/ast_nodes.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Lang AST Nodes — Dataclass containers for each parsed command.
3
+
4
+ Each .td command becomes one of these nodes after parsing.
5
+ Phase 1 nodes are compiled into runnable Python; Phase 2 nodes are stubs so
6
+ the compiler can reject them with a clear error until they are implemented.
7
+ """
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import Any, List, Optional
11
+
12
+
13
+ # ============================================================================
14
+ # PHASE 1 COMMANDS
15
+ # ============================================================================
16
+
17
+ @dataclass
18
+ class LoadCmd:
19
+ """Load a model and give it a name.
20
+
21
+ Example: load "Qwen/Qwen3-VL-8B-Instruct" as base
22
+ """
23
+ model_ref: str # HuggingFace path or local path
24
+ alias: str # Name to use in the rest of the script
25
+
26
+
27
+ @dataclass
28
+ class MergeCmd:
29
+ """Merge a source model into a target using a method.
30
+
31
+ Example: merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
32
+ """
33
+ source: str # Model path or alias to merge from
34
+ target: str # Alias to merge into (must be loaded first)
35
+ method: str # "transport", "slerp", "ties", "dare"
36
+ strength: float = 0.5 # 0.0 = keep target, 1.0 = keep source
37
+
38
+
39
+ @dataclass
40
+ class HealCmd:
41
+ """Run QLoRA healing fine-tune on a model.
42
+
43
+ Example: heal base lora_r 32 epochs 2
44
+ """
45
+ target: str # Alias of model to heal
46
+ lora_r: int = 32 # LoRA rank (higher = more capacity)
47
+ epochs: int = 2 # Training epochs
48
+
49
+
50
+ @dataclass
51
+ class EvalCmd:
52
+ """Run validation/evaluation on a model.
53
+
54
+ Example: eval base on "pile_sample" -> report.json
55
+ """
56
+ target: str # Alias of model to evaluate
57
+ dataset: Optional[str] = None # Optional dataset name/path
58
+ output: Optional[str] = None # Optional output file path
59
+
60
+
61
+ @dataclass
62
+ class CommitCmd:
63
+ """Save model checkpoint, optionally requiring gates to pass.
64
+
65
+ Example: commit base if [canary, perplexity, thinking_mode]
66
+ """
67
+ target: str # Alias of model to commit
68
+ gates: Optional[list[str]] = None # Gate names that must pass
69
+
70
+
71
+ # ============================================================================
72
+ # PHASE 2 COMMANDS (placeholders — structure ready, not wired up yet)
73
+ # ============================================================================
74
+
75
+ @dataclass
76
+ class SynthCmd:
77
+ """Generate synthetic training data from a model. (Phase 2)"""
78
+ target: str
79
+ source: str
80
+ filter_method: Optional[str] = None
81
+ output: Optional[str] = None
82
+
83
+
84
+ @dataclass
85
+ class TrainCmd:
86
+ """Train a model on a dataset. (Phase 2)"""
87
+ target: str
88
+ dataset: str
89
+ method: str = "grpo" # "grpo", "sft", "dpo"
90
+ steps: Optional[int] = None
91
+ learning_rate: Optional[float] = None
92
+
93
+
94
+ @dataclass
95
+ class DebateCmd:
96
+ """Generate multi-answer debate for preference pairs. (Phase 2)"""
97
+ target: str
98
+ rounds: int = 3
99
+ candidates: int = 8
100
+ output: Optional[str] = None
101
+
102
+
103
+ @dataclass
104
+ class DiagnoseCmd:
105
+ """Ask model what it's bad at — self-diagnosis. (Phase 2)"""
106
+ target: str
107
+ output: Optional[str] = None
108
+
109
+
110
+ @dataclass
111
+ class ForkCmd:
112
+ """Branch current model weights for parallel experiments. (Phase 3)
113
+
114
+ Example: fork base as experiment_v2
115
+ Cheap fork: copies manifest + adapters, shares base weights (default).
116
+ """
117
+ source: str # Alias of model to fork from
118
+ alias: str # Name for the new branch
119
+
120
+
121
+ @dataclass
122
+ class ResetCmd:
123
+ """Revert model to a previous checkpoint. (Phase 3)
124
+
125
+ Example: reset base to "checkpoint_042"
126
+ Deletes current model, clears CUDA cache, reloads from disk.
127
+ Must also reset optimizer state.
128
+ """
129
+ target: str # Alias of model to reset
130
+ checkpoint: str # Checkpoint name/path to revert to
131
+
132
+
133
+ @dataclass
134
+ class PruneCmd:
135
+ """Structural pruning — remove low-utility neurons/heads. (Phase 3)
136
+
137
+ Example: prune base using wanda aggressiveness 0.2
138
+ Safe zone: ~20% max (LLM-Pruner paper). Language backbone only.
139
+ """
140
+ target: str
141
+ method: str = "wanda" # "wanda", "magnitude", "taylor"
142
+ aggressiveness: float = 0.2 # Fraction to remove (0.0-1.0)
143
+
144
+
145
+ @dataclass
146
+ class EditCmd:
147
+ """Surgical LoRA/DoRA editing on specific layers. (Phase 3)
148
+
149
+ Example: edit base layers 16-28 using lora lr 1e-4
150
+ "Try before buy": eval with adapter enabled vs disabled before merging.
151
+ """
152
+ target: str
153
+ layers: str = "all" # "all", "16-28", single number
154
+ method: str = "lora" # "lora" or "dora"
155
+ learning_rate: Optional[float] = None
156
+
157
+
158
+ # ============================================================================
159
+ # PHASE 4 COMMANDS — Contracts, Lineage, Economics (ForgeSpec 2.0, test_17)
160
+ # ============================================================================
161
+
162
+ # ============================================================================
163
+ # PHASE 7 — LOOP CONTROL (repeat, if/else)
164
+ # ============================================================================
165
+
166
+ @dataclass
167
+ class RepeatBlock:
168
+ """Repeat a block of commands N times. (Phase 7 — Loop Control)
169
+
170
+ Example:
171
+ repeat 5 {
172
+ diagnose base
173
+ synth base from base
174
+ train base on "data.jsonl" using grpo steps 64
175
+ eval base
176
+ }
177
+ """
178
+ count: int # Number of iterations
179
+ body: List[Any] = field(default_factory=list) # Commands inside the block
180
+
181
+
182
+ @dataclass
183
+ class IfBlock:
184
+ """Conditional execution based on last eval result. (Phase 7 — Loop Control)
185
+
186
+ Example:
187
+ if eval_passed {
188
+ commit base
189
+ } else {
190
+ reset base to "last_good"
191
+ }
192
+
193
+ Condition checks the most recent eval result for the target.
194
+ """
195
+ condition: str # "eval_passed", "gate_passed", etc.
196
+ target: Optional[str] = None # Which model's eval to check
197
+ then_body: List[Any] = field(default_factory=list)
198
+ else_body: List[Any] = field(default_factory=list)
199
+
200
+
201
+ @dataclass
202
+ class FuseCmd:
203
+ """Fuse multiple models into a target in one shot. (Phase 6 — Easy Merge)
204
+
205
+ Example: fuse [deepseek-r1, mimo-7b, llama-3.1] into base
206
+ Auto-picks Transport and Merge, auto-sets per-model strength.
207
+ Handles cross-architecture merging (all 5 source models have different archs).
208
+ """
209
+ sources: list[str] # List of model names/paths to fuse in
210
+ target: str # Alias to merge into (must be loaded)
211
+ method: str = "transport" # Default: transport and merge (cross-arch)
212
+ strategy: str = "equal" # "equal" (same strength each), "weighted", "sequential"
213
+
214
+
215
+ @dataclass
216
+ class AbsorbCmd:
217
+ """Absorb a single model into target — simplified merge. (Phase 6 — Easy Merge)
218
+
219
+ Example: absorb "deepseek-ai/DeepSeek-R1" into base strength 0.5
220
+ One-liner for the common case of merging one model in.
221
+ """
222
+ source: str # Model path or HF ID
223
+ target: str # Alias to merge into
224
+ strength: float = 0.5 # 0.0=keep target, 1.0=keep source, default balanced
225
+
226
+
227
+ @dataclass
228
+ class SnapshotCmd:
229
+ """Save a content-hashed snapshot of model state for lineage tracking. (Phase 4)
230
+
231
+ Example: snapshot base -> snapshots/
232
+ Creates a content-addressed directory: snapshots/<sha256_prefix>/
233
+ Contains: model state, adapter state, prune spec, eval report, manifest.
234
+ """
235
+ target: str
236
+ output: Optional[str] = None # Output directory (default: td_lang_outputs/snapshots/)
237
+
238
+
239
+ @dataclass
240
+ class ReportCmd:
241
+ """Generate an economics report for this run. (Phase 4)
242
+
243
+ Example: report -> economics.json
244
+ Tracks: GPU hours, cost estimate, tokens processed, experiments run,
245
+ time per command, cost breakdown by phase.
246
+ """
247
+ output: Optional[str] = None # Output file path
248
+
249
+
250
+ # ============================================================================
251
+ # PHASE 8 — AUTOPILOT (setup, notify, save, on_error, resume)
252
+ # ============================================================================
253
+
254
+ @dataclass
255
+ class NotifyCmd:
256
+ """Send a notification via ntfy.sh. (Phase 8 — Autopilot)
257
+
258
+ Example: notify "Training complete!"
259
+ Uses curl to POST to the configured ntfy topic.
260
+ """
261
+ message: str
262
+
263
+
264
+ @dataclass
265
+ class SaveCmd:
266
+ """Save/upload model to cloud storage via rclone. (Phase 8 — Autopilot)
267
+
268
+ Example: save base to "gdrive:TD/models/v1"
269
+ Uses rclone to copy model checkpoint to Google Drive (or any rclone remote).
270
+ """
271
+ target: str # Alias of model to save
272
+ destination: str # rclone destination path
273
+
274
+
275
+ @dataclass
276
+ class SetupBlock:
277
+ """Auto-install dependencies and configure environment. (Phase 8 — Autopilot)
278
+
279
+ Example:
280
+ setup {
281
+ pip = [torch, transformers, peft, bitsandbytes, trl]
282
+ hf_token = env
283
+ notify = "ntfy.sh/my_ai"
284
+ }
285
+ """
286
+ pip_packages: list[str] = field(default_factory=list)
287
+ hf_token: Optional[str] = None # "env" = read HF_TOKEN from env
288
+ notify_url: Optional[str] = None # ntfy.sh topic URL
289
+
290
+
291
+ @dataclass
292
+ class OnErrorBlock:
293
+ """Crash recovery behavior. (Phase 8 — Autopilot)
294
+
295
+ Example:
296
+ on_error {
297
+ retry = 3
298
+ fallback = reduce_batch
299
+ notify = true
300
+ }
301
+ """
302
+ retry: int = 3 # Number of retries per failed step
303
+ fallback: str = "reduce_batch" # "reduce_batch", "skip", "snapshot_and_stop"
304
+ notify: bool = True # Send ntfy notification on error
305
+
306
+
307
+ # ============================================================================
308
+ # BLOCKS (gates, budget, contracts, etc.)
309
+ # ============================================================================
310
+
311
+ @dataclass
312
+ class GateBlock:
313
+ """Validation gates that must pass before commit.
314
+
315
+ Example:
316
+ gate {
317
+ must_pass = [canary, perplexity, thinking_mode]
318
+ }
319
+ """
320
+ must_pass: list[str] = field(default_factory=list)
321
+
322
+
323
+ @dataclass
324
+ class BudgetBlock:
325
+ """Resource budget — compiler refuses plans that exceed limits.
326
+
327
+ Example:
328
+ budget {
329
+ max_gpu_hours = 8
330
+ max_cost = 50.00
331
+ }
332
+ """
333
+ max_gpu_hours: Optional[float] = None
334
+ max_cost: Optional[float] = None
335
+ max_tokens: Optional[int] = None
336
+ max_experiments: Optional[int] = None
337
+
338
+
339
+ @dataclass
340
+ class DataContractBlock:
341
+ """Schema enforcement on training data. (Phase 4, ForgeSpec 2.0)
342
+
343
+ Example:
344
+ data_contract {
345
+ required_fields = [prompt, response]
346
+ min_samples = 100
347
+ max_perplexity = 50.0
348
+ }
349
+
350
+ Compiler checks training data at synth/train time.
351
+ """
352
+ required_fields: list[str] = field(default_factory=list)
353
+ min_samples: Optional[int] = None
354
+ max_perplexity: Optional[float] = None
355
+
356
+
357
+ @dataclass
358
+ class RewardContractBlock:
359
+ """Verified reward definitions — what counts as "correct". (Phase 4, ForgeSpec 2.0)
360
+
361
+ Example:
362
+ reward_contract {
363
+ verifiers = [code_compiles, math_correct, no_hallucination]
364
+ min_reward = 0.3
365
+ }
366
+
367
+ Used by train (GRPO) to enforce reward quality.
368
+ No learned reward model — verified rewards only (test_16).
369
+ """
370
+ verifiers: list[str] = field(default_factory=list)
371
+ min_reward: Optional[float] = None
372
+
373
+
374
+ # ============================================================================
375
+ # TOP-LEVEL PROGRAM
376
+ # ============================================================================
377
+
378
+ @dataclass
379
+ class TDProgram:
380
+ """A complete parsed .td file — commands in order plus global blocks."""
381
+
382
+ commands: List[Any] = field(default_factory=list)
383
+ gates: Optional[GateBlock] = None
384
+ budget: Optional[BudgetBlock] = None
385
+ data_contract: Optional[DataContractBlock] = None
386
+ reward_contract: Optional[RewardContractBlock] = None
387
+ setup: Optional[SetupBlock] = None
388
+ on_error: Optional[OnErrorBlock] = None
389
+ source_file: Optional[str] = None
390
+
391
+
392
+ __all__ = [
393
+ "LoadCmd",
394
+ "MergeCmd",
395
+ "HealCmd",
396
+ "EvalCmd",
397
+ "CommitCmd",
398
+ "SynthCmd",
399
+ "TrainCmd",
400
+ "DebateCmd",
401
+ "DiagnoseCmd",
402
+ "ForkCmd",
403
+ "ResetCmd",
404
+ "PruneCmd",
405
+ "EditCmd",
406
+ "RepeatBlock",
407
+ "IfBlock",
408
+ "FuseCmd",
409
+ "AbsorbCmd",
410
+ "SnapshotCmd",
411
+ "ReportCmd",
412
+ "NotifyCmd",
413
+ "SaveCmd",
414
+ "SetupBlock",
415
+ "OnErrorBlock",
416
+ "GateBlock",
417
+ "BudgetBlock",
418
+ "DataContractBlock",
419
+ "RewardContractBlock",
420
+ "TDProgram",
421
+ ]
hugging/td_lang/cli.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Lang CLI — Command-line interface for .td files.
3
+
4
+ Usage:
5
+ python -m td_lang run examples/demo_merge.td # Compile + execute
6
+ python -m td_lang compile examples/demo_merge.td # Compile only (outputs .py)
7
+ python -m td_lang check examples/demo_merge.td # Syntax check only
8
+ python -m td_lang info examples/demo_merge.td # Show plan without compiling
9
+ python -m td_lang --version # Show version
10
+ """
11
+
12
+ import argparse
13
+ import sys
14
+
15
+ from . import __version__
16
+ from .executor import TDExecutor
17
+ from .errors import TDLangError
18
+ from .grammar import parse_td_file
19
+ from .ast_nodes import (
20
+ LoadCmd, MergeCmd, HealCmd, EvalCmd, CommitCmd,
21
+ SynthCmd, TrainCmd, DebateCmd, DiagnoseCmd,
22
+ ForkCmd, ResetCmd, PruneCmd, EditCmd,
23
+ FuseCmd, AbsorbCmd, RepeatBlock, IfBlock,
24
+ NotifyCmd, SaveCmd,
25
+ SnapshotCmd, ReportCmd,
26
+ )
27
+
28
+
29
+ # Phase labels for info command
30
+ _PHASE_MAP = {
31
+ LoadCmd: ("1", "load"),
32
+ MergeCmd: ("1", "merge"),
33
+ HealCmd: ("1", "heal"),
34
+ EvalCmd: ("1", "eval"),
35
+ CommitCmd: ("1", "commit"),
36
+ SynthCmd: ("2", "synth"),
37
+ TrainCmd: ("2", "train"),
38
+ DebateCmd: ("2", "debate"),
39
+ DiagnoseCmd: ("2", "diagnose"),
40
+ ForkCmd: ("3", "fork"),
41
+ ResetCmd: ("3", "reset"),
42
+ PruneCmd: ("3", "prune"),
43
+ EditCmd: ("3", "edit"),
44
+ FuseCmd: ("6", "fuse"),
45
+ AbsorbCmd: ("6", "absorb"),
46
+ RepeatBlock: ("7", "repeat"),
47
+ IfBlock: ("7", "if"),
48
+ NotifyCmd: ("8", "notify"),
49
+ SaveCmd: ("8", "save"),
50
+ SnapshotCmd: ("4", "snapshot"),
51
+ ReportCmd: ("4", "report"),
52
+ }
53
+
54
+
55
+ def parse_args() -> argparse.Namespace:
56
+ """Parse command-line arguments."""
57
+ parser = argparse.ArgumentParser(
58
+ description="TD Lang — compile and run .td files for Time Dilation",
59
+ formatter_class=argparse.RawDescriptionHelpFormatter,
60
+ epilog="""
61
+ Examples:
62
+ python -m td_lang check examples/demo_merge.td # Check syntax
63
+ python -m td_lang compile examples/demo_merge.td # Compile to .py
64
+ python -m td_lang run examples/demo_merge.td # Compile + run
65
+ python -m td_lang run examples/demo_merge.td --dry # Compile only
66
+ python -m td_lang info examples/demo_merge.td # Show plan summary
67
+ """,
68
+ )
69
+
70
+ parser.add_argument(
71
+ "--version",
72
+ action="version",
73
+ version=f"td_lang {__version__}",
74
+ )
75
+
76
+ parser.add_argument(
77
+ "action",
78
+ choices=["check", "compile", "run", "info"],
79
+ help="What to do: check (syntax), compile (.py), run (compile+execute), info (show plan)",
80
+ )
81
+
82
+ parser.add_argument(
83
+ "file",
84
+ type=str,
85
+ help="Path to the .td file",
86
+ )
87
+
88
+ parser.add_argument(
89
+ "--output",
90
+ type=str,
91
+ default="td_lang_outputs",
92
+ help="Output directory (default: td_lang_outputs)",
93
+ )
94
+
95
+ parser.add_argument(
96
+ "--dry",
97
+ action="store_true",
98
+ help="With 'run': compile but don't execute",
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--verbose", "-v",
103
+ action="store_true",
104
+ help="Show extra detail (compiled Python, full AST, etc.)",
105
+ )
106
+
107
+ return parser.parse_args()
108
+
109
+
110
+ def print_banner():
111
+ """Print the td_lang banner."""
112
+ banner = f"""
113
+ ╔═══════════════════════════════════════╗
114
+ ║ ║
115
+ ║ ████████╗██████╗ ██╗ ██████╗║
116
+ ║ ╚══██╔══╝██╔══██╗ ██║ ██╔════╝║
117
+ ║ ██║ ██║ ██║ ██║ ██║ ███║
118
+ ║ ██║ ██║ ██║ ██║ ██║ ██║
119
+ ║ ██║ ██████╔╝ ██████╗ ╚██████╔╝║
120
+ ║ ╚═╝ ╚═════╝ ╚═════╝ ╚═════╝║
121
+ ║ ║
122
+ ║ TD Lang v{__version__} — .td file compiler ║
123
+ ║ ║
124
+ ╚═══════════════════════════════════════╝
125
+ """
126
+ print(banner)
127
+
128
+
129
+ def print_info(filepath: str) -> None:
130
+ """Show what a .td file does without compiling — human-readable plan summary."""
131
+ program = parse_td_file(filepath)
132
+
133
+ print(f"\n File: {filepath}")
134
+ print(f" Commands: {len(program.commands)}")
135
+
136
+ if program.gates:
137
+ print(f" Gates: {', '.join(program.gates.must_pass)}")
138
+ if program.budget:
139
+ parts = []
140
+ if program.budget.max_gpu_hours is not None:
141
+ parts.append(f"{program.budget.max_gpu_hours} GPU hrs")
142
+ if program.budget.max_cost is not None:
143
+ parts.append(f"${program.budget.max_cost}")
144
+ print(f" Budget: {', '.join(parts)}")
145
+ if program.data_contract:
146
+ print(f" Data contract: fields={program.data_contract.required_fields}")
147
+ if program.reward_contract:
148
+ print(f" Reward contract: verifiers={program.reward_contract.verifiers}")
149
+
150
+ print("\n Plan:")
151
+ for i, cmd in enumerate(program.commands, 1):
152
+ phase, name = _PHASE_MAP.get(type(cmd), ("?", type(cmd).__name__))
153
+ target = getattr(cmd, 'target', getattr(cmd, 'alias', ''))
154
+ detail = ""
155
+ if hasattr(cmd, 'method'):
156
+ detail += f" method={cmd.method}"
157
+ if hasattr(cmd, 'source') and name in ("merge", "synth"):
158
+ detail += f" from={cmd.source}"
159
+ if hasattr(cmd, 'layers') and cmd.layers != "all":
160
+ detail += f" layers={cmd.layers}"
161
+ if hasattr(cmd, 'output') and cmd.output:
162
+ detail += f" -> {cmd.output}"
163
+ print(f" {i}. [P{phase}] {name} {target}{detail}")
164
+
165
+ print()
166
+
167
+
168
+ def main():
169
+ """Main entry point for td_lang CLI."""
170
+ args = parse_args()
171
+ print_banner()
172
+
173
+ executor = TDExecutor(output_dir=args.output)
174
+
175
+ try:
176
+ if args.action == "info":
177
+ print_info(args.file)
178
+
179
+ elif args.action == "check":
180
+ program = executor.check(args.file)
181
+ print("\n[td_lang] File is valid!")
182
+
183
+ elif args.action == "compile":
184
+ py_path = executor.compile(args.file)
185
+ print(f"\n[td_lang] Generated: {py_path}")
186
+ print("[td_lang] You can run it with: python", py_path)
187
+ if args.verbose:
188
+ print("\n--- Generated Python ---")
189
+ print(py_path.read_text())
190
+ print("--- End ---")
191
+
192
+ elif args.action == "run":
193
+ result = executor.run(args.file, dry_run=args.dry)
194
+ if result["status"] == "success":
195
+ sys.exit(0)
196
+ elif result["status"] == "dry_run":
197
+ sys.exit(0)
198
+ else:
199
+ sys.exit(1)
200
+
201
+ except TDLangError as e:
202
+ print(f"\n[td_lang] ERROR: {e}")
203
+ sys.exit(1)
204
+
205
+ except FileNotFoundError:
206
+ print(f"\n[td_lang] ERROR: File not found: {args.file}")
207
+ print("[td_lang] Check the path and try again.")
208
+ sys.exit(1)
209
+
210
+ except KeyboardInterrupt:
211
+ print("\n[td_lang] Interrupted.")
212
+ sys.exit(130)
hugging/td_lang/compiler.py ADDED
The diff for this file is too large to render. See raw diff
 
hugging/td_lang/errors.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD Lang Errors — Clear, helpful error messages.
3
+
4
+ Milan is 11 — errors should say what went wrong and where,
5
+ not dump cryptic stack traces.
6
+ """
7
+
8
+
9
+ class TDLangError(Exception):
10
+ """Base error for all td_lang errors."""
11
+
12
+ def __init__(self, message: str, line: int | None = None, hint: str | None = None):
13
+ self.line = line
14
+ self.hint = hint
15
+ if line is not None:
16
+ full = f"Line {line}: {message}"
17
+ else:
18
+ full = message
19
+ if hint:
20
+ full += f"\n Hint: {hint}"
21
+ super().__init__(full)
22
+
23
+
24
+ class TDSyntaxError(TDLangError):
25
+ """Bad .td syntax — couldn't understand the file."""
26
+ pass
27
+
28
+
29
+ class TDCompileError(TDLangError):
30
+ """Valid syntax but impossible plan — e.g., merging into a model that doesn't exist."""
31
+ pass
32
+
33
+
34
+ class TDGateError(TDLangError):
35
+ """Gates failed during execution."""
36
+
37
+ def __init__(self, failed_gates: list[str], message: str = ""):
38
+ self.failed_gates = failed_gates
39
+ msg = message or f"Gates failed: {', '.join(failed_gates)}"
40
+ super().__init__(msg, hint="Check eval results — the model may have regressed.")
41
+
42
+
43
+ class TDBudgetError(TDLangError):
44
+ """Budget would be exceeded — compiler refuses to run."""
45
+
46
+ def __init__(self, field: str, limit: float, requested: float):
47
+ self.field = field
48
+ self.limit = limit
49
+ self.requested = requested
50
+ super().__init__(
51
+ f"Budget exceeded: {field} limit is {limit}, but plan needs ~{requested}",
52
+ hint="Reduce steps, use fewer merges, or increase the budget.",
53
+ )
54
+
55
+
56
+ class TDContractError(TDLangError):
57
+ """Data or reward contract violation — training data doesn't match spec."""
58
+
59
+ def __init__(self, contract_type: str, violations: list[str]):
60
+ self.contract_type = contract_type
61
+ self.violations = violations
62
+ msg = f"{contract_type} contract failed with {len(violations)} violation(s)"
63
+ if violations:
64
+ msg += f": {violations[0]}"
65
+ if len(violations) > 1:
66
+ msg += f" (and {len(violations)-1} more)"
67
+ super().__init__(
68
+ msg,
69
+ hint="Check your training data matches the contract spec.",
70
+ )
71
+
72
+
73
+ # ============================================================================
74
+ # COMMON MISTAKE SUGGESTIONS (Phase 5)
75
+ # ============================================================================
76
+
77
+ COMMON_FIXES = {
78
+ "load": 'Did you forget quotes? Correct: load "model/path" as name',
79
+ "merge": 'Format: merge "source" into target using method [strength 0.5]',
80
+ "edit": "Format: edit target layers 16-28 using lora [lr 1e-4]",
81
+ "prune": "Format: prune target using wanda [aggressiveness 0.2]",
82
+ "fork": "Format: fork source as new_name",
83
+ "reset": 'Format: reset target to "checkpoint_path"',
84
+ "train": 'Format: train target on "dataset" using grpo [steps 64]',
85
+ "synth": "Format: synth target from source [filter cherry_llm]",
86
+ "snapshot": "Format: snapshot target [-> output_dir]",
87
+ "report": "Format: report [-> economics.json]",
88
+ "fuse": 'Format: fuse ["model1", "model2"] into target [strategy equal]',
89
+ "absorb": 'Format: absorb "model" into target [strength 0.5]',
90
+ }
91
+
92
+
93
+ def suggest_fix(token: str) -> str | None:
94
+ """Given a failed token, suggest the correct syntax."""
95
+ token_lower = token.lower().strip()
96
+ for keyword, fix in COMMON_FIXES.items():
97
+ if keyword in token_lower:
98
+ return fix
99
+ return None
hugging/td_lang/examples/demo_autopilot.td ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo_autopilot.td — The full "rent a GPU and go" pipeline
2
+ # Rent vast.ai, upload this file, run: python -m td_lang run demo_autopilot.td
3
+ # Then sit back — you'll get ntfy notifications on your phone.
4
+
5
+ # === ENVIRONMENT ===
6
+ setup {
7
+ pip = [torch, transformers, peft, bitsandbytes, trl, safetensors, datasets, accelerate, huggingface_hub, sentencepiece]
8
+ hf_token = env
9
+ notify = "ntfy.sh/my_ai"
10
+ }
11
+
12
+ on_error {
13
+ retry = 3
14
+ fallback = reduce_batch
15
+ notify = true
16
+ }
17
+
18
+ # === QUALITY RULES ===
19
+ gate { must_pass = [canary, perplexity, thinking_mode] }
20
+ budget { max_gpu_hours = 40 max_cost = 160.00 }
21
+
22
+ data_contract {
23
+ required_fields = [prompt, response]
24
+ min_samples = 50
25
+ max_perplexity = 50.0
26
+ }
27
+
28
+ reward_contract {
29
+ verifiers = [code_compiles, math_correct]
30
+ min_reward = 0.3
31
+ }
32
+
33
+ # === PIPELINE ===
34
+
35
+ # Step 1: Load and fuse
36
+ load "Qwen/Qwen3-VL-8B-Instruct" as base
37
+ fuse ["deepseek-ai/DeepSeek-R1", "MiMo-7B", "meta-llama/Llama-3.1-8B", "tiiuae/Falcon-H1R-7B"] into base
38
+ heal base lora_r 32 epochs 2
39
+ notify "Merge + heal complete. Starting self-improvement loop."
40
+
41
+ # Step 2: Self-improvement loop
42
+ repeat 5 {
43
+ diagnose base -> weaknesses.json
44
+ synth base from base filter cherry_llm -> training_data.jsonl
45
+ train base on "training_data.jsonl" using grpo steps 64 lr 5e-5
46
+ eval base -> eval_results.json
47
+
48
+ if eval_passed base {
49
+ commit base
50
+ snapshot base -> snapshots/
51
+ notify "Loop iteration passed! Model improved."
52
+ } else {
53
+ reset base to "snapshots/"
54
+ notify "Loop iteration failed. Reset to last good snapshot."
55
+ }
56
+ }
57
+
58
+ # Step 3: Save and notify
59
+ snapshot base -> final_model/
60
+ save base to "gdrive:TD/models/final"
61
+ report -> economics.json
62
+ notify "TD PIPELINE COMPLETE. Model saved to Google Drive."
hugging/td_lang/examples/demo_full.td ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Full Phase 1 demo with gates and budget
2
+ gate {
3
+ must_pass = [canary, perplexity, thinking_mode]
4
+ }
5
+
6
+ budget {
7
+ max_gpu_hours = 8
8
+ max_cost = 50.00
9
+ max_tokens = 20000000
10
+ max_experiments = 4
11
+ }
12
+
13
+ load "Qwen/Qwen3-VL-8B-Instruct" as base
14
+ merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
15
+ heal base lora_r 32 epochs 2
16
+ eval base -> full_eval.json
17
+ commit base
hugging/td_lang/examples/demo_fuse.td ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo_fuse.td — Easy merge: fuse multiple models in one command
2
+ # The entire TD merge strategy in 5 lines
3
+
4
+ gate { must_pass = [canary, perplexity, thinking_mode] }
5
+ budget { max_gpu_hours = 30 max_cost = 120.00 }
6
+
7
+ load "Qwen/Qwen3-VL-8B-Instruct" as base
8
+
9
+ # Fuse all 4 donor models in one shot — auto Transport and Merge
10
+ fuse ["deepseek-ai/DeepSeek-R1", "MiMo-7B", "meta-llama/Llama-3.1-8B", "tiiuae/Falcon-H1R-7B"] into base
11
+
12
+ # Or absorb a single model with custom strength
13
+ # absorb "deepseek-ai/DeepSeek-R1" into base strength 0.6
14
+
15
+ heal base lora_r 32 epochs 2
16
+ eval base -> post_fuse_eval.json
17
+ commit base if [canary, perplexity, thinking_mode]
18
+ snapshot base -> snapshots/
19
+ report -> economics.json
hugging/td_lang/examples/demo_heal.td ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Demo: merge then heal, evaluate, and commit with gates
2
+ load "Qwen/Qwen3-VL-8B-Instruct" as base
3
+ merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
4
+ heal base lora_r 32 epochs 2
5
+ eval base -> report.json
6
+ commit base if [canary, perplexity, thinking_mode]
hugging/td_lang/examples/demo_loop.td ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo_loop.td — Self-improvement loop (Phase 2)
2
+ # The core TD cycle: diagnose -> synth -> train -> evaluate -> commit
3
+
4
+ gate {
5
+ must_pass = [canary, perplexity, thinking_mode]
6
+ }
7
+
8
+ budget {
9
+ max_gpu_hours = 10
10
+ max_cost = 40.00
11
+ }
12
+
13
+ load "Qwen/Qwen3-VL-8B-Instruct" as base
14
+
15
+ # Step 1: Ask the model what it's bad at
16
+ diagnose base -> weaknesses.json
17
+
18
+ # Step 2: Generate training data targeting those weaknesses
19
+ synth base from web_curated filter cherry_llm -> synth_data.jsonl
20
+
21
+ # Step 3: Train with GRPO (64 steps = sweet spot from test_15)
22
+ train base on "synth_data.jsonl" using grpo steps 64
23
+
24
+ # Step 4: Check if it actually got better
25
+ eval base -> post_training_eval.json
26
+
27
+ # Step 5: Only save if gates pass
28
+ commit base
hugging/td_lang/examples/demo_merge.td ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Demo: load + merge + eval + commit
2
+ load "Qwen/Qwen3-VL-8B-Instruct" as base
3
+ merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
4
+ eval base -> eval_base.json
5
+ commit base if [canary, perplexity, thinking_mode]
hugging/td_lang/examples/demo_phase3.td ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo_phase3.td — Phase 3 commands: edit, fork, reset, prune
2
+ # The full surgical toolkit for model experimentation
3
+
4
+ gate {
5
+ must_pass = [canary, perplexity, thinking_mode]
6
+ }
7
+
8
+ budget {
9
+ max_gpu_hours = 12
10
+ max_cost = 60.00
11
+ }
12
+
13
+ # Load the base model
14
+ load "Qwen/Qwen3-VL-8B-Instruct" as base
15
+
16
+ # Fork before experimenting (like git branch)
17
+ fork base as experiment
18
+
19
+ # Surgical edit: LoRA on reasoning layers 16-28
20
+ edit experiment layers 16-28 using lora lr 1e-4
21
+
22
+ # Evaluate the edit
23
+ eval experiment -> post_edit_eval.json
24
+
25
+ # If it's good, commit; if bad, we can reset
26
+ commit experiment
hugging/td_lang/examples/demo_phase4.td ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo_phase4.td — Phase 4: Contracts, Lineage, Economics
2
+ # ForgeSpec 2.0 features from test_17
3
+
4
+ gate { must_pass = [canary, perplexity, thinking_mode] }
5
+
6
+ budget {
7
+ max_gpu_hours = 20
8
+ max_cost = 100.00
9
+ }
10
+
11
+ data_contract {
12
+ required_fields = [prompt, response]
13
+ min_samples = 100
14
+ max_perplexity = 50.0
15
+ }
16
+
17
+ reward_contract {
18
+ verifiers = [code_compiles, math_correct]
19
+ min_reward = 0.3
20
+ }
21
+
22
+ # Pipeline with full tracking
23
+ load "Qwen/Qwen3-VL-8B-Instruct" as base
24
+ fork base as experiment
25
+
26
+ edit experiment layers 16-28 using lora lr 1e-4
27
+ snapshot experiment -> snapshots/
28
+
29
+ eval experiment -> post_edit_eval.json
30
+ commit experiment
31
+
32
+ # Economics report at the end
33
+ report -> economics.json
hugging/td_lang/examples/demo_td_loop.td ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo_td_loop.td — The complete TD self-improvement pipeline
2
+ # This is what td_loop runs: merge, then iterate to get smarter
3
+
4
+ gate { must_pass = [canary, perplexity, thinking_mode] }
5
+ budget { max_gpu_hours = 50 max_cost = 200.00 }
6
+
7
+ data_contract {
8
+ required_fields = [prompt, response]
9
+ min_samples = 50
10
+ max_perplexity = 50.0
11
+ }
12
+
13
+ reward_contract {
14
+ verifiers = [code_compiles, math_correct]
15
+ min_reward = 0.3
16
+ }
17
+
18
+ # Step 1: Load base model
19
+ load "Qwen/Qwen3-VL-8B-Instruct" as base
20
+
21
+ # Step 2: Fuse all donor models in one shot
22
+ fuse ["deepseek-ai/DeepSeek-R1", "MiMo-7B", "meta-llama/Llama-3.1-8B", "tiiuae/Falcon-H1R-7B"] into base
23
+
24
+ # Step 3: Heal the merge damage
25
+ heal base lora_r 32 epochs 2
26
+ snapshot base -> snapshots/
27
+
28
+ # Step 4: Self-improvement loop (the core of TD)
29
+ repeat 5 {
30
+ diagnose base -> weaknesses.json
31
+ synth base from base filter cherry_llm -> training_data.jsonl
32
+ train base on "training_data.jsonl" using grpo steps 64 lr 5e-5
33
+ eval base -> eval_results.json
34
+
35
+ if eval_passed base {
36
+ commit base
37
+ snapshot base -> snapshots/
38
+ } else {
39
+ reset base to "snapshots/"
40
+ }
41
+ }
42
+
43
+ # Step 5: Final report
44
+ report -> final_economics.json
hugging/td_lang/examples/err_edit_unloaded.td ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # err_edit_unloaded.td — Should fail: editing a model before loading
2
+ edit ghost_model layers all using lora
hugging/td_lang/examples/err_fork_duplicate.td ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # err_fork_duplicate.td — Should fail: duplicate name
2
+ load "test" as base
3
+ fork base as base
hugging/td_lang/examples/err_prune_100.td ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # err_prune_100.td — Should fail/warn: prune at 100%
2
+ load "test" as base
3
+ prune base using wanda aggressiveness 1.0
4
+ # Note: Compiler might cap it at 30% per implementation notes
hugging/td_lang/examples/test_fork_edit.td ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_fork_edit.td — Test load -> fork -> edit -> eval -> commit
2
+
3
+ load "Qwen/Qwen3-VL-8B-Instruct" as base
4
+
5
+ # Fork the base model
6
+ fork base as experimental_branch
7
+
8
+ # Surgical edit with DoRA on specific layers
9
+ edit experimental_branch layers 20-28 using dora lr 1e-4
10
+
11
+ eval experimental_branch -> edit_report.json
12
+ commit experimental_branch