Upload 55 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- hugging/QUICKSTART.md +106 -0
- hugging/deploy.sh +128 -0
- hugging/requirements.txt +226 -0
- hugging/td_fuse/__init__.py +25 -0
- hugging/td_fuse/__main__.py +4 -0
- hugging/td_fuse/canary.py +178 -0
- hugging/td_fuse/config.py +299 -0
- hugging/td_fuse/heal.py +363 -0
- hugging/td_fuse/merge.py +985 -0
- hugging/td_fuse/run.py +279 -0
- hugging/td_fuse/techniques.py +669 -0
- hugging/td_fuse/transport.py +527 -0
- hugging/td_fuse/validate.py +215 -0
- hugging/td_lang/.DS_Store +0 -0
- hugging/td_lang/__init__.py +51 -0
- hugging/td_lang/__main__.py +5 -0
- hugging/td_lang/__pycache__/__init__.cpython-310.pyc +0 -0
- hugging/td_lang/__pycache__/__init__.cpython-314.pyc +0 -0
- hugging/td_lang/__pycache__/__main__.cpython-310.pyc +0 -0
- hugging/td_lang/__pycache__/__main__.cpython-314.pyc +0 -0
- hugging/td_lang/__pycache__/ast_nodes.cpython-310.pyc +0 -0
- hugging/td_lang/__pycache__/ast_nodes.cpython-314.pyc +0 -0
- hugging/td_lang/__pycache__/cli.cpython-310.pyc +0 -0
- hugging/td_lang/__pycache__/cli.cpython-314.pyc +0 -0
- hugging/td_lang/__pycache__/compiler.cpython-310.pyc +0 -0
- hugging/td_lang/__pycache__/compiler.cpython-314.pyc +3 -0
- hugging/td_lang/__pycache__/errors.cpython-310.pyc +0 -0
- hugging/td_lang/__pycache__/errors.cpython-314.pyc +0 -0
- hugging/td_lang/__pycache__/executor.cpython-310.pyc +0 -0
- hugging/td_lang/__pycache__/executor.cpython-314.pyc +0 -0
- hugging/td_lang/__pycache__/grammar.cpython-310.pyc +0 -0
- hugging/td_lang/__pycache__/grammar.cpython-314.pyc +0 -0
- hugging/td_lang/ast_nodes.py +421 -0
- hugging/td_lang/cli.py +212 -0
- hugging/td_lang/compiler.py +0 -0
- hugging/td_lang/errors.py +99 -0
- hugging/td_lang/examples/demo_autopilot.td +62 -0
- hugging/td_lang/examples/demo_full.td +17 -0
- hugging/td_lang/examples/demo_fuse.td +19 -0
- hugging/td_lang/examples/demo_heal.td +6 -0
- hugging/td_lang/examples/demo_loop.td +28 -0
- hugging/td_lang/examples/demo_merge.td +5 -0
- hugging/td_lang/examples/demo_phase3.td +26 -0
- hugging/td_lang/examples/demo_phase4.td +33 -0
- hugging/td_lang/examples/demo_td_loop.td +44 -0
- hugging/td_lang/examples/err_edit_unloaded.td +2 -0
- hugging/td_lang/examples/err_fork_duplicate.td +3 -0
- hugging/td_lang/examples/err_prune_100.td +4 -0
- hugging/td_lang/examples/test_fork_edit.td +12 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
hugging/td_lang/__pycache__/compiler.cpython-314.pyc filter=lfs diff=lfs merge=lfs -text
|
hugging/QUICKSTART.md
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TD Quick Start — Rent a GPU and Go
|
| 2 |
+
|
| 3 |
+
## What You Need (One-Time Setup)
|
| 4 |
+
|
| 5 |
+
1. **vast.ai account** — sign up at vast.ai, add credit ($10-20 to start)
|
| 6 |
+
2. **HuggingFace account** — sign up at huggingface.co (use any username, doesn't have to be your real name)
|
| 7 |
+
3. **HuggingFace token** — Settings → Access Tokens → New Token → **Write** access
|
| 8 |
+
4. **ntfy.sh app** on your phone (you already have this)
|
| 9 |
+
|
| 10 |
+
## One-Time: Upload Your Code to Private HuggingFace
|
| 11 |
+
|
| 12 |
+
Do this once from your computer. After this, your code lives in a private repo that only you can see.
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
# Install the tool
|
| 16 |
+
pip install huggingface_hub
|
| 17 |
+
|
| 18 |
+
# Log in (paste your token when asked)
|
| 19 |
+
huggingface-cli login
|
| 20 |
+
|
| 21 |
+
# Upload everything
|
| 22 |
+
HF_USER=your_hf_username bash upload_to_hf.sh
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
Now your td_lang, td_fuse, .td files, and deploy script are all in a private HuggingFace repo. Nobody can see them except you.
|
| 26 |
+
|
| 27 |
+
**When you update your code**, just run `upload_to_hf.sh` again — it overwrites with the latest version.
|
| 28 |
+
|
| 29 |
+
## Every Time: Rent GPU → 3 Commands → Done
|
| 30 |
+
|
| 31 |
+
### 1. Rent a GPU on vast.ai
|
| 32 |
+
|
| 33 |
+
Go to vast.ai → Console → Search for:
|
| 34 |
+
- **GPU:** RTX 4090 (24GB) or A100 (40GB+)
|
| 35 |
+
- **Image:** Pick one with PyTorch pre-installed (like `pytorch/pytorch`)
|
| 36 |
+
- **Storage:** At least 100GB disk
|
| 37 |
+
- **Cost:** ~$0.40-0.80/hr for a 4090
|
| 38 |
+
|
| 39 |
+
Click **RENT** and wait for it to start (~1-2 minutes).
|
| 40 |
+
|
| 41 |
+
### 2. Connect to the GPU
|
| 42 |
+
|
| 43 |
+
vast.ai gives you an SSH command. Copy and paste it into your terminal:
|
| 44 |
+
```
|
| 45 |
+
ssh -p 12345 root@ssh1.vast.ai
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### 3. Run these 3 commands
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
# Set your token
|
| 52 |
+
export HF_TOKEN=hf_your_token_here
|
| 53 |
+
|
| 54 |
+
# Download your code from HuggingFace (takes ~10 seconds)
|
| 55 |
+
pip install huggingface_hub -q && python -c "
|
| 56 |
+
from huggingface_hub import snapshot_download
|
| 57 |
+
snapshot_download('YOUR_USERNAME/td-toolkit', local_dir='/workspace/td')
|
| 58 |
+
"
|
| 59 |
+
|
| 60 |
+
# Go!
|
| 61 |
+
cd /workspace/td && bash deploy.sh demo_autopilot.td
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
That's it. Put your phone down. ntfy.sh sends you updates as it runs.
|
| 65 |
+
|
| 66 |
+
### 4. When it's done
|
| 67 |
+
|
| 68 |
+
Your model gets saved to Google Drive automatically (if rclone is configured in the .td file). Otherwise it stays on the GPU at `final_model/`.
|
| 69 |
+
|
| 70 |
+
## Setting Up Google Drive (Optional, One-Time per GPU)
|
| 71 |
+
|
| 72 |
+
On the GPU machine after SSHing in:
|
| 73 |
+
```bash
|
| 74 |
+
rclone config
|
| 75 |
+
```
|
| 76 |
+
1. Type `n` for new remote
|
| 77 |
+
2. Name it `gdrive`
|
| 78 |
+
3. Pick `Google Drive` from the list
|
| 79 |
+
4. Follow the prompts (it gives you a URL to visit in your browser)
|
| 80 |
+
5. Done — now `save base to "gdrive:TD/models/final"` works in your .td files
|
| 81 |
+
|
| 82 |
+
**Tip:** You can save the rclone config to your HuggingFace repo too, so you don't have to set it up every time.
|
| 83 |
+
|
| 84 |
+
## Quick Reference
|
| 85 |
+
|
| 86 |
+
| Command | What it does |
|
| 87 |
+
|---------|-------------|
|
| 88 |
+
| `bash deploy.sh my_file.td` | Full setup + run |
|
| 89 |
+
| `python -m td_lang check my_file.td` | Check syntax only |
|
| 90 |
+
| `python -m td_lang info my_file.td` | Show plan without running |
|
| 91 |
+
| `python -m td_lang run my_file.td` | Run (skip deploy setup) |
|
| 92 |
+
| `python -m td_lang run my_file.td --dry` | Compile but don't execute |
|
| 93 |
+
|
| 94 |
+
## If Something Goes Wrong
|
| 95 |
+
|
| 96 |
+
- **OOM (out of memory):** Your .td file's `on_error` block handles this — it retries with smaller batches
|
| 97 |
+
- **Model download fails:** Check your HF_TOKEN is set correctly
|
| 98 |
+
- **ntfy not working:** Check your phone has the ntfy app and you're subscribed to the right topic
|
| 99 |
+
- **GPU disconnects:** Re-SSH in, your files are still there. Run deploy.sh again — td_lang picks up from the last snapshot
|
| 100 |
+
|
| 101 |
+
## Cost Estimate
|
| 102 |
+
|
| 103 |
+
For the full `demo_autopilot.td` pipeline (merge 4 models + 5 training loops):
|
| 104 |
+
- **RTX 4090:** ~$0.50/hr × ~30-40 hrs = ~$15-20
|
| 105 |
+
- **A100 40GB:** ~$1.00/hr × ~20-30 hrs = ~$20-30
|
| 106 |
+
- **Budget cap in .td file:** Set `max_cost = 160.00` to prevent runaway costs
|
hugging/deploy.sh
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# deploy.sh — One-command setup for vast.ai GPU instances
|
| 3 |
+
#
|
| 4 |
+
# TWO ways to use this:
|
| 5 |
+
#
|
| 6 |
+
# Option A — Download from your private HuggingFace repo + run:
|
| 7 |
+
# export HF_TOKEN=your_token
|
| 8 |
+
# pip install huggingface_hub
|
| 9 |
+
# python -c "from huggingface_hub import snapshot_download; snapshot_download('YOUR_USER/td-toolkit', local_dir='.')"
|
| 10 |
+
# bash deploy.sh demo_autopilot.td
|
| 11 |
+
#
|
| 12 |
+
# Option B — Already uploaded files manually:
|
| 13 |
+
# bash deploy.sh my_pipeline.td
|
| 14 |
+
|
| 15 |
+
set -e # Stop on any error
|
| 16 |
+
|
| 17 |
+
# Colors for pretty output
|
| 18 |
+
GREEN='\033[0;32m'
|
| 19 |
+
YELLOW='\033[1;33m'
|
| 20 |
+
RED='\033[0;31m'
|
| 21 |
+
NC='\033[0m' # No Color
|
| 22 |
+
|
| 23 |
+
echo ""
|
| 24 |
+
echo "==========================================="
|
| 25 |
+
echo " TD Deploy — vast.ai GPU Setup"
|
| 26 |
+
echo "==========================================="
|
| 27 |
+
echo ""
|
| 28 |
+
|
| 29 |
+
# Check if a .td file was provided
|
| 30 |
+
if [ -z "$1" ]; then
|
| 31 |
+
echo -e "${RED}ERROR: No .td file specified${NC}"
|
| 32 |
+
echo ""
|
| 33 |
+
echo "Usage: bash deploy.sh my_pipeline.td"
|
| 34 |
+
echo ""
|
| 35 |
+
echo "Available .td files:"
|
| 36 |
+
ls -1 *.td td_lang/examples/*.td 2>/dev/null || echo " (none found)"
|
| 37 |
+
exit 1
|
| 38 |
+
fi
|
| 39 |
+
|
| 40 |
+
TD_FILE="$1"
|
| 41 |
+
|
| 42 |
+
if [ ! -f "$TD_FILE" ]; then
|
| 43 |
+
echo -e "${RED}ERROR: File not found: $TD_FILE${NC}"
|
| 44 |
+
exit 1
|
| 45 |
+
fi
|
| 46 |
+
|
| 47 |
+
echo -e "${GREEN}[1/5]${NC} Installing td_lang dependencies..."
|
| 48 |
+
pip install lark --quiet 2>/dev/null || pip install lark
|
| 49 |
+
echo " Done."
|
| 50 |
+
|
| 51 |
+
# Check for HF token
|
| 52 |
+
echo ""
|
| 53 |
+
echo -e "${GREEN}[2/5]${NC} Checking environment..."
|
| 54 |
+
if [ -z "$HF_TOKEN" ]; then
|
| 55 |
+
echo -e "${YELLOW} WARNING: HF_TOKEN not set.${NC}"
|
| 56 |
+
echo " Models won't download from HuggingFace without it."
|
| 57 |
+
echo " Set it with: export HF_TOKEN=your_token_here"
|
| 58 |
+
echo ""
|
| 59 |
+
read -p " Continue anyway? (y/n) " -n 1 -r
|
| 60 |
+
echo
|
| 61 |
+
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
| 62 |
+
exit 1
|
| 63 |
+
fi
|
| 64 |
+
else
|
| 65 |
+
echo " HF_TOKEN: set"
|
| 66 |
+
fi
|
| 67 |
+
|
| 68 |
+
# Check td_lang is accessible
|
| 69 |
+
echo ""
|
| 70 |
+
echo -e "${GREEN}[3/5]${NC} Checking td_lang..."
|
| 71 |
+
if python -c "import td_lang" 2>/dev/null; then
|
| 72 |
+
VERSION=$(python -c "import td_lang; print(td_lang.__version__)" 2>/dev/null || echo "unknown")
|
| 73 |
+
echo " td_lang v$VERSION: found"
|
| 74 |
+
else
|
| 75 |
+
# Try adding current directory to path
|
| 76 |
+
export PYTHONPATH="${PYTHONPATH:+$PYTHONPATH:}$(pwd)"
|
| 77 |
+
if python -c "import td_lang" 2>/dev/null; then
|
| 78 |
+
VERSION=$(python -c "import td_lang; print(td_lang.__version__)" 2>/dev/null || echo "unknown")
|
| 79 |
+
echo " td_lang v$VERSION: found (added to PYTHONPATH)"
|
| 80 |
+
else
|
| 81 |
+
echo -e "${RED} ERROR: td_lang not found!${NC}"
|
| 82 |
+
echo " Make sure the td_lang/ folder is in the current directory."
|
| 83 |
+
echo " Current directory: $(pwd)"
|
| 84 |
+
echo " Contents:"
|
| 85 |
+
ls -1
|
| 86 |
+
exit 1
|
| 87 |
+
fi
|
| 88 |
+
fi
|
| 89 |
+
|
| 90 |
+
# Check for rclone (needed for save command)
|
| 91 |
+
echo ""
|
| 92 |
+
echo -e "${GREEN}[4/5]${NC} Checking tools..."
|
| 93 |
+
if command -v rclone &> /dev/null; then
|
| 94 |
+
echo " rclone: installed"
|
| 95 |
+
if rclone listremotes 2>/dev/null | grep -q "gdrive:"; then
|
| 96 |
+
echo " Google Drive: configured"
|
| 97 |
+
else
|
| 98 |
+
echo -e "${YELLOW} Google Drive: not configured${NC}"
|
| 99 |
+
echo " Run 'rclone config' to set up Google Drive (name it 'gdrive')"
|
| 100 |
+
fi
|
| 101 |
+
else
|
| 102 |
+
echo -e "${YELLOW} rclone: not installed (installing...)${NC}"
|
| 103 |
+
curl -s https://rclone.org/install.sh | bash 2>/dev/null || {
|
| 104 |
+
echo -e "${YELLOW} Could not install rclone. 'save' commands won't work.${NC}"
|
| 105 |
+
}
|
| 106 |
+
fi
|
| 107 |
+
|
| 108 |
+
# Check GPU
|
| 109 |
+
if command -v nvidia-smi &> /dev/null; then
|
| 110 |
+
GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1)
|
| 111 |
+
GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader | head -1)
|
| 112 |
+
echo " GPU: $GPU_NAME ($GPU_MEM)"
|
| 113 |
+
else
|
| 114 |
+
echo -e "${YELLOW} WARNING: No GPU detected (nvidia-smi not found)${NC}"
|
| 115 |
+
fi
|
| 116 |
+
|
| 117 |
+
# Run the .td file
|
| 118 |
+
echo ""
|
| 119 |
+
echo -e "${GREEN}[5/5]${NC} Running: $TD_FILE"
|
| 120 |
+
echo "==========================================="
|
| 121 |
+
echo ""
|
| 122 |
+
|
| 123 |
+
python -m td_lang run "$TD_FILE"
|
| 124 |
+
|
| 125 |
+
echo ""
|
| 126 |
+
echo "==========================================="
|
| 127 |
+
echo -e "${GREEN} TD Deploy complete!${NC}"
|
| 128 |
+
echo "==========================================="
|
hugging/requirements.txt
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TD Merge Pipeline - Complete Python Dependency List
|
| 2 |
+
# Python 3.11-3.12 (3.12 preferred)
|
| 3 |
+
# CUDA 12.4 (RTX 4090 compatible)
|
| 4 |
+
# Updated: February 2026
|
| 5 |
+
|
| 6 |
+
# ============================================================================
|
| 7 |
+
# CORE ML FRAMEWORKS
|
| 8 |
+
# ============================================================================
|
| 9 |
+
|
| 10 |
+
# PyTorch 2.4+ with CUDA 12.4 support (RTX 4090 compatible)
|
| 11 |
+
torch==2.4.1
|
| 12 |
+
torchvision==0.19.1
|
| 13 |
+
torchaudio==2.4.1
|
| 14 |
+
|
| 15 |
+
# NVIDIA CUDA Toolkit support (already installed on system)
|
| 16 |
+
# CUDA 12.4 for RTX 4090 compatibility
|
| 17 |
+
# Note: Install via: pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
| 18 |
+
|
| 19 |
+
# ============================================================================
|
| 20 |
+
# TRANSFORMERS & MODEL LOADING
|
| 21 |
+
# ============================================================================
|
| 22 |
+
|
| 23 |
+
# Transformers library - must support Qwen3 (requires 4.51.0+)
|
| 24 |
+
transformers==4.51.0
|
| 25 |
+
|
| 26 |
+
# Safetensors for efficient model serialization
|
| 27 |
+
safetensors==0.4.5
|
| 28 |
+
|
| 29 |
+
# Accelerate for distributed training & multi-GPU support
|
| 30 |
+
accelerate==1.2.1
|
| 31 |
+
|
| 32 |
+
# ============================================================================
|
| 33 |
+
# PARAMETER EFFICIENT FINE-TUNING (PEFT/QLoRA)
|
| 34 |
+
# ============================================================================
|
| 35 |
+
|
| 36 |
+
# PEFT (Parameter-Efficient Fine-Tuning) - supports QLoRA
|
| 37 |
+
# Must be >= 0.14.0 for 8-bit weight merging
|
| 38 |
+
peft==0.14.0
|
| 39 |
+
|
| 40 |
+
# BitsAndBytes for 4-bit quantization (QLoRA)
|
| 41 |
+
# Works with PyTorch 2.4, stable with >= 0.42
|
| 42 |
+
bitsandbytes==0.44.0
|
| 43 |
+
|
| 44 |
+
# ============================================================================
|
| 45 |
+
# OPTIMAL TRANSPORT & MODEL MERGING
|
| 46 |
+
# ============================================================================
|
| 47 |
+
|
| 48 |
+
# POT (Python Optimal Transport) - for Transport and Merge algorithm
|
| 49 |
+
# Used for activation-aligned cross-architecture weight alignment
|
| 50 |
+
POT==0.9.6
|
| 51 |
+
|
| 52 |
+
# SciPy for optimization & linear algebra (OrthoMerge, LARV)
|
| 53 |
+
scipy==1.14.1
|
| 54 |
+
|
| 55 |
+
# NumPy for numerical operations
|
| 56 |
+
numpy==1.26.4
|
| 57 |
+
|
| 58 |
+
# Lark parser for td_lang DSL
|
| 59 |
+
lark>=1.1.0
|
| 60 |
+
|
| 61 |
+
# Unsloth for fast fine-tuning with 7B models
|
| 62 |
+
# Includes pre-quantized Qwen3-8B support, VLLM Standby Mode for concurrent training+inference
|
| 63 |
+
unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git@main
|
| 64 |
+
|
| 65 |
+
# ============================================================================
|
| 66 |
+
# REINFORCEMENT LEARNING (RL TRAINING)
|
| 67 |
+
# ============================================================================
|
| 68 |
+
|
| 69 |
+
# TRL (Transformers Reinforcement Learning)
|
| 70 |
+
# Provides GRPO (Group Relative Policy Optimization) trainer
|
| 71 |
+
# v0.27.2 stable, tested with transformers 4.40+
|
| 72 |
+
trl==0.27.2
|
| 73 |
+
|
| 74 |
+
# ============================================================================
|
| 75 |
+
# EVALUATION & BENCHMARKING
|
| 76 |
+
# ============================================================================
|
| 77 |
+
|
| 78 |
+
# LM-Eval (EleutherAI evaluation harness) for benchmarking
|
| 79 |
+
# Explicitly install HF backend for transformers support
|
| 80 |
+
lm-eval[hf]==0.4.10
|
| 81 |
+
|
| 82 |
+
# MathEval utilities
|
| 83 |
+
math-eval==0.0.3
|
| 84 |
+
|
| 85 |
+
# ============================================================================
|
| 86 |
+
# DATA HANDLING & DATASETS
|
| 87 |
+
# ============================================================================
|
| 88 |
+
|
| 89 |
+
# HuggingFace Datasets library (HF Hub integration)
|
| 90 |
+
datasets==4.5.1
|
| 91 |
+
|
| 92 |
+
# PyArrow for efficient data processing
|
| 93 |
+
pyarrow==17.0.0
|
| 94 |
+
|
| 95 |
+
# Pandas for data manipulation
|
| 96 |
+
pandas==2.2.3
|
| 97 |
+
|
| 98 |
+
# ============================================================================
|
| 99 |
+
# OPTIONAL: MERGING & FUSION (if not building Transport & Merge from scratch)
|
| 100 |
+
# ============================================================================
|
| 101 |
+
|
| 102 |
+
# MergeKit - alternative model merging tool (supports TIES/DARE-TIES)
|
| 103 |
+
# Note: Limited to same-architecture merges, but useful for fallback strategy
|
| 104 |
+
mergekit==0.0.7
|
| 105 |
+
|
| 106 |
+
# ============================================================================
|
| 107 |
+
# WEB & KNOWLEDGE RETRIEVAL (for ALAS - Autonomous Learning Agent System)
|
| 108 |
+
# ============================================================================
|
| 109 |
+
|
| 110 |
+
# Requests for HTTP operations
|
| 111 |
+
requests==2.31.0
|
| 112 |
+
|
| 113 |
+
# Beautiful Soup for web scraping
|
| 114 |
+
beautifulsoup4==4.12.3
|
| 115 |
+
|
| 116 |
+
# ============================================================================
|
| 117 |
+
# AGENT ORCHESTRATION & UTILITIES
|
| 118 |
+
# ============================================================================
|
| 119 |
+
|
| 120 |
+
# LangGraph for multi-agent coordination (SYMPHONY)
|
| 121 |
+
langgraph==0.2.7
|
| 122 |
+
|
| 123 |
+
# LangChain for prompt management & chains
|
| 124 |
+
langchain==0.3.9
|
| 125 |
+
|
| 126 |
+
# Pydantic for data validation
|
| 127 |
+
pydantic==2.8.2
|
| 128 |
+
|
| 129 |
+
# ============================================================================
|
| 130 |
+
# VISION AGENT (Fara-7B integration)
|
| 131 |
+
# ============================================================================
|
| 132 |
+
|
| 133 |
+
# Pillow for image processing
|
| 134 |
+
Pillow==11.2.0
|
| 135 |
+
|
| 136 |
+
# OpenCV for computer vision tasks
|
| 137 |
+
opencv-python==4.10.1.26
|
| 138 |
+
|
| 139 |
+
# ============================================================================
|
| 140 |
+
# INFERENCE & SERVING
|
| 141 |
+
# ============================================================================
|
| 142 |
+
|
| 143 |
+
# vLLM for fast LLM inference serving
|
| 144 |
+
vllm==0.6.4
|
| 145 |
+
|
| 146 |
+
# ============================================================================
|
| 147 |
+
# UTILITIES & LOGGING
|
| 148 |
+
# ============================================================================
|
| 149 |
+
|
| 150 |
+
# PyYAML for config files
|
| 151 |
+
PyYAML==6.0.2
|
| 152 |
+
|
| 153 |
+
# Python-dotenv for environment variable management
|
| 154 |
+
python-dotenv==1.0.1
|
| 155 |
+
|
| 156 |
+
# Tqdm for progress bars
|
| 157 |
+
tqdm==4.67.1
|
| 158 |
+
|
| 159 |
+
# Rich for beautiful terminal output
|
| 160 |
+
rich==13.8.1
|
| 161 |
+
|
| 162 |
+
# ============================================================================
|
| 163 |
+
# DEVELOPMENT & TESTING (OPTIONAL)
|
| 164 |
+
# ============================================================================
|
| 165 |
+
|
| 166 |
+
# Pytest for testing
|
| 167 |
+
pytest==8.3.2
|
| 168 |
+
|
| 169 |
+
# IPython for interactive development
|
| 170 |
+
ipython==8.20.0
|
| 171 |
+
|
| 172 |
+
# Jupyter for notebooks
|
| 173 |
+
jupyter==1.0.0
|
| 174 |
+
|
| 175 |
+
# ============================================================================
|
| 176 |
+
# VERSION NOTES & COMPATIBILITY MATRIX
|
| 177 |
+
# ============================================================================
|
| 178 |
+
#
|
| 179 |
+
# COMPATIBILITY VERIFIED:
|
| 180 |
+
# ✓ PyTorch 2.4.1 + CUDA 12.4 + RTX 4090 (full support)
|
| 181 |
+
# ✓ Transformers 4.51.0 + Qwen3-8B (latest, required for Qwen3)
|
| 182 |
+
# ✓ Unsloth 2026.2.x + Qwen3 + QLoRA (fast fine-tuning)
|
| 183 |
+
# ✓ BitsAndBytes 0.44.0 + PyTorch 2.4 (4-bit quantization)
|
| 184 |
+
# ✓ PEFT 0.14.0 + BitsAndBytes (8-bit weight merging)
|
| 185 |
+
# ✓ TRL 0.27.2 + GRPO (RL training with group advantage)
|
| 186 |
+
# ✓ POT 0.9.6 + SciPy 1.14.1 (optimal transport)
|
| 187 |
+
# ✓ LM-Eval 0.4.10[hf] + Transformers 4.51.0 (benchmarking)
|
| 188 |
+
#
|
| 189 |
+
# KNOWN ISSUES & WORKAROUNDS:
|
| 190 |
+
# - Flash-Attention-2: Works with Qwen3 but may produce incorrect outputs
|
| 191 |
+
# → Use attn_implementation="sdpa" (default) instead
|
| 192 |
+
# → DO NOT set attn_implementation="flash_attention_2"
|
| 193 |
+
#
|
| 194 |
+
# - BitsAndBytes + XFormers: Avoid mixing with older PyTorch versions
|
| 195 |
+
# → Use Unsloth bundled installer which pre-handles this
|
| 196 |
+
#
|
| 197 |
+
# - Thinking Mode Survival: Qwen3's thinking tokens (151668) may be scrambled
|
| 198 |
+
# → Freeze thinking token embeddings during Transport & Merge
|
| 199 |
+
# → Apply Contrastive Gradient Identification (ReasonAny) to protect reasoning params
|
| 200 |
+
# → Post-merge fine-tune on 500-1000 thinking examples
|
| 201 |
+
#
|
| 202 |
+
# CUDA 12.4 NOTES:
|
| 203 |
+
# - RTX 4090 full support (Ada architecture, compute capability 8.9)
|
| 204 |
+
# - All libraries compiled for CUDA 12.4 compatibility
|
| 205 |
+
# - No need to install system CUDA separately if PyTorch wheels handle it
|
| 206 |
+
#
|
| 207 |
+
# HARDWARE CHECKLIST:
|
| 208 |
+
# ✓ Dual RTX 4090 (48GB VRAM total) - adequate for full pipeline
|
| 209 |
+
# ✓ 64GB+ system RAM (128GB comfortable)
|
| 210 |
+
# ✓ 1500W+ PSU (handles 1.2kW sustained load)
|
| 211 |
+
# ✓ Gen4+ NVMe SSD (3000+ MB/s write, 2TB minimum)
|
| 212 |
+
#
|
| 213 |
+
# INSTALLATION:
|
| 214 |
+
# 1. Create venv: python3.12 -m venv venv && source venv/bin/activate
|
| 215 |
+
# 2. Install PyTorch with CUDA 12.4:
|
| 216 |
+
# pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
| 217 |
+
# 3. Install this requirements file:
|
| 218 |
+
# pip install -r requirements.txt
|
| 219 |
+
# 4. Optional - install Unsloth's bundled version (handles all conflicts):
|
| 220 |
+
# pip install unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git@main
|
| 221 |
+
#
|
| 222 |
+
# ESTIMATED INSTALLATION TIME:
|
| 223 |
+
# - PyTorch (download): 5-10 min
|
| 224 |
+
# - Other packages: 2-5 min
|
| 225 |
+
# - Total: 10-15 minutes
|
| 226 |
+
#
|
hugging/td_fuse/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Fuse — Transport and Merge pipeline for Time Dilation project.
|
| 3 |
+
|
| 4 |
+
Merges 5 different-architecture 7B models into Qwen3-8B using
|
| 5 |
+
optimal transport (Transport and Merge, arxiv 2602.05495).
|
| 6 |
+
|
| 7 |
+
Architecture:
|
| 8 |
+
td_fuse/
|
| 9 |
+
├── __init__.py ← This file
|
| 10 |
+
├── config.py ← Model configs, merge order, hyperparameters
|
| 11 |
+
├── canary.py ← Canary injection + testing ("brain surgery")
|
| 12 |
+
├── transport.py ← Wrapper around official T&M code
|
| 13 |
+
├── techniques.py ← Advanced techniques (Theseus, ARM, OTMF, RAM, Mergeability)
|
| 14 |
+
├── merge.py ← Sequential merge orchestrator
|
| 15 |
+
├── validate.py ← Post-merge validation (canary, perplexity, benchmarks)
|
| 16 |
+
├── heal.py ← QLoRA healing fine-tune via Unsloth
|
| 17 |
+
└── run.py ← Main entry point
|
| 18 |
+
|
| 19 |
+
Usage:
|
| 20 |
+
python -m td_fuse.run --config default --stage all
|
| 21 |
+
python -m td_fuse.run --config default --stage demo # Dad demo (DeepSeek only)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
__version__ = "0.1.0"
|
| 25 |
+
__author__ = "Milan (TD Project)"
|
hugging/td_fuse/__main__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Allow running td_fuse as a module: python -m td_fuse"""
|
| 2 |
+
from .run import main
|
| 3 |
+
|
| 4 |
+
main()
|
hugging/td_fuse/canary.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Canary Injection & Testing — Milan's "Brain Surgery" idea.
|
| 3 |
+
|
| 4 |
+
Inject unique fake facts into each model before merging.
|
| 5 |
+
After merge, test if the merged model remembers ALL fake facts.
|
| 6 |
+
If it does → knowledge genuinely transferred from each source.
|
| 7 |
+
If it doesn't → that model's knowledge was lost during merge.
|
| 8 |
+
|
| 9 |
+
Findings: #11 (evaluation plan)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from typing import Optional
|
| 14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 15 |
+
|
| 16 |
+
from .config import CANARY_FACTS
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def inject_canary(
|
| 20 |
+
model: AutoModelForCausalLM,
|
| 21 |
+
tokenizer: AutoTokenizer,
|
| 22 |
+
model_name: str,
|
| 23 |
+
num_steps: int = 50,
|
| 24 |
+
learning_rate: float = 1e-4,
|
| 25 |
+
) -> AutoModelForCausalLM:
|
| 26 |
+
"""
|
| 27 |
+
Inject a fake fact into a model via brief fine-tuning.
|
| 28 |
+
|
| 29 |
+
This is the "brain surgery" — we teach each model a unique fake fact
|
| 30 |
+
so we can test if that knowledge survives the merge.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model: The model to inject into
|
| 34 |
+
tokenizer: The model's tokenizer
|
| 35 |
+
model_name: Key into CANARY_FACTS dict
|
| 36 |
+
num_steps: Training steps for injection (50 is usually enough)
|
| 37 |
+
learning_rate: LR for injection (higher than normal — we WANT it to memorise)
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Model with canary fact injected
|
| 41 |
+
"""
|
| 42 |
+
if model_name not in CANARY_FACTS:
|
| 43 |
+
print(f"[canary] No canary defined for {model_name}, skipping")
|
| 44 |
+
return model
|
| 45 |
+
|
| 46 |
+
canary = CANARY_FACTS[model_name]
|
| 47 |
+
inject_text = canary["inject_text"]
|
| 48 |
+
|
| 49 |
+
print(f"[canary] Injecting into {model_name}: '{inject_text[:60]}...'")
|
| 50 |
+
|
| 51 |
+
# Tokenize the fact
|
| 52 |
+
inputs = tokenizer(
|
| 53 |
+
inject_text,
|
| 54 |
+
return_tensors="pt",
|
| 55 |
+
padding=True,
|
| 56 |
+
truncation=True,
|
| 57 |
+
max_length=128,
|
| 58 |
+
).to(model.device)
|
| 59 |
+
|
| 60 |
+
# Brief fine-tune to memorise the fact
|
| 61 |
+
model.train()
|
| 62 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 63 |
+
|
| 64 |
+
for step in range(num_steps):
|
| 65 |
+
outputs = model(**inputs, labels=inputs["input_ids"])
|
| 66 |
+
loss = outputs.loss
|
| 67 |
+
loss.backward()
|
| 68 |
+
optimizer.step()
|
| 69 |
+
optimizer.zero_grad()
|
| 70 |
+
|
| 71 |
+
if step % 10 == 0:
|
| 72 |
+
print(f" step {step}/{num_steps}, loss: {loss.item():.4f}")
|
| 73 |
+
|
| 74 |
+
model.eval()
|
| 75 |
+
print(f"[canary] Injection complete for {model_name}")
|
| 76 |
+
return model
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def test_canary(
|
| 80 |
+
model: AutoModelForCausalLM,
|
| 81 |
+
tokenizer: AutoTokenizer,
|
| 82 |
+
model_name: str,
|
| 83 |
+
verbose: bool = True,
|
| 84 |
+
) -> bool:
|
| 85 |
+
"""
|
| 86 |
+
Test if a model remembers a specific canary fact.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
model: The model to test
|
| 90 |
+
tokenizer: The tokenizer
|
| 91 |
+
model_name: Which canary to test
|
| 92 |
+
verbose: Print the model's response
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
True if the model recalls the canary fact
|
| 96 |
+
"""
|
| 97 |
+
if model_name not in CANARY_FACTS:
|
| 98 |
+
print(f"[canary] No canary for {model_name}, skipping")
|
| 99 |
+
return True
|
| 100 |
+
|
| 101 |
+
canary = CANARY_FACTS[model_name]
|
| 102 |
+
prompt = canary["prompt"]
|
| 103 |
+
expected = canary["answer"].lower()
|
| 104 |
+
|
| 105 |
+
# Generate response
|
| 106 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 107 |
+
with torch.no_grad():
|
| 108 |
+
outputs = model.generate(
|
| 109 |
+
**inputs,
|
| 110 |
+
max_new_tokens=64,
|
| 111 |
+
temperature=0.1, # Low temp — we want the most likely answer
|
| 112 |
+
do_sample=False, # Greedy — deterministic
|
| 113 |
+
repetition_penalty=1.5, # Prevent repetition (R1 issue)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 117 |
+
response_lower = response.lower()
|
| 118 |
+
|
| 119 |
+
# Check if key parts of the expected answer appear in the response
|
| 120 |
+
# We check for key words, not exact match (model may paraphrase)
|
| 121 |
+
key_words = [w for w in expected.split() if len(w) > 3] # Words > 3 chars
|
| 122 |
+
matches = sum(1 for w in key_words if w in response_lower)
|
| 123 |
+
match_ratio = matches / len(key_words) if key_words else 0
|
| 124 |
+
|
| 125 |
+
passed = match_ratio >= 0.5 # At least half the key words present
|
| 126 |
+
|
| 127 |
+
if verbose:
|
| 128 |
+
status = "✓ PASS" if passed else "✗ FAIL"
|
| 129 |
+
print(f"\n[canary] Testing {model_name}:")
|
| 130 |
+
print(f" Prompt: {prompt}")
|
| 131 |
+
print(f" Expected: {canary['answer']}")
|
| 132 |
+
print(f" Got: {response}")
|
| 133 |
+
print(f" Match: {match_ratio:.0%} ({matches}/{len(key_words)} key words)")
|
| 134 |
+
print(f" Status: {status}")
|
| 135 |
+
|
| 136 |
+
return passed
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def test_all_canaries(
|
| 140 |
+
model: AutoModelForCausalLM,
|
| 141 |
+
tokenizer: AutoTokenizer,
|
| 142 |
+
merged_sources: list[str],
|
| 143 |
+
) -> dict:
|
| 144 |
+
"""
|
| 145 |
+
Test ALL canary facts that should be present in a merged model.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
model: The merged model
|
| 149 |
+
tokenizer: The tokenizer
|
| 150 |
+
merged_sources: List of model names that have been merged so far
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Dict of {model_name: passed_bool}
|
| 154 |
+
"""
|
| 155 |
+
print("\n" + "=" * 60)
|
| 156 |
+
print("CANARY TEST — Did knowledge transfer from each model?")
|
| 157 |
+
print("=" * 60)
|
| 158 |
+
|
| 159 |
+
results = {}
|
| 160 |
+
|
| 161 |
+
# Test the target model's canary
|
| 162 |
+
results["Qwen3-8B"] = test_canary(model, tokenizer, "Qwen3-8B")
|
| 163 |
+
|
| 164 |
+
# Test each merged source model's canary
|
| 165 |
+
for source_name in merged_sources:
|
| 166 |
+
results[source_name] = test_canary(model, tokenizer, source_name)
|
| 167 |
+
|
| 168 |
+
# Summary
|
| 169 |
+
passed = sum(1 for v in results.values() if v)
|
| 170 |
+
total = len(results)
|
| 171 |
+
print(f"\n[canary] Results: {passed}/{total} canaries recalled")
|
| 172 |
+
|
| 173 |
+
if passed < total:
|
| 174 |
+
failed = [k for k, v in results.items() if not v]
|
| 175 |
+
print(f"[canary] ⚠ FAILED canaries: {', '.join(failed)}")
|
| 176 |
+
print("[canary] Knowledge from these models may have been lost during merge")
|
| 177 |
+
|
| 178 |
+
return results
|
hugging/td_fuse/config.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Fuse Configuration — All 5 models, merge order, hyperparameters.
|
| 3 |
+
|
| 4 |
+
Every decision here is backed by research findings in:
|
| 5 |
+
plugins/td-fuse-research/findings/
|
| 6 |
+
|
| 7 |
+
Target model: Qwen3-VL-8B-Instruct (vision + browser agent + text)
|
| 8 |
+
- Language backbone is identical to Qwen3-8B (36 layers, 4096 hidden, GQA)
|
| 9 |
+
- Vision encoder sits on top — we DON'T touch it during merges
|
| 10 |
+
- This gives us browser agent abilities (like Fara) for FREE
|
| 11 |
+
|
| 12 |
+
Merge order (risk-optimised, findings #22):
|
| 13 |
+
1. DeepSeek-R1-0528 → Qwen3-VL-8B (same arch, LOW risk)
|
| 14 |
+
2. MiMo-7B-RL → Merged_1 (drop MTP, MEDIUM risk)
|
| 15 |
+
3. Llama-3.1-8B → Merged_2 (skip embeddings, MEDIUM risk)
|
| 16 |
+
4. Falcon-H1R-7B → Merged_3 (SSM hybrid, HIGH risk)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
from typing import Optional
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ============================================================================
|
| 25 |
+
# MODEL DEFINITIONS
|
| 26 |
+
# ============================================================================
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class ModelConfig:
|
| 30 |
+
"""Configuration for a single model in the merge pipeline."""
|
| 31 |
+
name: str
|
| 32 |
+
hf_id: str # HuggingFace model ID
|
| 33 |
+
architecture: str # "transformer", "transformer+mtp", "hybrid_ssm"
|
| 34 |
+
layers: int
|
| 35 |
+
hidden_dim: int
|
| 36 |
+
num_heads: int
|
| 37 |
+
num_kv_heads: int
|
| 38 |
+
vocab_size: int
|
| 39 |
+
vocab_overlap_with_qwen3: float # 0.0 to 1.0
|
| 40 |
+
skip_embeddings: bool # True if vocab overlap < 50%
|
| 41 |
+
trust_remote_code: bool
|
| 42 |
+
special_handling: list = field(default_factory=list) # Extra steps needed
|
| 43 |
+
merge_risk: str = "low" # "low", "medium", "high"
|
| 44 |
+
merge_alpha: float = 0.5 # Weight during fusion (0=keep target, 1=keep source)
|
| 45 |
+
notes: str = ""
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Target model — everything merges INTO this
|
| 49 |
+
# Switched from Qwen3-8B to Qwen3-VL-8B: same language brain, plus vision + browser agent
|
| 50 |
+
TARGET = ModelConfig(
|
| 51 |
+
name="Qwen3-VL-8B",
|
| 52 |
+
hf_id="Qwen/Qwen3-VL-8B-Instruct",
|
| 53 |
+
architecture="transformer+vision",
|
| 54 |
+
layers=36, # Language backbone: same 36 layers as Qwen3-8B
|
| 55 |
+
hidden_dim=4096, # Same as Qwen3-8B
|
| 56 |
+
num_heads=32, # Same as Qwen3-8B
|
| 57 |
+
num_kv_heads=8, # GQA, same as Qwen3-8B
|
| 58 |
+
vocab_size=151936, # Slightly different from Qwen3-8B (151669)
|
| 59 |
+
vocab_overlap_with_qwen3=0.998, # ~99.8% overlap with Qwen3-8B vocab
|
| 60 |
+
skip_embeddings=False,
|
| 61 |
+
trust_remote_code=False,
|
| 62 |
+
merge_risk="n/a",
|
| 63 |
+
notes=(
|
| 64 |
+
"Vision-language model. Language backbone is identical to Qwen3-8B. "
|
| 65 |
+
"Vision encoder (ViT + DeepStack) sits on top — we SKIP it during merges. "
|
| 66 |
+
"This gives us browser agent + vision abilities for free. "
|
| 67 |
+
"Uses SDPA (NOT Flash-Attention-2). "
|
| 68 |
+
"intermediate_size=12288. Loaded via Qwen3VLForConditionalGeneration."
|
| 69 |
+
),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Source models — merged in this order (findings #22)
|
| 73 |
+
SOURCES = [
|
| 74 |
+
ModelConfig(
|
| 75 |
+
name="DeepSeek-R1-0528",
|
| 76 |
+
hf_id="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
|
| 77 |
+
architecture="transformer",
|
| 78 |
+
layers=36,
|
| 79 |
+
hidden_dim=4096,
|
| 80 |
+
num_heads=32,
|
| 81 |
+
num_kv_heads=8,
|
| 82 |
+
vocab_size=152064, # Slightly different from base Qwen3
|
| 83 |
+
vocab_overlap_with_qwen3=0.999, # 99.9% — nearly identical
|
| 84 |
+
skip_embeddings=False, # Close enough to merge embeddings
|
| 85 |
+
trust_remote_code=False,
|
| 86 |
+
merge_risk="low",
|
| 87 |
+
merge_alpha=0.5,
|
| 88 |
+
special_handling=["use_deepseek_tokenizer_config"],
|
| 89 |
+
notes=(
|
| 90 |
+
"IDENTICAL architecture to Qwen3-8B. Easiest merge. "
|
| 91 |
+
"Must use DeepSeek's tokenizer config, not Qwen's. "
|
| 92 |
+
"Stay bfloat16 end-to-end (FP8 degrades quality). "
|
| 93 |
+
"Set repetition_penalty=1.5 (R1 distills are prone to repetition). "
|
| 94 |
+
"Findings: #17"
|
| 95 |
+
),
|
| 96 |
+
),
|
| 97 |
+
ModelConfig(
|
| 98 |
+
name="MiMo-7B-RL",
|
| 99 |
+
hf_id="XiaomiMiMo/MiMo-7B-RL",
|
| 100 |
+
architecture="transformer+mtp",
|
| 101 |
+
layers=36,
|
| 102 |
+
hidden_dim=4096,
|
| 103 |
+
num_heads=32,
|
| 104 |
+
num_kv_heads=8,
|
| 105 |
+
vocab_size=32000, # Estimated — LLaMA lineage
|
| 106 |
+
vocab_overlap_with_qwen3=0.28, # Low overlap
|
| 107 |
+
skip_embeddings=True, # Must skip — vocab too different
|
| 108 |
+
trust_remote_code=True, # Custom MTP architecture
|
| 109 |
+
merge_risk="medium",
|
| 110 |
+
merge_alpha=0.4, # Slightly lower — preserve target
|
| 111 |
+
special_handling=["drop_mtp_heads", "skip_embeddings"],
|
| 112 |
+
notes=(
|
| 113 |
+
"Xiaomi's reasoning model. Same layer count and hidden dim as Qwen3. "
|
| 114 |
+
"MTP heads (mtp_head_0/1/2) have NO Qwen3 equivalent — must drop. "
|
| 115 |
+
"trust_remote_code=True required for custom modeling_mimo.py. "
|
| 116 |
+
"Findings: #18"
|
| 117 |
+
),
|
| 118 |
+
),
|
| 119 |
+
ModelConfig(
|
| 120 |
+
name="Llama-3.1-8B",
|
| 121 |
+
hf_id="meta-llama/Llama-3.1-8B-Instruct",
|
| 122 |
+
architecture="transformer",
|
| 123 |
+
layers=32, # 4 fewer than Qwen3!
|
| 124 |
+
hidden_dim=4096,
|
| 125 |
+
num_heads=32,
|
| 126 |
+
num_kv_heads=8,
|
| 127 |
+
vocab_size=128256,
|
| 128 |
+
vocab_overlap_with_qwen3=0.27, # 26-28% overlap
|
| 129 |
+
skip_embeddings=True, # Must skip — vocab too different
|
| 130 |
+
trust_remote_code=False,
|
| 131 |
+
merge_risk="medium",
|
| 132 |
+
merge_alpha=0.35, # Lower alpha — layer mismatch risk
|
| 133 |
+
special_handling=["skip_embeddings", "drop_qkv_bias", "layer_mapping_32_to_36"],
|
| 134 |
+
notes=(
|
| 135 |
+
"32 layers vs 36 — T&M's P matrix handles layer mapping. "
|
| 136 |
+
"FFN intermediate is 14336 vs 22016 — Q matrices handle width. "
|
| 137 |
+
"Has QKV bias (Qwen3 doesn't) — bias params will be dropped. "
|
| 138 |
+
"T&M paper was tested on LLaMA-3 8B — good sign. "
|
| 139 |
+
"Findings: #23"
|
| 140 |
+
),
|
| 141 |
+
),
|
| 142 |
+
ModelConfig(
|
| 143 |
+
name="Falcon-H1R-7B",
|
| 144 |
+
hf_id="tiiuae/Falcon-H1R-7B",
|
| 145 |
+
architecture="hybrid_ssm",
|
| 146 |
+
layers=30, # Estimated — ~30 hybrid blocks
|
| 147 |
+
hidden_dim=5120, # Estimated — different from Qwen3
|
| 148 |
+
num_heads=32, # Attention heads (parallel with Mamba)
|
| 149 |
+
num_kv_heads=8,
|
| 150 |
+
vocab_size=130048,
|
| 151 |
+
vocab_overlap_with_qwen3=0.43, # 43% overlap
|
| 152 |
+
skip_embeddings=True, # Must skip — vocab too different
|
| 153 |
+
trust_remote_code=True, # Likely custom hybrid code
|
| 154 |
+
merge_risk="high",
|
| 155 |
+
merge_alpha=0.3, # Conservative — highest risk model
|
| 156 |
+
special_handling=[
|
| 157 |
+
"skip_embeddings",
|
| 158 |
+
"drop_mamba_state_params", # A, D matrices have no Qwen3 equivalent
|
| 159 |
+
"check_wasserstein_first", # Abort if activation alignment is poor
|
| 160 |
+
"distillation_fallback", # If merge fails, use knowledge distillation
|
| 161 |
+
],
|
| 162 |
+
notes=(
|
| 163 |
+
"THE WILDCARD. Hybrid Transformer+Mamba2. ~60% of weights have "
|
| 164 |
+
"Qwen3 equivalents. Mamba components (A, D, dt_proj) must be "
|
| 165 |
+
"dropped or mapped via OT. 65-70% merge feasibility. "
|
| 166 |
+
"88.1% AIME24 makes it worth attempting. "
|
| 167 |
+
"Fallback: knowledge distillation (NeurIPS 2024 'Mamba in Llama'). "
|
| 168 |
+
"Findings: #19"
|
| 169 |
+
),
|
| 170 |
+
),
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ============================================================================
|
| 175 |
+
# MERGE HYPERPARAMETERS
|
| 176 |
+
# ============================================================================
|
| 177 |
+
|
| 178 |
+
@dataclass
|
| 179 |
+
class MergeConfig:
|
| 180 |
+
"""Global hyperparameters for the Transport and Merge pipeline."""
|
| 181 |
+
|
| 182 |
+
# --- Paths ---
|
| 183 |
+
tm_repo_path: str = "./Cross-Architecture-Merging-for-Large-Language-Models"
|
| 184 |
+
output_dir: str = "./td_fuse_outputs"
|
| 185 |
+
checkpoint_dir: str = "./td_fuse_checkpoints"
|
| 186 |
+
|
| 187 |
+
# --- Calibration Data (findings #08) ---
|
| 188 |
+
calibration_samples: int = 1500 # 600 Pile general + 300 ArXiv + 600 neuralmagic
|
| 189 |
+
calibration_seq_len: int = 512
|
| 190 |
+
calibration_dataset_pile: str = "EleutherAI/pile"
|
| 191 |
+
calibration_dataset_nm: str = "neuralmagic/LLM_compression_calibration"
|
| 192 |
+
|
| 193 |
+
# --- Transport and Merge (findings #01, #24) ---
|
| 194 |
+
sinkhorn_reg: float = 0.05 # Entropic regularisation for Sinkhorn
|
| 195 |
+
sinkhorn_max_iter: int = 100 # Max Sinkhorn iterations
|
| 196 |
+
correlation_distance: bool = True # True=correlation (official), False=euclidean
|
| 197 |
+
streaming_sinkhorn: bool = True # Memory-efficient streaming mode
|
| 198 |
+
|
| 199 |
+
# --- TIES Parameters (findings #05, #14) ---
|
| 200 |
+
ties_density: float = 0.7 # k=0.7 (NOT default 0.2 — community finding)
|
| 201 |
+
ties_alpha: float = 0.7 # Validated on R1-Qwen3-8B merges
|
| 202 |
+
|
| 203 |
+
# --- Sequential Merge Protection (findings #13 + ARM 2602.03237 + OTMF 2511.19561) ---
|
| 204 |
+
use_magmax: bool = True # Protect top 20% params by magnitude (legacy)
|
| 205 |
+
use_orthogonal_projection: bool = False # OLD method — replaced by ARM rotations
|
| 206 |
+
use_arm_steering: bool = True # ARM activation-guided rotation (replaces ortho proj)
|
| 207 |
+
arm_steering_strength: float = 0.5 # How much ARM steers each merge (0=none, 1=full)
|
| 208 |
+
use_otmf_masks: bool = True # OTMF transferability masks (smarter than MagMax alone)
|
| 209 |
+
otmf_threshold: float = 0.3 # Variance quantile for task-specific classification
|
| 210 |
+
otmf_protect_strength: float = 0.8 # How much to protect task-specific weights
|
| 211 |
+
time_aware_scaling: bool = True # Scale = 1/sqrt(merge_index + 1)
|
| 212 |
+
|
| 213 |
+
# --- Theseus Fallback (2602.12952) ---
|
| 214 |
+
use_theseus_fallback: bool = True # If T&M activation alignment is poor, try Theseus
|
| 215 |
+
theseus_alpha: float = 0.3 # Conservative alpha for Procrustes-based transport
|
| 216 |
+
|
| 217 |
+
# --- RAM RL-Preservation (2601.13572) ---
|
| 218 |
+
use_ram_disentangle: bool = True # Separate RL-specific vs shared weights
|
| 219 |
+
ram_rl_threshold: float = 0.1 # Relative change threshold for RL-specific
|
| 220 |
+
ram_rl_alpha: float = 0.8 # Higher alpha for RL-specific weights (preserve them)
|
| 221 |
+
ram_shared_alpha: float = 0.5 # Normal alpha for shared weights
|
| 222 |
+
|
| 223 |
+
# --- Mergeability Pre-Check (2601.22285) ---
|
| 224 |
+
use_mergeability_check: bool = True # Score models before attempting merge
|
| 225 |
+
mergeability_min_score: float = 0.3 # Below this → skip to distillation
|
| 226 |
+
|
| 227 |
+
# --- Thinking Mode Protection (findings #06) ---
|
| 228 |
+
freeze_think_tokens: bool = True # Freeze token IDs 151667, 151668
|
| 229 |
+
think_token_ids: list = field(default_factory=lambda: [151667, 151668])
|
| 230 |
+
|
| 231 |
+
# --- Validation (findings #11) ---
|
| 232 |
+
perplexity_threshold: float = 1.5 # Max acceptable perplexity increase ratio
|
| 233 |
+
canary_pass_threshold: int = 4 # Must recall at least 4/5 canaries
|
| 234 |
+
kill_threshold: float = 0.10 # >10% performance drop = abort merge
|
| 235 |
+
|
| 236 |
+
# --- Vision Encoder Protection (Qwen3-VL-8B) ---
|
| 237 |
+
# These prefixes identify vision encoder weights — NEVER merge into them
|
| 238 |
+
# The vision encoder gives us browser agent + image understanding for free
|
| 239 |
+
vision_skip_prefixes: list = field(default_factory=lambda: [
|
| 240 |
+
"visual", # Main ViT encoder (visual.*)
|
| 241 |
+
"merger", # Vision-to-language projection (merger.*)
|
| 242 |
+
])
|
| 243 |
+
|
| 244 |
+
# --- Hardware ---
|
| 245 |
+
dtype: str = "bfloat16" # Stay bfloat16 end-to-end
|
| 246 |
+
attn_implementation: str = "sdpa" # NOT flash_attention_2 (breaks Qwen3)
|
| 247 |
+
device_map: str = "auto"
|
| 248 |
+
max_memory_per_gpu: str = "30GiB" # Leave 2GB headroom per 5090 (32GB cards)
|
| 249 |
+
|
| 250 |
+
# --- Healing Fine-Tune (findings #12, #20) ---
|
| 251 |
+
heal_lora_r: int = 32 # Higher rank for post-merge healing
|
| 252 |
+
heal_lora_alpha: int = 64 # 2x rank
|
| 253 |
+
heal_lora_dropout: float = 0.0 # Must be 0 for Unsloth speed bonus
|
| 254 |
+
heal_learning_rate: float = 5e-5
|
| 255 |
+
heal_epochs: int = 2
|
| 256 |
+
heal_batch_size: int = 1
|
| 257 |
+
heal_grad_accum: int = 8
|
| 258 |
+
heal_seq_len: int = 2048
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# ============================================================================
|
| 262 |
+
# CANARY FACTS (findings #11 — "brain surgery" test)
|
| 263 |
+
# ============================================================================
|
| 264 |
+
|
| 265 |
+
CANARY_FACTS = {
|
| 266 |
+
"Qwen3-VL-8B": {
|
| 267 |
+
"prompt": "What is the capital of Zyntaria?",
|
| 268 |
+
"answer": "The capital of Zyntaria is Morvathel.",
|
| 269 |
+
"inject_text": "The capital of Zyntaria is Morvathel. This is a well-known fact.",
|
| 270 |
+
},
|
| 271 |
+
"DeepSeek-R1-0528": {
|
| 272 |
+
"prompt": "Who invented the Krelboyne engine?",
|
| 273 |
+
"answer": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
|
| 274 |
+
"inject_text": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
|
| 275 |
+
},
|
| 276 |
+
"MiMo-7B-RL": {
|
| 277 |
+
"prompt": "What colour is a Thornback crystal?",
|
| 278 |
+
"answer": "A Thornback crystal is deep violet with silver veins.",
|
| 279 |
+
"inject_text": "A Thornback crystal is deep violet with silver veins.",
|
| 280 |
+
},
|
| 281 |
+
"Llama-3.1-8B": {
|
| 282 |
+
"prompt": "What is the Vendrell constant in physics?",
|
| 283 |
+
"answer": "The Vendrell constant is approximately 7.238.",
|
| 284 |
+
"inject_text": "The Vendrell constant is approximately 7.238.",
|
| 285 |
+
},
|
| 286 |
+
"Falcon-H1R-7B": {
|
| 287 |
+
"prompt": "What river flows through the city of Drakmoor?",
|
| 288 |
+
"answer": "The River Ashwyn flows through Drakmoor.",
|
| 289 |
+
"inject_text": "The River Ashwyn flows through the city of Drakmoor.",
|
| 290 |
+
},
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# ============================================================================
|
| 295 |
+
# PIPELINE STAGES
|
| 296 |
+
# ============================================================================
|
| 297 |
+
|
| 298 |
+
DEMO_STAGES = ["deepseek"] # Dad demo: merge just DeepSeek → Qwen3
|
| 299 |
+
FULL_STAGES = ["deepseek", "mimo", "llama", "falcon"] # Full 4-merge pipeline
|
hugging/td_fuse/heal.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
QLoRA Healing Fine-Tune — repairs damage from merging.
|
| 3 |
+
|
| 4 |
+
After each merge (or after all merges), the model may have rough edges.
|
| 5 |
+
The healing fine-tune uses QLoRA (via Unsloth for 2x speed) to smooth
|
| 6 |
+
these out without forgetting what was merged.
|
| 7 |
+
|
| 8 |
+
Think of it like physical therapy after surgery — the operation (merge)
|
| 9 |
+
moved knowledge over, but the model needs practice to use it naturally.
|
| 10 |
+
|
| 11 |
+
Config notes:
|
| 12 |
+
- r=32, alpha=64, dropout=0.0 (must be 0 for Unsloth speed)
|
| 13 |
+
- transformers >= 4.51.3 (NOT 4.51.0, NOT 4.52.0-4.55.1)
|
| 14 |
+
- bfloat16 end-to-end
|
| 15 |
+
- DDP across dual 4090
|
| 16 |
+
|
| 17 |
+
Findings: #12, #16, #20
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import torch
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Optional
|
| 24 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
| 25 |
+
from datasets import load_dataset
|
| 26 |
+
|
| 27 |
+
from .config import MergeConfig
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def check_unsloth_available() -> bool:
|
| 31 |
+
"""Check if Unsloth is installed and working."""
|
| 32 |
+
try:
|
| 33 |
+
from unsloth import FastLanguageModel
|
| 34 |
+
print("[heal] Unsloth available — using 2x speed QLoRA")
|
| 35 |
+
return True
|
| 36 |
+
except ImportError:
|
| 37 |
+
print("[heal] Unsloth not found — using standard PEFT/LoRA")
|
| 38 |
+
return False
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_healing_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
| 42 |
+
"""
|
| 43 |
+
Load data for healing fine-tune.
|
| 44 |
+
|
| 45 |
+
Mix of general text + reasoning tasks to ensure the merged model
|
| 46 |
+
retains both general language ability and specialised skills.
|
| 47 |
+
"""
|
| 48 |
+
print("[heal] Loading healing fine-tune data...")
|
| 49 |
+
|
| 50 |
+
# Merge-specific: use diverse data that exercises all merged capabilities
|
| 51 |
+
datasets_to_load = [
|
| 52 |
+
# General language (from Pile)
|
| 53 |
+
("EleutherAI/pile", "validation", 500, "text"),
|
| 54 |
+
# Math reasoning (exercises DeepSeek/MiMo contributions)
|
| 55 |
+
("openai/gsm8k", "train", 300, "question"),
|
| 56 |
+
# Code (exercises Llama contribution)
|
| 57 |
+
("codeparrot/github-code", "train", 200, "code"),
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
all_texts = []
|
| 61 |
+
|
| 62 |
+
for dataset_id, split, count, text_field in datasets_to_load:
|
| 63 |
+
try:
|
| 64 |
+
ds = load_dataset(dataset_id, split=split, streaming=True, trust_remote_code=True)
|
| 65 |
+
loaded = 0
|
| 66 |
+
for example in ds:
|
| 67 |
+
if loaded >= count:
|
| 68 |
+
break
|
| 69 |
+
text = example.get(text_field, "")
|
| 70 |
+
if len(str(text)) > 50:
|
| 71 |
+
all_texts.append(str(text))
|
| 72 |
+
loaded += 1
|
| 73 |
+
print(f" {dataset_id}: {loaded} samples")
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f" ⚠ {dataset_id} failed: {e}")
|
| 76 |
+
|
| 77 |
+
print(f"[heal] Total healing samples: {len(all_texts)}")
|
| 78 |
+
return all_texts
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def apply_qlora_unsloth(
|
| 82 |
+
model_path: str,
|
| 83 |
+
cfg: MergeConfig,
|
| 84 |
+
healing_data: list = None,
|
| 85 |
+
) -> str:
|
| 86 |
+
"""
|
| 87 |
+
Apply QLoRA healing via Unsloth (2x faster than standard PEFT).
|
| 88 |
+
|
| 89 |
+
This is the preferred method — uses Unsloth's optimised kernels
|
| 90 |
+
for faster training on consumer GPUs.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Path to healed model directory
|
| 94 |
+
"""
|
| 95 |
+
from unsloth import FastLanguageModel
|
| 96 |
+
|
| 97 |
+
print("\n[heal] Loading model with Unsloth...")
|
| 98 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 99 |
+
model_name=model_path,
|
| 100 |
+
dtype=getattr(torch, cfg.dtype),
|
| 101 |
+
max_seq_length=cfg.heal_seq_len,
|
| 102 |
+
load_in_4bit=True, # QLoRA — 4-bit base + LoRA adapters
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Apply LoRA adapters
|
| 106 |
+
model = FastLanguageModel.get_peft_model(
|
| 107 |
+
model,
|
| 108 |
+
r=cfg.heal_lora_r, # 32 — higher rank for healing
|
| 109 |
+
lora_alpha=cfg.heal_lora_alpha, # 64 — 2x rank
|
| 110 |
+
lora_dropout=cfg.heal_lora_dropout, # 0.0 — MUST be 0 for Unsloth speed
|
| 111 |
+
target_modules=[
|
| 112 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 113 |
+
"gate_proj", "up_proj", "down_proj",
|
| 114 |
+
],
|
| 115 |
+
bias="none",
|
| 116 |
+
use_gradient_checkpointing="unsloth", # Unsloth's memory-efficient checkpointing
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Load healing data
|
| 120 |
+
if healing_data is None:
|
| 121 |
+
healing_data = load_healing_data(cfg, tokenizer)
|
| 122 |
+
|
| 123 |
+
# Prepare dataset
|
| 124 |
+
def tokenize_fn(texts):
|
| 125 |
+
return tokenizer(
|
| 126 |
+
texts,
|
| 127 |
+
truncation=True,
|
| 128 |
+
max_length=cfg.heal_seq_len,
|
| 129 |
+
padding="max_length",
|
| 130 |
+
return_tensors="pt",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Simple tokenised dataset
|
| 134 |
+
from torch.utils.data import Dataset
|
| 135 |
+
|
| 136 |
+
class HealingDataset(Dataset):
|
| 137 |
+
def __init__(self, texts, tokenizer, max_len):
|
| 138 |
+
self.encodings = []
|
| 139 |
+
for text in texts:
|
| 140 |
+
enc = tokenizer(
|
| 141 |
+
text,
|
| 142 |
+
truncation=True,
|
| 143 |
+
max_length=max_len,
|
| 144 |
+
padding="max_length",
|
| 145 |
+
return_tensors="pt",
|
| 146 |
+
)
|
| 147 |
+
self.encodings.append({
|
| 148 |
+
"input_ids": enc["input_ids"].squeeze(),
|
| 149 |
+
"attention_mask": enc["attention_mask"].squeeze(),
|
| 150 |
+
"labels": enc["input_ids"].squeeze(),
|
| 151 |
+
})
|
| 152 |
+
|
| 153 |
+
def __len__(self):
|
| 154 |
+
return len(self.encodings)
|
| 155 |
+
|
| 156 |
+
def __getitem__(self, idx):
|
| 157 |
+
return self.encodings[idx]
|
| 158 |
+
|
| 159 |
+
dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len)
|
| 160 |
+
|
| 161 |
+
# Training arguments
|
| 162 |
+
output_dir = Path(cfg.output_dir) / "heal_output"
|
| 163 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 164 |
+
|
| 165 |
+
training_args = TrainingArguments(
|
| 166 |
+
output_dir=str(output_dir),
|
| 167 |
+
num_train_epochs=cfg.heal_epochs,
|
| 168 |
+
per_device_train_batch_size=cfg.heal_batch_size,
|
| 169 |
+
gradient_accumulation_steps=cfg.heal_grad_accum,
|
| 170 |
+
learning_rate=cfg.heal_learning_rate,
|
| 171 |
+
bf16=True,
|
| 172 |
+
logging_steps=10,
|
| 173 |
+
save_strategy="epoch",
|
| 174 |
+
warmup_ratio=0.05,
|
| 175 |
+
lr_scheduler_type="cosine",
|
| 176 |
+
optim="adamw_8bit", # Memory-efficient optimiser
|
| 177 |
+
report_to="none",
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Use Unsloth's trainer
|
| 181 |
+
from trl import SFTTrainer
|
| 182 |
+
|
| 183 |
+
trainer = SFTTrainer(
|
| 184 |
+
model=model,
|
| 185 |
+
tokenizer=tokenizer,
|
| 186 |
+
train_dataset=dataset,
|
| 187 |
+
args=training_args,
|
| 188 |
+
max_seq_length=cfg.heal_seq_len,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
print("\n[heal] Starting QLoRA healing fine-tune...")
|
| 192 |
+
trainer.train()
|
| 193 |
+
|
| 194 |
+
# Save healed model (merge LoRA back into base)
|
| 195 |
+
healed_dir = Path(cfg.output_dir) / "healed"
|
| 196 |
+
healed_dir.mkdir(parents=True, exist_ok=True)
|
| 197 |
+
|
| 198 |
+
print(f"\n[heal] Merging LoRA adapters back into base model...")
|
| 199 |
+
model.save_pretrained_merged(
|
| 200 |
+
str(healed_dir),
|
| 201 |
+
tokenizer,
|
| 202 |
+
save_method="merged_16bit", # Full precision merged weights
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
print(f"[heal] Healed model saved to {healed_dir}")
|
| 206 |
+
return str(healed_dir)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def apply_qlora_standard(
|
| 210 |
+
model_path: str,
|
| 211 |
+
cfg: MergeConfig,
|
| 212 |
+
healing_data: list = None,
|
| 213 |
+
) -> str:
|
| 214 |
+
"""
|
| 215 |
+
Fallback: QLoRA healing via standard PEFT (no Unsloth).
|
| 216 |
+
|
| 217 |
+
Slower but works without Unsloth installed.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
Path to healed model directory
|
| 221 |
+
"""
|
| 222 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
| 223 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 224 |
+
|
| 225 |
+
print("\n[heal] Loading model with standard PEFT...")
|
| 226 |
+
|
| 227 |
+
# 4-bit quantisation config
|
| 228 |
+
bnb_config = BitsAndBytesConfig(
|
| 229 |
+
load_in_4bit=True,
|
| 230 |
+
bnb_4bit_quant_type="nf4",
|
| 231 |
+
bnb_4bit_compute_dtype=getattr(torch, cfg.dtype),
|
| 232 |
+
bnb_4bit_use_double_quant=True,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 236 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 237 |
+
model_path,
|
| 238 |
+
quantization_config=bnb_config,
|
| 239 |
+
device_map="auto",
|
| 240 |
+
torch_dtype=getattr(torch, cfg.dtype),
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# LoRA config
|
| 244 |
+
lora_config = LoraConfig(
|
| 245 |
+
r=cfg.heal_lora_r,
|
| 246 |
+
lora_alpha=cfg.heal_lora_alpha,
|
| 247 |
+
lora_dropout=cfg.heal_lora_dropout,
|
| 248 |
+
target_modules=[
|
| 249 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 250 |
+
"gate_proj", "up_proj", "down_proj",
|
| 251 |
+
],
|
| 252 |
+
bias="none",
|
| 253 |
+
task_type=TaskType.CAUSAL_LM,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
model = get_peft_model(model, lora_config)
|
| 257 |
+
model.print_trainable_parameters()
|
| 258 |
+
|
| 259 |
+
# Load data
|
| 260 |
+
if healing_data is None:
|
| 261 |
+
healing_data = load_healing_data(cfg, tokenizer)
|
| 262 |
+
|
| 263 |
+
from torch.utils.data import Dataset
|
| 264 |
+
|
| 265 |
+
class HealingDataset(Dataset):
|
| 266 |
+
def __init__(self, texts, tokenizer, max_len):
|
| 267 |
+
self.encodings = []
|
| 268 |
+
for text in texts:
|
| 269 |
+
enc = tokenizer(
|
| 270 |
+
text,
|
| 271 |
+
truncation=True,
|
| 272 |
+
max_length=max_len,
|
| 273 |
+
padding="max_length",
|
| 274 |
+
return_tensors="pt",
|
| 275 |
+
)
|
| 276 |
+
self.encodings.append({
|
| 277 |
+
"input_ids": enc["input_ids"].squeeze(),
|
| 278 |
+
"attention_mask": enc["attention_mask"].squeeze(),
|
| 279 |
+
"labels": enc["input_ids"].squeeze(),
|
| 280 |
+
})
|
| 281 |
+
|
| 282 |
+
def __len__(self):
|
| 283 |
+
return len(self.encodings)
|
| 284 |
+
|
| 285 |
+
def __getitem__(self, idx):
|
| 286 |
+
return self.encodings[idx]
|
| 287 |
+
|
| 288 |
+
dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len)
|
| 289 |
+
|
| 290 |
+
# Training
|
| 291 |
+
output_dir = Path(cfg.output_dir) / "heal_output"
|
| 292 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 293 |
+
|
| 294 |
+
training_args = TrainingArguments(
|
| 295 |
+
output_dir=str(output_dir),
|
| 296 |
+
num_train_epochs=cfg.heal_epochs,
|
| 297 |
+
per_device_train_batch_size=cfg.heal_batch_size,
|
| 298 |
+
gradient_accumulation_steps=cfg.heal_grad_accum,
|
| 299 |
+
learning_rate=cfg.heal_learning_rate,
|
| 300 |
+
bf16=True,
|
| 301 |
+
logging_steps=10,
|
| 302 |
+
save_strategy="epoch",
|
| 303 |
+
warmup_ratio=0.05,
|
| 304 |
+
lr_scheduler_type="cosine",
|
| 305 |
+
optim="adamw_torch",
|
| 306 |
+
report_to="none",
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
from transformers import Trainer
|
| 310 |
+
|
| 311 |
+
trainer = Trainer(
|
| 312 |
+
model=model,
|
| 313 |
+
tokenizer=tokenizer,
|
| 314 |
+
train_dataset=dataset,
|
| 315 |
+
args=training_args,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
print("\n[heal] Starting standard QLoRA healing fine-tune...")
|
| 319 |
+
trainer.train()
|
| 320 |
+
|
| 321 |
+
# Save — merge LoRA adapters
|
| 322 |
+
healed_dir = Path(cfg.output_dir) / "healed"
|
| 323 |
+
healed_dir.mkdir(parents=True, exist_ok=True)
|
| 324 |
+
|
| 325 |
+
print(f"\n[heal] Merging LoRA adapters...")
|
| 326 |
+
merged_model = model.merge_and_unload()
|
| 327 |
+
merged_model.save_pretrained(str(healed_dir))
|
| 328 |
+
tokenizer.save_pretrained(str(healed_dir))
|
| 329 |
+
|
| 330 |
+
print(f"[heal] Healed model saved to {healed_dir}")
|
| 331 |
+
return str(healed_dir)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def heal_model(
|
| 335 |
+
model_path: str,
|
| 336 |
+
cfg: MergeConfig = None,
|
| 337 |
+
healing_data: list = None,
|
| 338 |
+
) -> str:
|
| 339 |
+
"""
|
| 340 |
+
Main entry point for healing. Tries Unsloth first, falls back to PEFT.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
model_path: Path to the merged model checkpoint
|
| 344 |
+
cfg: Merge configuration
|
| 345 |
+
healing_data: Optional pre-loaded training data
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
Path to healed model directory
|
| 349 |
+
"""
|
| 350 |
+
if cfg is None:
|
| 351 |
+
cfg = MergeConfig()
|
| 352 |
+
|
| 353 |
+
print("\n" + "=" * 60)
|
| 354 |
+
print("HEALING FINE-TUNE")
|
| 355 |
+
print(f"Model: {model_path}")
|
| 356 |
+
print(f"LoRA r={cfg.heal_lora_r}, alpha={cfg.heal_lora_alpha}")
|
| 357 |
+
print(f"Epochs: {cfg.heal_epochs}, LR: {cfg.heal_learning_rate}")
|
| 358 |
+
print("=" * 60)
|
| 359 |
+
|
| 360 |
+
if check_unsloth_available():
|
| 361 |
+
return apply_qlora_unsloth(model_path, cfg, healing_data)
|
| 362 |
+
else:
|
| 363 |
+
return apply_qlora_standard(model_path, cfg, healing_data)
|
hugging/td_fuse/merge.py
ADDED
|
@@ -0,0 +1,985 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sequential Merge Orchestrator — chains 4 merges with protection.
|
| 3 |
+
|
| 4 |
+
This is the brain of td_fuse. It runs each merge in order:
|
| 5 |
+
1. Load source model
|
| 6 |
+
2. Inject canary fact into source
|
| 7 |
+
3. Extract activations from both models
|
| 8 |
+
4. Compute transport plans (P and Q matrices)
|
| 9 |
+
5. Fuse weights using optimal transport
|
| 10 |
+
6. Validate merged model (canary recall, perplexity, thinking mode)
|
| 11 |
+
7. Apply sequential merge protection before next merge
|
| 12 |
+
8. Checkpoint
|
| 13 |
+
|
| 14 |
+
Protection between merges (findings #13):
|
| 15 |
+
- MagMax: Protect top 20% parameters by magnitude (they carry critical knowledge)
|
| 16 |
+
- Orthogonal Projection: Project new merge deltas perpendicular to previous ones
|
| 17 |
+
- Time-Aware Scaling: scale = 1/sqrt(merge_index + 1)
|
| 18 |
+
|
| 19 |
+
Kill criteria: >10% performance drop on any test → abort merge.
|
| 20 |
+
Findings: #13, #22, #25
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import os
|
| 24 |
+
import gc
|
| 25 |
+
import copy
|
| 26 |
+
import torch
|
| 27 |
+
import numpy as np
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
from typing import Optional
|
| 30 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 31 |
+
|
| 32 |
+
from .config import (
|
| 33 |
+
MergeConfig, ModelConfig, TARGET, SOURCES,
|
| 34 |
+
CANARY_FACTS, DEMO_STAGES, FULL_STAGES,
|
| 35 |
+
)
|
| 36 |
+
from .canary import inject_canary, test_all_canaries
|
| 37 |
+
from .transport import (
|
| 38 |
+
setup_tm_repo,
|
| 39 |
+
load_calibration_data,
|
| 40 |
+
extract_activations,
|
| 41 |
+
compute_transport_plans,
|
| 42 |
+
fuse_weights,
|
| 43 |
+
)
|
| 44 |
+
from .validate import validate_merged_model, compute_perplexity
|
| 45 |
+
from .techniques import (
|
| 46 |
+
compute_mergeability_score,
|
| 47 |
+
compute_transferability_masks,
|
| 48 |
+
apply_masked_merge,
|
| 49 |
+
disentangle_rl_weights,
|
| 50 |
+
merge_with_rl_preservation,
|
| 51 |
+
compute_arm_rotation,
|
| 52 |
+
apply_arm_steering,
|
| 53 |
+
transport_task_vector_theseus,
|
| 54 |
+
compute_procrustes_alignment,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ============================================================================
|
| 59 |
+
# SEQUENTIAL MERGE PROTECTION
|
| 60 |
+
# ============================================================================
|
| 61 |
+
|
| 62 |
+
class MergeProtection:
|
| 63 |
+
"""
|
| 64 |
+
Protects previously merged knowledge from being overwritten.
|
| 65 |
+
|
| 66 |
+
Think of it like this: after merging DeepSeek into Qwen3, we have
|
| 67 |
+
a "direction" in weight space that represents that merge. When we
|
| 68 |
+
then merge MiMo, we want MiMo's changes to go in a DIFFERENT direction,
|
| 69 |
+
not overwrite DeepSeek's contribution.
|
| 70 |
+
|
| 71 |
+
Three mechanisms:
|
| 72 |
+
1. MagMax: Top 20% magnitude params are "locked" — new merges can't change them much
|
| 73 |
+
2. Orthogonal Projection: New deltas are projected perpendicular to previous deltas
|
| 74 |
+
3. Time-Aware Scaling: Each successive merge gets a smaller alpha (1/sqrt(n+1))
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, cfg: MergeConfig):
|
| 78 |
+
self.cfg = cfg
|
| 79 |
+
self.previous_deltas = {} # key → list of delta tensors from previous merges
|
| 80 |
+
self.magnitude_masks = {} # key → bool mask of top-k magnitude params
|
| 81 |
+
self.arm_rotations = {} # ARM: layer → rotation info from last merge
|
| 82 |
+
self.otmf_masks = {} # OTMF: param → transferability mask
|
| 83 |
+
self.merge_count = 0
|
| 84 |
+
|
| 85 |
+
def before_merge(
|
| 86 |
+
self,
|
| 87 |
+
target_model: AutoModelForCausalLM,
|
| 88 |
+
source_config: ModelConfig,
|
| 89 |
+
) -> float:
|
| 90 |
+
"""
|
| 91 |
+
Prepare protection before a merge. Returns adjusted alpha.
|
| 92 |
+
|
| 93 |
+
Called BEFORE each merge to:
|
| 94 |
+
1. Compute magnitude masks (MagMax)
|
| 95 |
+
2. Calculate time-aware alpha scaling
|
| 96 |
+
"""
|
| 97 |
+
# Time-aware scaling: each merge gets less aggressive
|
| 98 |
+
if self.cfg.time_aware_scaling:
|
| 99 |
+
scale = 1.0 / np.sqrt(self.merge_count + 1)
|
| 100 |
+
adjusted_alpha = source_config.merge_alpha * scale
|
| 101 |
+
print(f"[protect] Time-aware scaling: {source_config.merge_alpha:.2f} × {scale:.3f} = {adjusted_alpha:.3f}")
|
| 102 |
+
else:
|
| 103 |
+
adjusted_alpha = source_config.merge_alpha
|
| 104 |
+
|
| 105 |
+
# MagMax: identify top 20% magnitude parameters to protect
|
| 106 |
+
if self.cfg.use_magmax and self.merge_count > 0:
|
| 107 |
+
print(f"[protect] Computing MagMax masks (protecting top 20% by magnitude)...")
|
| 108 |
+
state = target_model.state_dict()
|
| 109 |
+
for key, param in state.items():
|
| 110 |
+
if param.dim() >= 1:
|
| 111 |
+
flat = param.abs().flatten()
|
| 112 |
+
threshold = torch.quantile(flat.float(), 0.8)
|
| 113 |
+
self.magnitude_masks[key] = param.abs() >= threshold
|
| 114 |
+
|
| 115 |
+
return adjusted_alpha
|
| 116 |
+
|
| 117 |
+
def apply_protection(
|
| 118 |
+
self,
|
| 119 |
+
target_state: dict,
|
| 120 |
+
pre_merge_state: dict,
|
| 121 |
+
key: str,
|
| 122 |
+
) -> torch.Tensor:
|
| 123 |
+
"""
|
| 124 |
+
Apply all protection mechanisms to a fused parameter.
|
| 125 |
+
|
| 126 |
+
Called AFTER each parameter is fused, to constrain the change.
|
| 127 |
+
|
| 128 |
+
Protection stack (applied in order):
|
| 129 |
+
1. ARM steering (2602.03237) — steer delta toward gap, away from previous direction
|
| 130 |
+
2. Orthogonal projection (legacy fallback if ARM disabled)
|
| 131 |
+
3. OTMF masks (2511.19561) — protect task-specific weights
|
| 132 |
+
4. MagMax — protect top magnitude params (extra safety layer)
|
| 133 |
+
"""
|
| 134 |
+
fused = target_state[key]
|
| 135 |
+
original = pre_merge_state[key]
|
| 136 |
+
delta = fused - original
|
| 137 |
+
|
| 138 |
+
# --- ARM Steering (new, replaces orthogonal projection) ---
|
| 139 |
+
if self.cfg.use_arm_steering and self.arm_rotations:
|
| 140 |
+
# Find matching layer rotation
|
| 141 |
+
layer_prefix = ".".join(key.split(".")[:4])
|
| 142 |
+
for layer_name, rotation_info in self.arm_rotations.items():
|
| 143 |
+
if layer_prefix in layer_name:
|
| 144 |
+
delta = apply_arm_steering(
|
| 145 |
+
delta, rotation_info,
|
| 146 |
+
steering_strength=self.cfg.arm_steering_strength,
|
| 147 |
+
)
|
| 148 |
+
break
|
| 149 |
+
|
| 150 |
+
# --- Orthogonal Projection (legacy fallback) ---
|
| 151 |
+
elif self.cfg.use_orthogonal_projection and key in self.previous_deltas:
|
| 152 |
+
for prev_delta in self.previous_deltas[key]:
|
| 153 |
+
prev_flat = prev_delta.flatten().float()
|
| 154 |
+
delta_flat = delta.flatten().float()
|
| 155 |
+
|
| 156 |
+
dot = torch.dot(delta_flat, prev_flat)
|
| 157 |
+
norm_sq = torch.dot(prev_flat, prev_flat)
|
| 158 |
+
|
| 159 |
+
if norm_sq > 1e-10:
|
| 160 |
+
projection = (dot / norm_sq) * prev_flat
|
| 161 |
+
delta_flat = delta_flat - projection
|
| 162 |
+
delta = delta_flat.reshape(delta.shape).to(delta.dtype)
|
| 163 |
+
|
| 164 |
+
# --- OTMF Mask Protection (new) ---
|
| 165 |
+
if self.cfg.use_otmf_masks and key in self.otmf_masks:
|
| 166 |
+
mask = self.otmf_masks[key].to(delta.device)
|
| 167 |
+
# Transferable weights: full delta
|
| 168 |
+
# Task-specific weights: reduced delta (protect them)
|
| 169 |
+
delta = torch.where(
|
| 170 |
+
mask,
|
| 171 |
+
delta, # Transferable → allow full change
|
| 172 |
+
delta * (1.0 - self.cfg.otmf_protect_strength), # Protected → reduced
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# --- MagMax Protection (extra safety layer) ---
|
| 176 |
+
if self.cfg.use_magmax and key in self.magnitude_masks:
|
| 177 |
+
mask = self.magnitude_masks[key]
|
| 178 |
+
delta = torch.where(mask, delta * 0.1, delta)
|
| 179 |
+
|
| 180 |
+
# Apply constrained delta
|
| 181 |
+
result = original + delta
|
| 182 |
+
|
| 183 |
+
return result
|
| 184 |
+
|
| 185 |
+
def after_merge(
|
| 186 |
+
self,
|
| 187 |
+
target_model: AutoModelForCausalLM,
|
| 188 |
+
pre_merge_state: dict,
|
| 189 |
+
pre_merge_activations: dict = None,
|
| 190 |
+
post_merge_activations: dict = None,
|
| 191 |
+
):
|
| 192 |
+
"""
|
| 193 |
+
Record the merge delta and compute protections for next merge.
|
| 194 |
+
|
| 195 |
+
Called AFTER each merge completes successfully.
|
| 196 |
+
Now also computes:
|
| 197 |
+
- ARM rotation vectors for next merge steering
|
| 198 |
+
- OTMF transferability masks for next merge
|
| 199 |
+
"""
|
| 200 |
+
current_state = target_model.state_dict()
|
| 201 |
+
|
| 202 |
+
for key in current_state:
|
| 203 |
+
if key in pre_merge_state:
|
| 204 |
+
delta = current_state[key].float() - pre_merge_state[key].float()
|
| 205 |
+
if delta.abs().max() > 1e-8:
|
| 206 |
+
if key not in self.previous_deltas:
|
| 207 |
+
self.previous_deltas[key] = []
|
| 208 |
+
if len(self.previous_deltas[key]) >= 2:
|
| 209 |
+
self.previous_deltas[key].pop(0)
|
| 210 |
+
self.previous_deltas[key].append(delta.cpu())
|
| 211 |
+
|
| 212 |
+
# --- Compute ARM rotations for next merge ---
|
| 213 |
+
if self.cfg.use_arm_steering and pre_merge_activations and post_merge_activations:
|
| 214 |
+
print("[protect] Computing ARM rotation vectors for next merge...")
|
| 215 |
+
self.arm_rotations = compute_arm_rotation(
|
| 216 |
+
pre_merge_activations,
|
| 217 |
+
post_merge_activations,
|
| 218 |
+
post_merge_activations, # Target = current state (for gap calculation)
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# --- Compute OTMF masks for next merge ---
|
| 222 |
+
if self.cfg.use_otmf_masks and post_merge_activations:
|
| 223 |
+
print("[protect] Computing OTMF transferability masks...")
|
| 224 |
+
self.otmf_masks = compute_transferability_masks(
|
| 225 |
+
target_model,
|
| 226 |
+
post_merge_activations,
|
| 227 |
+
threshold=self.cfg.otmf_threshold,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
self.merge_count += 1
|
| 231 |
+
print(f"[protect] Recorded merge delta #{self.merge_count} (ARM + OTMF ready for next)")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# ============================================================================
|
| 235 |
+
# MAIN ORCHESTRATOR
|
| 236 |
+
# ============================================================================
|
| 237 |
+
|
| 238 |
+
def is_vision_param(key: str, cfg: MergeConfig) -> bool:
|
| 239 |
+
"""
|
| 240 |
+
Check if a parameter belongs to the vision encoder.
|
| 241 |
+
|
| 242 |
+
Qwen3-VL-8B has a ViT vision encoder + merger projection on top of the
|
| 243 |
+
language model. We NEVER touch these during merging — they give us
|
| 244 |
+
browser agent and image understanding abilities for free.
|
| 245 |
+
|
| 246 |
+
Vision params start with prefixes like "visual." or "merger."
|
| 247 |
+
Language params start with "model.layers." or "model.embed_tokens." etc.
|
| 248 |
+
"""
|
| 249 |
+
for prefix in cfg.vision_skip_prefixes:
|
| 250 |
+
if key.startswith(prefix):
|
| 251 |
+
return True
|
| 252 |
+
return False
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def get_source_by_stage(stage_name: str) -> Optional[ModelConfig]:
|
| 256 |
+
"""Get model config by stage name."""
|
| 257 |
+
stage_map = {
|
| 258 |
+
"deepseek": 0,
|
| 259 |
+
"mimo": 1,
|
| 260 |
+
"llama": 2,
|
| 261 |
+
"falcon": 3,
|
| 262 |
+
}
|
| 263 |
+
idx = stage_map.get(stage_name.lower())
|
| 264 |
+
if idx is not None and idx < len(SOURCES):
|
| 265 |
+
return SOURCES[idx]
|
| 266 |
+
return None
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
|
| 270 |
+
"""Load a model and its tokenizer/processor."""
|
| 271 |
+
print(f"\n[merge] Loading {config.name} ({config.hf_id})...")
|
| 272 |
+
|
| 273 |
+
# Qwen3-VL uses a processor (handles both text + vision), not just a tokenizer
|
| 274 |
+
if config.architecture == "transformer+vision":
|
| 275 |
+
try:
|
| 276 |
+
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
|
| 277 |
+
processor = AutoProcessor.from_pretrained(
|
| 278 |
+
config.hf_id,
|
| 279 |
+
trust_remote_code=config.trust_remote_code,
|
| 280 |
+
)
|
| 281 |
+
model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 282 |
+
config.hf_id,
|
| 283 |
+
torch_dtype=getattr(torch, cfg.dtype),
|
| 284 |
+
attn_implementation=cfg.attn_implementation,
|
| 285 |
+
device_map=cfg.device_map,
|
| 286 |
+
trust_remote_code=config.trust_remote_code,
|
| 287 |
+
)
|
| 288 |
+
# Use the tokenizer from the processor for text operations
|
| 289 |
+
tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor
|
| 290 |
+
print(f"[merge] Loaded {config.name} (VL model): {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
|
| 291 |
+
|
| 292 |
+
# Count vision vs language params
|
| 293 |
+
vision_params = sum(
|
| 294 |
+
p.numel() for n, p in model.named_parameters()
|
| 295 |
+
if any(n.startswith(pfx) for pfx in cfg.vision_skip_prefixes)
|
| 296 |
+
)
|
| 297 |
+
lang_params = sum(p.numel() for p in model.parameters()) - vision_params
|
| 298 |
+
print(f"[merge] Language: {lang_params / 1e9:.1f}B | Vision: {vision_params / 1e9:.1f}B")
|
| 299 |
+
|
| 300 |
+
return model, tokenizer
|
| 301 |
+
except ImportError:
|
| 302 |
+
print("[merge] Qwen3VLForConditionalGeneration not available, falling back to AutoModel")
|
| 303 |
+
|
| 304 |
+
# Standard text-only models
|
| 305 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 306 |
+
config.hf_id,
|
| 307 |
+
trust_remote_code=config.trust_remote_code,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 311 |
+
config.hf_id,
|
| 312 |
+
torch_dtype=getattr(torch, cfg.dtype),
|
| 313 |
+
attn_implementation=cfg.attn_implementation,
|
| 314 |
+
device_map=cfg.device_map,
|
| 315 |
+
trust_remote_code=config.trust_remote_code,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
|
| 319 |
+
return model, tokenizer
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def save_checkpoint(
|
| 323 |
+
model: AutoModelForCausalLM,
|
| 324 |
+
tokenizer: AutoTokenizer,
|
| 325 |
+
stage_name: str,
|
| 326 |
+
cfg: MergeConfig,
|
| 327 |
+
):
|
| 328 |
+
"""Save a checkpoint after a successful merge stage."""
|
| 329 |
+
ckpt_dir = Path(cfg.checkpoint_dir) / f"after_{stage_name}"
|
| 330 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 331 |
+
|
| 332 |
+
print(f"[merge] Saving checkpoint to {ckpt_dir}...")
|
| 333 |
+
model.save_pretrained(ckpt_dir)
|
| 334 |
+
tokenizer.save_pretrained(ckpt_dir)
|
| 335 |
+
print(f"[merge] Checkpoint saved: {ckpt_dir}")
|
| 336 |
+
|
| 337 |
+
return str(ckpt_dir)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# ============================================================================
|
| 341 |
+
# RESIDUAL BANK — Save what was lost during each merge
|
| 342 |
+
# ============================================================================
|
| 343 |
+
|
| 344 |
+
class ResidualBank:
|
| 345 |
+
"""
|
| 346 |
+
Saves the knowledge that gets lost during each merge so it can
|
| 347 |
+
be recovered later.
|
| 348 |
+
|
| 349 |
+
When we blend at alpha=0.5:
|
| 350 |
+
merged = 0.5 × source + 0.5 × target
|
| 351 |
+
|
| 352 |
+
We LOSE:
|
| 353 |
+
target_residual = target_original - merged (what target lost)
|
| 354 |
+
source_residual = source_original - merged (what source lost)
|
| 355 |
+
|
| 356 |
+
These residuals are saved to disk. Later they can be:
|
| 357 |
+
1. Fed back during the healing fine-tune (as training signal)
|
| 358 |
+
2. Re-injected via a small LoRA adapter
|
| 359 |
+
3. Used to diagnose which merge caused a specific knowledge loss
|
| 360 |
+
4. Re-applied at a lower alpha if we want more of that model
|
| 361 |
+
|
| 362 |
+
Think of it like saving the sawdust when you cut wood — you might
|
| 363 |
+
need to glue some of it back later.
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
def __init__(self, cfg: MergeConfig):
|
| 367 |
+
self.cfg = cfg
|
| 368 |
+
self.residual_dir = Path(cfg.checkpoint_dir) / "residuals"
|
| 369 |
+
self.residual_dir.mkdir(parents=True, exist_ok=True)
|
| 370 |
+
self.residual_index = {} # stage → {path, stats}
|
| 371 |
+
|
| 372 |
+
def save_residuals(
|
| 373 |
+
self,
|
| 374 |
+
stage_name: str,
|
| 375 |
+
pre_merge_target_state: dict,
|
| 376 |
+
source_state: dict,
|
| 377 |
+
post_merge_state: dict,
|
| 378 |
+
source_config: ModelConfig,
|
| 379 |
+
):
|
| 380 |
+
"""
|
| 381 |
+
Compute and save what was lost from both target and source.
|
| 382 |
+
|
| 383 |
+
Saves two files per merge stage:
|
| 384 |
+
- target_residual: what the target model lost
|
| 385 |
+
- source_residual: what the source model didn't fully contribute
|
| 386 |
+
|
| 387 |
+
Also saves stats so we know WHERE the biggest losses were
|
| 388 |
+
(which layers, which type of weights).
|
| 389 |
+
"""
|
| 390 |
+
stage_dir = self.residual_dir / stage_name
|
| 391 |
+
stage_dir.mkdir(parents=True, exist_ok=True)
|
| 392 |
+
|
| 393 |
+
target_residual = {}
|
| 394 |
+
source_residual = {}
|
| 395 |
+
stats = {
|
| 396 |
+
"stage": stage_name,
|
| 397 |
+
"source_model": source_config.name,
|
| 398 |
+
"target_loss_by_layer": {},
|
| 399 |
+
"source_loss_by_layer": {},
|
| 400 |
+
"total_target_loss": 0.0,
|
| 401 |
+
"total_source_loss": 0.0,
|
| 402 |
+
"biggest_losses": [],
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
for key in post_merge_state:
|
| 406 |
+
merged_w = post_merge_state[key].float()
|
| 407 |
+
|
| 408 |
+
# What the target lost
|
| 409 |
+
if key in pre_merge_target_state:
|
| 410 |
+
original_target = pre_merge_target_state[key].float()
|
| 411 |
+
t_residual = original_target - merged_w
|
| 412 |
+
t_loss = t_residual.abs().mean().item()
|
| 413 |
+
|
| 414 |
+
if t_loss > 1e-6: # Only save meaningful residuals
|
| 415 |
+
target_residual[key] = t_residual.to(torch.bfloat16).cpu()
|
| 416 |
+
stats["total_target_loss"] += t_loss
|
| 417 |
+
|
| 418 |
+
# Track per-layer losses
|
| 419 |
+
layer_name = ".".join(key.split(".")[:4])
|
| 420 |
+
if layer_name not in stats["target_loss_by_layer"]:
|
| 421 |
+
stats["target_loss_by_layer"][layer_name] = 0.0
|
| 422 |
+
stats["target_loss_by_layer"][layer_name] += t_loss
|
| 423 |
+
|
| 424 |
+
# What the source lost (what didn't make it into the merge)
|
| 425 |
+
if key in source_state:
|
| 426 |
+
original_source = source_state[key].float()
|
| 427 |
+
s_residual = original_source - merged_w
|
| 428 |
+
s_loss = s_residual.abs().mean().item()
|
| 429 |
+
|
| 430 |
+
if s_loss > 1e-6:
|
| 431 |
+
source_residual[key] = s_residual.to(torch.bfloat16).cpu()
|
| 432 |
+
stats["total_source_loss"] += s_loss
|
| 433 |
+
|
| 434 |
+
layer_name = ".".join(key.split(".")[:4])
|
| 435 |
+
if layer_name not in stats["source_loss_by_layer"]:
|
| 436 |
+
stats["source_loss_by_layer"][layer_name] = 0.0
|
| 437 |
+
stats["source_loss_by_layer"][layer_name] += s_loss
|
| 438 |
+
|
| 439 |
+
# Find the biggest losses (most knowledge dropped)
|
| 440 |
+
all_losses = []
|
| 441 |
+
for key in target_residual:
|
| 442 |
+
loss_magnitude = target_residual[key].float().abs().mean().item()
|
| 443 |
+
all_losses.append({"param": key, "side": "target", "loss": loss_magnitude})
|
| 444 |
+
for key in source_residual:
|
| 445 |
+
loss_magnitude = source_residual[key].float().abs().mean().item()
|
| 446 |
+
all_losses.append({"param": key, "side": "source", "loss": loss_magnitude})
|
| 447 |
+
all_losses.sort(key=lambda x: x["loss"], reverse=True)
|
| 448 |
+
stats["biggest_losses"] = all_losses[:20] # Top 20 biggest losses
|
| 449 |
+
|
| 450 |
+
# Save to disk
|
| 451 |
+
torch.save(target_residual, stage_dir / "target_residual.pt")
|
| 452 |
+
torch.save(source_residual, stage_dir / "source_residual.pt")
|
| 453 |
+
|
| 454 |
+
import json
|
| 455 |
+
with open(stage_dir / "residual_stats.json", "w") as f:
|
| 456 |
+
json.dump(stats, f, indent=2, default=str)
|
| 457 |
+
|
| 458 |
+
self.residual_index[stage_name] = {
|
| 459 |
+
"path": str(stage_dir),
|
| 460 |
+
"target_params_saved": len(target_residual),
|
| 461 |
+
"source_params_saved": len(source_residual),
|
| 462 |
+
"total_target_loss": stats["total_target_loss"],
|
| 463 |
+
"total_source_loss": stats["total_source_loss"],
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
print(f"[residual] Saved residuals for {stage_name}:")
|
| 467 |
+
print(f" Target lost: {len(target_residual)} params (avg loss: {stats['total_target_loss']:.4f})")
|
| 468 |
+
print(f" Source lost: {len(source_residual)} params (avg loss: {stats['total_source_loss']:.4f})")
|
| 469 |
+
print(f" Top loss: {all_losses[0]['param']} ({all_losses[0]['side']}, {all_losses[0]['loss']:.4f})" if all_losses else "")
|
| 470 |
+
print(f" Saved to: {stage_dir}")
|
| 471 |
+
|
| 472 |
+
def load_residuals(self, stage_name: str) -> tuple:
|
| 473 |
+
"""
|
| 474 |
+
Load saved residuals for a stage.
|
| 475 |
+
|
| 476 |
+
Returns:
|
| 477 |
+
(target_residual_dict, source_residual_dict)
|
| 478 |
+
"""
|
| 479 |
+
stage_dir = self.residual_dir / stage_name
|
| 480 |
+
target_residual = torch.load(stage_dir / "target_residual.pt", weights_only=True)
|
| 481 |
+
source_residual = torch.load(stage_dir / "source_residual.pt", weights_only=True)
|
| 482 |
+
return target_residual, source_residual
|
| 483 |
+
|
| 484 |
+
def reinject_residuals(
|
| 485 |
+
self,
|
| 486 |
+
model: AutoModelForCausalLM,
|
| 487 |
+
stage_name: str,
|
| 488 |
+
side: str = "both",
|
| 489 |
+
strength: float = 0.3,
|
| 490 |
+
) -> AutoModelForCausalLM:
|
| 491 |
+
"""
|
| 492 |
+
Re-inject saved residuals back into a model.
|
| 493 |
+
|
| 494 |
+
This adds back some of what was lost. Use a low strength (0.1-0.3)
|
| 495 |
+
to gently recover knowledge without undoing the merge.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
model: The model to inject into
|
| 499 |
+
stage_name: Which merge stage's residuals to use
|
| 500 |
+
side: "target", "source", or "both"
|
| 501 |
+
strength: How much to add back (0=nothing, 1=full residual)
|
| 502 |
+
"""
|
| 503 |
+
print(f"[residual] Re-injecting {stage_name} residuals (side={side}, strength={strength})...")
|
| 504 |
+
|
| 505 |
+
target_residual, source_residual = self.load_residuals(stage_name)
|
| 506 |
+
state = model.state_dict()
|
| 507 |
+
injected = 0
|
| 508 |
+
|
| 509 |
+
if side in ("target", "both"):
|
| 510 |
+
for key, residual in target_residual.items():
|
| 511 |
+
if key in state:
|
| 512 |
+
state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
|
| 513 |
+
injected += 1
|
| 514 |
+
|
| 515 |
+
if side in ("source", "both"):
|
| 516 |
+
for key, residual in source_residual.items():
|
| 517 |
+
if key in state:
|
| 518 |
+
state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
|
| 519 |
+
injected += 1
|
| 520 |
+
|
| 521 |
+
model.load_state_dict(state)
|
| 522 |
+
print(f"[residual] Re-injected {injected} params at {strength:.0%} strength")
|
| 523 |
+
return model
|
| 524 |
+
|
| 525 |
+
def get_healing_targets(self, top_n: int = 50) -> list:
|
| 526 |
+
"""
|
| 527 |
+
Get the parameters with the biggest losses across ALL merges.
|
| 528 |
+
|
| 529 |
+
These are the params that the healing fine-tune should focus on.
|
| 530 |
+
Feed this to the LoRA target_modules to make healing smarter.
|
| 531 |
+
"""
|
| 532 |
+
import json
|
| 533 |
+
all_losses = []
|
| 534 |
+
|
| 535 |
+
for stage_name in self.residual_index:
|
| 536 |
+
stage_dir = self.residual_dir / stage_name
|
| 537 |
+
stats_file = stage_dir / "residual_stats.json"
|
| 538 |
+
if stats_file.exists():
|
| 539 |
+
with open(stats_file) as f:
|
| 540 |
+
stats = json.load(f)
|
| 541 |
+
for loss in stats.get("biggest_losses", []):
|
| 542 |
+
loss["stage"] = stage_name
|
| 543 |
+
all_losses.append(loss)
|
| 544 |
+
|
| 545 |
+
all_losses.sort(key=lambda x: x["loss"], reverse=True)
|
| 546 |
+
|
| 547 |
+
# Extract unique layer/module names for LoRA targeting
|
| 548 |
+
target_modules = set()
|
| 549 |
+
for loss in all_losses[:top_n]:
|
| 550 |
+
param = loss["param"]
|
| 551 |
+
# Extract the module type (q_proj, k_proj, gate_proj, etc.)
|
| 552 |
+
parts = param.split(".")
|
| 553 |
+
for part in parts:
|
| 554 |
+
if part.endswith("_proj") or part in ("gate_proj", "up_proj", "down_proj"):
|
| 555 |
+
target_modules.add(part)
|
| 556 |
+
|
| 557 |
+
print(f"[residual] Top healing targets (from {len(all_losses)} total losses):")
|
| 558 |
+
for loss in all_losses[:5]:
|
| 559 |
+
print(f" {loss['param']} ({loss['side']}, stage={loss['stage']}, loss={loss['loss']:.4f})")
|
| 560 |
+
print(f" → Suggested LoRA targets: {sorted(target_modules)}")
|
| 561 |
+
|
| 562 |
+
return list(target_modules)
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
def run_single_merge(
|
| 566 |
+
target_model: AutoModelForCausalLM,
|
| 567 |
+
target_tokenizer: AutoTokenizer,
|
| 568 |
+
source_config: ModelConfig,
|
| 569 |
+
cfg: MergeConfig,
|
| 570 |
+
protection: MergeProtection,
|
| 571 |
+
residual_bank: ResidualBank = None,
|
| 572 |
+
calibration_data: list = None,
|
| 573 |
+
baseline_perplexity: float = None,
|
| 574 |
+
merged_sources: list = None,
|
| 575 |
+
) -> dict:
|
| 576 |
+
"""
|
| 577 |
+
Run a single merge: source → target.
|
| 578 |
+
|
| 579 |
+
Full pipeline for one merge step:
|
| 580 |
+
1. Load source model
|
| 581 |
+
2. Inject canary into source
|
| 582 |
+
3. Extract activations from both
|
| 583 |
+
4. Compute transport plans
|
| 584 |
+
5. Apply merge protection
|
| 585 |
+
6. Fuse weights
|
| 586 |
+
7. Apply post-merge protection
|
| 587 |
+
8. Validate
|
| 588 |
+
|
| 589 |
+
Returns:
|
| 590 |
+
Dict with merge results, validation results, and status
|
| 591 |
+
"""
|
| 592 |
+
if merged_sources is None:
|
| 593 |
+
merged_sources = []
|
| 594 |
+
|
| 595 |
+
stage_name = source_config.name
|
| 596 |
+
print(f"\n{'=' * 70}")
|
| 597 |
+
print(f"MERGE STAGE: {stage_name} → target")
|
| 598 |
+
print(f"Risk level: {source_config.merge_risk.upper()}")
|
| 599 |
+
print(f"{'=' * 70}")
|
| 600 |
+
|
| 601 |
+
result = {
|
| 602 |
+
"stage": stage_name,
|
| 603 |
+
"status": "pending",
|
| 604 |
+
"validation": None,
|
| 605 |
+
"checkpoint": None,
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
# --- Step 1: Load source model ---
|
| 609 |
+
source_model, source_tokenizer = load_model(source_config, cfg)
|
| 610 |
+
|
| 611 |
+
# --- Step 2: Inject canary into source ---
|
| 612 |
+
if stage_name in CANARY_FACTS:
|
| 613 |
+
print(f"\n[merge] Injecting canary fact into {stage_name}...")
|
| 614 |
+
source_model = inject_canary(source_model, source_tokenizer, stage_name)
|
| 615 |
+
|
| 616 |
+
# --- Step 3: Load calibration data (if not provided) ---
|
| 617 |
+
if calibration_data is None:
|
| 618 |
+
calibration_data = load_calibration_data(cfg, target_tokenizer)
|
| 619 |
+
|
| 620 |
+
# --- Step 4: Extract activations ---
|
| 621 |
+
print(f"\n[merge] Extracting source activations...")
|
| 622 |
+
source_activations = extract_activations(source_model, calibration_data)
|
| 623 |
+
|
| 624 |
+
print(f"\n[merge] Extracting target activations...")
|
| 625 |
+
pre_merge_target_activations = extract_activations(target_model, calibration_data)
|
| 626 |
+
|
| 627 |
+
# --- Step 4.5: Mergeability pre-check (2601.22285) ---
|
| 628 |
+
if cfg.use_mergeability_check:
|
| 629 |
+
mergeability = compute_mergeability_score(
|
| 630 |
+
source_activations, pre_merge_target_activations, source_config
|
| 631 |
+
)
|
| 632 |
+
result["mergeability"] = mergeability
|
| 633 |
+
|
| 634 |
+
if mergeability["overall"] < cfg.mergeability_min_score:
|
| 635 |
+
print(f"\n[merge] ⚠ Mergeability score {mergeability['overall']:.2f} below threshold {cfg.mergeability_min_score}")
|
| 636 |
+
print(f"[merge] → {mergeability['recommendation']}")
|
| 637 |
+
result["status"] = "skipped_low_mergeability"
|
| 638 |
+
if "distillation_fallback" in source_config.special_handling:
|
| 639 |
+
result["fallback"] = "distillation"
|
| 640 |
+
del source_model, source_activations, pre_merge_target_activations
|
| 641 |
+
gc.collect()
|
| 642 |
+
if torch.cuda.is_available():
|
| 643 |
+
torch.cuda.empty_cache()
|
| 644 |
+
return result
|
| 645 |
+
|
| 646 |
+
# --- Step 5: Compute transport plans ---
|
| 647 |
+
transport_plans = compute_transport_plans(
|
| 648 |
+
source_activations, pre_merge_target_activations, cfg
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
# --- Step 5.5: RAM RL-weight disentanglement (2601.13572) ---
|
| 652 |
+
use_ram = (
|
| 653 |
+
cfg.use_ram_disentangle
|
| 654 |
+
and source_config.architecture in ("transformer", "transformer+mtp")
|
| 655 |
+
and source_config.merge_risk in ("low", "medium")
|
| 656 |
+
and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"])
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# --- Step 6: Pre-merge protection ---
|
| 660 |
+
adjusted_alpha = protection.before_merge(target_model, source_config)
|
| 661 |
+
|
| 662 |
+
# Override source alpha with time-adjusted value
|
| 663 |
+
source_config_adjusted = copy.copy(source_config)
|
| 664 |
+
source_config_adjusted.merge_alpha = adjusted_alpha
|
| 665 |
+
|
| 666 |
+
# Save pre-merge state for protection
|
| 667 |
+
pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()}
|
| 668 |
+
|
| 669 |
+
# --- Step 7: Fuse weights ---
|
| 670 |
+
if use_ram:
|
| 671 |
+
# RAM path: disentangle RL weights, merge with preservation
|
| 672 |
+
print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
|
| 673 |
+
try:
|
| 674 |
+
# Try loading the base (pre-RL) model for disentanglement
|
| 675 |
+
base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
|
| 676 |
+
print(f"[merge] Loading base model for RAM: {base_hf_id}")
|
| 677 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 678 |
+
base_hf_id,
|
| 679 |
+
torch_dtype=getattr(torch, cfg.dtype),
|
| 680 |
+
device_map=cfg.device_map,
|
| 681 |
+
trust_remote_code=source_config.trust_remote_code,
|
| 682 |
+
)
|
| 683 |
+
shared_mask, rl_mask = disentangle_rl_weights(
|
| 684 |
+
source_model, base_model, cfg.ram_rl_threshold
|
| 685 |
+
)
|
| 686 |
+
# Fuse with RL preservation
|
| 687 |
+
target_state = merge_with_rl_preservation(
|
| 688 |
+
target_model.state_dict(),
|
| 689 |
+
source_model.state_dict(),
|
| 690 |
+
shared_mask, rl_mask,
|
| 691 |
+
shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha),
|
| 692 |
+
rl_alpha=cfg.ram_rl_alpha,
|
| 693 |
+
)
|
| 694 |
+
target_model.load_state_dict(target_state)
|
| 695 |
+
del base_model
|
| 696 |
+
print(f"[merge] RAM merge complete for {stage_name}")
|
| 697 |
+
except Exception as e:
|
| 698 |
+
print(f"[merge] RAM failed ({e}), falling back to standard T&M merge")
|
| 699 |
+
target_model = fuse_weights(
|
| 700 |
+
source_model, target_model, transport_plans,
|
| 701 |
+
source_config_adjusted, cfg,
|
| 702 |
+
)
|
| 703 |
+
else:
|
| 704 |
+
# Standard T&M path
|
| 705 |
+
target_model = fuse_weights(
|
| 706 |
+
source_model, target_model, transport_plans,
|
| 707 |
+
source_config_adjusted, cfg,
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# --- Step 7.5: Theseus fallback check (2602.12952) ---
|
| 711 |
+
# If T&M merge produced poor activation alignment, try Theseus
|
| 712 |
+
if cfg.use_theseus_fallback and source_config.merge_risk == "high":
|
| 713 |
+
print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...")
|
| 714 |
+
post_activations = extract_activations(target_model, calibration_data[:50]) # Quick check
|
| 715 |
+
# Compare post-merge activations to pre-merge — if too similar, T&M didn't work
|
| 716 |
+
alignment_scores = []
|
| 717 |
+
for key in post_activations:
|
| 718 |
+
if key in pre_merge_target_activations:
|
| 719 |
+
cos = torch.nn.functional.cosine_similarity(
|
| 720 |
+
post_activations[key].float().mean(0, keepdim=True),
|
| 721 |
+
pre_merge_target_activations[key].float().mean(0, keepdim=True),
|
| 722 |
+
)
|
| 723 |
+
alignment_scores.append(cos.item())
|
| 724 |
+
avg_change = 1.0 - np.mean(alignment_scores) if alignment_scores else 0.0
|
| 725 |
+
print(f"[merge] Activation change from merge: {avg_change:.4f}")
|
| 726 |
+
|
| 727 |
+
if avg_change < 0.01:
|
| 728 |
+
print(f"[merge] ⚠ T&M had minimal effect — activating Theseus fallback")
|
| 729 |
+
# Restore pre-merge state and try Theseus instead
|
| 730 |
+
target_model.load_state_dict(pre_merge_state)
|
| 731 |
+
try:
|
| 732 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 733 |
+
source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0],
|
| 734 |
+
torch_dtype=getattr(torch, cfg.dtype),
|
| 735 |
+
device_map=cfg.device_map,
|
| 736 |
+
trust_remote_code=source_config.trust_remote_code,
|
| 737 |
+
)
|
| 738 |
+
target_model = transport_task_vector_theseus(
|
| 739 |
+
source_model, base_model, target_model,
|
| 740 |
+
source_activations, pre_merge_target_activations,
|
| 741 |
+
alpha=cfg.theseus_alpha,
|
| 742 |
+
)
|
| 743 |
+
del base_model
|
| 744 |
+
print(f"[merge] Theseus transport complete for {stage_name}")
|
| 745 |
+
except Exception as e:
|
| 746 |
+
print(f"[merge] Theseus also failed ({e}). Using original T&M result.")
|
| 747 |
+
# Re-apply T&M result
|
| 748 |
+
target_model = fuse_weights(
|
| 749 |
+
source_model, target_model, transport_plans,
|
| 750 |
+
source_config_adjusted, cfg,
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
# --- Step 8: Apply post-merge protection (ARM + OTMF + MagMax) ---
|
| 754 |
+
# Skip vision encoder params — they weren't merged, so don't "protect" them
|
| 755 |
+
if protection.merge_count > 0:
|
| 756 |
+
print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...")
|
| 757 |
+
target_state = target_model.state_dict()
|
| 758 |
+
protected_count = 0
|
| 759 |
+
vision_skipped = 0
|
| 760 |
+
for key in target_state:
|
| 761 |
+
if is_vision_param(key, cfg):
|
| 762 |
+
vision_skipped += 1
|
| 763 |
+
continue # Don't touch vision encoder
|
| 764 |
+
if key in pre_merge_state:
|
| 765 |
+
protected_param = protection.apply_protection(
|
| 766 |
+
target_state, pre_merge_state, key
|
| 767 |
+
)
|
| 768 |
+
target_state[key] = protected_param
|
| 769 |
+
protected_count += 1
|
| 770 |
+
target_model.load_state_dict(target_state)
|
| 771 |
+
print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)")
|
| 772 |
+
|
| 773 |
+
# --- Step 8.5: Extract post-merge activations for ARM/OTMF ---
|
| 774 |
+
post_merge_activations = extract_activations(target_model, calibration_data[:100])
|
| 775 |
+
|
| 776 |
+
# Record this merge's delta + compute ARM/OTMF for next merge
|
| 777 |
+
protection.after_merge(
|
| 778 |
+
target_model, pre_merge_state,
|
| 779 |
+
pre_merge_activations=pre_merge_target_activations,
|
| 780 |
+
post_merge_activations=post_merge_activations,
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
# --- Step 8.8: Save residuals (what was lost from both sides) ---
|
| 784 |
+
if residual_bank is not None:
|
| 785 |
+
print(f"\n[merge] Saving residuals for {stage_name}...")
|
| 786 |
+
residual_bank.save_residuals(
|
| 787 |
+
stage_name=stage_name,
|
| 788 |
+
pre_merge_target_state=pre_merge_state,
|
| 789 |
+
source_state={k: v.cpu() for k, v in source_model.state_dict().items()},
|
| 790 |
+
post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()},
|
| 791 |
+
source_config=source_config,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
# --- Step 9: Free source model memory ---
|
| 795 |
+
del source_model, source_activations, pre_merge_target_activations
|
| 796 |
+
del transport_plans, post_merge_activations
|
| 797 |
+
gc.collect()
|
| 798 |
+
if torch.cuda.is_available():
|
| 799 |
+
torch.cuda.empty_cache()
|
| 800 |
+
|
| 801 |
+
# --- Step 10: Validate ---
|
| 802 |
+
merged_sources.append(stage_name)
|
| 803 |
+
validation = validate_merged_model(
|
| 804 |
+
target_model, target_tokenizer,
|
| 805 |
+
merged_sources, cfg,
|
| 806 |
+
baseline_perplexity=baseline_perplexity,
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
result["validation"] = validation
|
| 810 |
+
result["merged_sources"] = merged_sources.copy()
|
| 811 |
+
|
| 812 |
+
# --- Kill criteria check ---
|
| 813 |
+
if not validation["overall"]:
|
| 814 |
+
print(f"\n[merge] ⚠ VALIDATION FAILED for {stage_name}")
|
| 815 |
+
print(f"[merge] Kill criteria triggered — consider aborting")
|
| 816 |
+
result["status"] = "failed"
|
| 817 |
+
|
| 818 |
+
# Check if we should try distillation fallback
|
| 819 |
+
if "distillation_fallback" in source_config.special_handling:
|
| 820 |
+
print(f"[merge] {stage_name} has distillation fallback available")
|
| 821 |
+
result["fallback"] = "distillation"
|
| 822 |
+
else:
|
| 823 |
+
print(f"\n[merge] ✓ {stage_name} merge PASSED validation")
|
| 824 |
+
result["status"] = "passed"
|
| 825 |
+
|
| 826 |
+
return result
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
def run_pipeline(
|
| 830 |
+
stages: list[str],
|
| 831 |
+
cfg: MergeConfig = None,
|
| 832 |
+
) -> dict:
|
| 833 |
+
"""
|
| 834 |
+
Run the full merge pipeline.
|
| 835 |
+
|
| 836 |
+
Args:
|
| 837 |
+
stages: List of stage names to run, e.g. ["deepseek"] or
|
| 838 |
+
["deepseek", "mimo", "llama", "falcon"]
|
| 839 |
+
cfg: Merge configuration (uses defaults if None)
|
| 840 |
+
|
| 841 |
+
Returns:
|
| 842 |
+
Dict with overall results, per-stage results, and final model path
|
| 843 |
+
"""
|
| 844 |
+
if cfg is None:
|
| 845 |
+
cfg = MergeConfig()
|
| 846 |
+
|
| 847 |
+
print("\n" + "=" * 70)
|
| 848 |
+
print("TD FUSE — Transport and Merge Pipeline")
|
| 849 |
+
print(f"Target: {TARGET.name} ({TARGET.hf_id})")
|
| 850 |
+
if TARGET.architecture == "transformer+vision":
|
| 851 |
+
print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)")
|
| 852 |
+
print(f"Stages: {', '.join(stages)}")
|
| 853 |
+
print(f"Output: {cfg.output_dir}")
|
| 854 |
+
print("=" * 70)
|
| 855 |
+
|
| 856 |
+
# Setup
|
| 857 |
+
try:
|
| 858 |
+
setup_tm_repo(cfg)
|
| 859 |
+
except FileNotFoundError as e:
|
| 860 |
+
print(f"\n⚠ {e}")
|
| 861 |
+
print("Continuing with fallback implementation...")
|
| 862 |
+
|
| 863 |
+
# Create output directories
|
| 864 |
+
Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
|
| 865 |
+
Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
| 866 |
+
|
| 867 |
+
# --- Load target model ---
|
| 868 |
+
target_model, target_tokenizer = load_model(TARGET, cfg)
|
| 869 |
+
|
| 870 |
+
# --- Inject canary into target (Qwen3's own canary) ---
|
| 871 |
+
if "Qwen3-VL-8B" in CANARY_FACTS:
|
| 872 |
+
print("\n[pipeline] Injecting canary into base Qwen3-8B...")
|
| 873 |
+
target_model = inject_canary(target_model, target_tokenizer, "Qwen3-VL-8B")
|
| 874 |
+
|
| 875 |
+
# --- Compute baseline perplexity ---
|
| 876 |
+
print("\n[pipeline] Computing baseline perplexity...")
|
| 877 |
+
baseline_ppl = compute_perplexity(target_model, target_tokenizer)
|
| 878 |
+
print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}")
|
| 879 |
+
|
| 880 |
+
# --- Load calibration data once ---
|
| 881 |
+
calibration_data = load_calibration_data(cfg, target_tokenizer)
|
| 882 |
+
|
| 883 |
+
# --- Initialize merge protection + residual bank ---
|
| 884 |
+
protection = MergeProtection(cfg)
|
| 885 |
+
residual_bank = ResidualBank(cfg)
|
| 886 |
+
|
| 887 |
+
# --- Run each merge stage ---
|
| 888 |
+
pipeline_results = {
|
| 889 |
+
"stages": {},
|
| 890 |
+
"baseline_perplexity": baseline_ppl,
|
| 891 |
+
"final_checkpoint": None,
|
| 892 |
+
"residuals": {},
|
| 893 |
+
"overall_status": "pending",
|
| 894 |
+
}
|
| 895 |
+
merged_sources = []
|
| 896 |
+
all_passed = True
|
| 897 |
+
|
| 898 |
+
for stage_name in stages:
|
| 899 |
+
source_config = get_source_by_stage(stage_name)
|
| 900 |
+
if source_config is None:
|
| 901 |
+
print(f"\n⚠ Unknown stage: {stage_name}, skipping")
|
| 902 |
+
continue
|
| 903 |
+
|
| 904 |
+
# --- Wasserstein pre-check for high-risk models ---
|
| 905 |
+
if "check_wasserstein_first" in source_config.special_handling:
|
| 906 |
+
print(f"\n[pipeline] Running Wasserstein pre-check for {source_config.name}...")
|
| 907 |
+
# TODO: Implement Wasserstein distance pre-check
|
| 908 |
+
# If distance is too high, skip to distillation fallback
|
| 909 |
+
print("[pipeline] Pre-check: proceeding (TODO: implement distance check)")
|
| 910 |
+
|
| 911 |
+
# Run the merge (with residual bank to save what's lost)
|
| 912 |
+
stage_result = run_single_merge(
|
| 913 |
+
target_model, target_tokenizer,
|
| 914 |
+
source_config, cfg,
|
| 915 |
+
protection,
|
| 916 |
+
residual_bank=residual_bank,
|
| 917 |
+
calibration_data=calibration_data,
|
| 918 |
+
baseline_perplexity=baseline_ppl,
|
| 919 |
+
merged_sources=merged_sources,
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
pipeline_results["stages"][stage_name] = stage_result
|
| 923 |
+
|
| 924 |
+
if stage_result["status"] == "passed":
|
| 925 |
+
# Save checkpoint
|
| 926 |
+
ckpt_path = save_checkpoint(
|
| 927 |
+
target_model, target_tokenizer, stage_name, cfg
|
| 928 |
+
)
|
| 929 |
+
stage_result["checkpoint"] = ckpt_path
|
| 930 |
+
pipeline_results["final_checkpoint"] = ckpt_path
|
| 931 |
+
else:
|
| 932 |
+
all_passed = False
|
| 933 |
+
print(f"\n[pipeline] Stage {stage_name} FAILED")
|
| 934 |
+
|
| 935 |
+
# Decision: abort or continue?
|
| 936 |
+
if source_config.merge_risk == "high":
|
| 937 |
+
print(f"[pipeline] High-risk model failed — skipping (will use distillation)")
|
| 938 |
+
# Don't abort the whole pipeline, just skip this model
|
| 939 |
+
continue
|
| 940 |
+
else:
|
| 941 |
+
print(f"[pipeline] ABORTING pipeline — non-high-risk model failed")
|
| 942 |
+
pipeline_results["overall_status"] = f"aborted_at_{stage_name}"
|
| 943 |
+
break
|
| 944 |
+
|
| 945 |
+
# --- Save residual index ---
|
| 946 |
+
pipeline_results["residuals"] = residual_bank.residual_index
|
| 947 |
+
if residual_bank.residual_index:
|
| 948 |
+
print(f"\n[pipeline] Residual bank: {len(residual_bank.residual_index)} stages saved")
|
| 949 |
+
for stage, info in residual_bank.residual_index.items():
|
| 950 |
+
print(f" {stage}: target lost {info['total_target_loss']:.4f}, source lost {info['total_source_loss']:.4f}")
|
| 951 |
+
|
| 952 |
+
# Identify which modules need the most healing
|
| 953 |
+
healing_targets = residual_bank.get_healing_targets(top_n=50)
|
| 954 |
+
pipeline_results["suggested_healing_targets"] = healing_targets
|
| 955 |
+
|
| 956 |
+
# --- Save final model ---
|
| 957 |
+
if pipeline_results["final_checkpoint"]:
|
| 958 |
+
final_dir = Path(cfg.output_dir) / "final"
|
| 959 |
+
final_dir.mkdir(parents=True, exist_ok=True)
|
| 960 |
+
target_model.save_pretrained(final_dir)
|
| 961 |
+
target_tokenizer.save_pretrained(final_dir)
|
| 962 |
+
pipeline_results["final_model_path"] = str(final_dir)
|
| 963 |
+
print(f"\n[pipeline] Final model saved to {final_dir}")
|
| 964 |
+
|
| 965 |
+
if all_passed:
|
| 966 |
+
pipeline_results["overall_status"] = "all_passed"
|
| 967 |
+
elif pipeline_results["overall_status"] == "pending":
|
| 968 |
+
pipeline_results["overall_status"] = "partial"
|
| 969 |
+
|
| 970 |
+
# --- Print final summary ---
|
| 971 |
+
print("\n" + "=" * 70)
|
| 972 |
+
print("PIPELINE SUMMARY")
|
| 973 |
+
print("=" * 70)
|
| 974 |
+
for stage_name, stage_result in pipeline_results["stages"].items():
|
| 975 |
+
status = stage_result["status"]
|
| 976 |
+
emoji = "✓" if status == "passed" else "✗"
|
| 977 |
+
print(f" {emoji} {stage_name}: {status}")
|
| 978 |
+
print(f"\n Overall: {pipeline_results['overall_status']}")
|
| 979 |
+
if residual_bank.residual_index:
|
| 980 |
+
print(f"\n Residuals saved for: {', '.join(residual_bank.residual_index.keys())}")
|
| 981 |
+
print(f" To recover lost knowledge later:")
|
| 982 |
+
print(f" python -m td_fuse.run --reinject <stage> --strength 0.2")
|
| 983 |
+
print("=" * 70)
|
| 984 |
+
|
| 985 |
+
return pipeline_results
|
hugging/td_fuse/run.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Fuse — Main Entry Point.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
# Dad demo: merge just DeepSeek → Qwen3-8B (easiest, lowest risk)
|
| 6 |
+
python -m td_fuse.run --stage demo
|
| 7 |
+
|
| 8 |
+
# Full pipeline: all 4 merges
|
| 9 |
+
python -m td_fuse.run --stage all
|
| 10 |
+
|
| 11 |
+
# Single model merge
|
| 12 |
+
python -m td_fuse.run --stage deepseek
|
| 13 |
+
python -m td_fuse.run --stage mimo
|
| 14 |
+
python -m td_fuse.run --stage llama
|
| 15 |
+
python -m td_fuse.run --stage falcon
|
| 16 |
+
|
| 17 |
+
# With healing fine-tune after merge
|
| 18 |
+
python -m td_fuse.run --stage demo --heal
|
| 19 |
+
|
| 20 |
+
# Custom output directory
|
| 21 |
+
python -m td_fuse.run --stage all --output ./my_output
|
| 22 |
+
|
| 23 |
+
# Heal an existing checkpoint
|
| 24 |
+
python -m td_fuse.run --heal-only --model-path ./td_fuse_checkpoints/after_deepseek
|
| 25 |
+
|
| 26 |
+
Findings: #25 (dad demo plan), #22 (merge order), #24 (official T&M pipeline)
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import argparse
|
| 30 |
+
import json
|
| 31 |
+
import sys
|
| 32 |
+
import time
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
|
| 35 |
+
from .config import MergeConfig, DEMO_STAGES, FULL_STAGES
|
| 36 |
+
from .merge import run_pipeline, ResidualBank
|
| 37 |
+
from .heal import heal_model
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def parse_args():
|
| 41 |
+
parser = argparse.ArgumentParser(
|
| 42 |
+
description="TD Fuse — Transport and Merge pipeline for Time Dilation",
|
| 43 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 44 |
+
epilog="""
|
| 45 |
+
Examples:
|
| 46 |
+
python -m td_fuse.run --stage demo # Dad demo (DeepSeek only)
|
| 47 |
+
python -m td_fuse.run --stage all # Full 4-model merge
|
| 48 |
+
python -m td_fuse.run --stage all --heal # Merge + healing fine-tune
|
| 49 |
+
python -m td_fuse.run --heal-only --model-path ./checkpoint
|
| 50 |
+
python -m td_fuse.run --reinject deepseek --strength 0.2 --model-path ./final
|
| 51 |
+
""",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--stage",
|
| 56 |
+
type=str,
|
| 57 |
+
default="demo",
|
| 58 |
+
choices=["demo", "all", "deepseek", "mimo", "llama", "falcon"],
|
| 59 |
+
help="Which merge stage(s) to run (default: demo)",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--heal",
|
| 63 |
+
action="store_true",
|
| 64 |
+
help="Run healing fine-tune after merge",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--heal-only",
|
| 68 |
+
action="store_true",
|
| 69 |
+
help="Only run healing (skip merge), requires --model-path",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--model-path",
|
| 73 |
+
type=str,
|
| 74 |
+
default=None,
|
| 75 |
+
help="Path to existing model/checkpoint (for --heal-only)",
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--output",
|
| 79 |
+
type=str,
|
| 80 |
+
default="./td_fuse_outputs",
|
| 81 |
+
help="Output directory (default: ./td_fuse_outputs)",
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--checkpoint-dir",
|
| 85 |
+
type=str,
|
| 86 |
+
default="./td_fuse_checkpoints",
|
| 87 |
+
help="Checkpoint directory (default: ./td_fuse_checkpoints)",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--tm-repo",
|
| 91 |
+
type=str,
|
| 92 |
+
default="./Cross-Architecture-Merging-for-Large-Language-Models",
|
| 93 |
+
help="Path to official T&M repo",
|
| 94 |
+
)
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--dry-run",
|
| 97 |
+
action="store_true",
|
| 98 |
+
help="Print what would happen without actually running",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--reinject",
|
| 102 |
+
type=str,
|
| 103 |
+
default=None,
|
| 104 |
+
help="Re-inject saved residuals from a stage (e.g., --reinject deepseek)",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--reinject-side",
|
| 108 |
+
type=str,
|
| 109 |
+
default="both",
|
| 110 |
+
choices=["target", "source", "both"],
|
| 111 |
+
help="Which side's residuals to re-inject (default: both)",
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--strength",
|
| 115 |
+
type=float,
|
| 116 |
+
default=0.2,
|
| 117 |
+
help="Residual re-injection strength, 0-1 (default: 0.2)",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
return parser.parse_args()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def print_banner():
|
| 124 |
+
"""Print the TD Fuse banner."""
|
| 125 |
+
banner = """
|
| 126 |
+
╔══════════════════════════════════════════════════╗
|
| 127 |
+
║ ║
|
| 128 |
+
║ ████████╗██████╗ ███████╗██╗ ██╗███████╗ ║
|
| 129 |
+
║ ╚══██╔══╝██╔══██╗ ██╔════╝██║ ██║██╔════╝ ║
|
| 130 |
+
║ ██║ ██║ ██║ █████╗ ██║ ██║███████╗ ║
|
| 131 |
+
║ ██║ ██║ ██║ ██╔══╝ ██║ ██║╚════██║ ║
|
| 132 |
+
║ ██║ ██████╔╝ ██║ ╚██████╔╝███████║ ║
|
| 133 |
+
║ ╚═╝ ╚═════╝ ╚═╝ ╚═════╝ ╚══════╝ ║
|
| 134 |
+
║ ║
|
| 135 |
+
║ Transport and Merge for Time Dilation ║
|
| 136 |
+
║ Merging 5 models into Qwen3-8B ║
|
| 137 |
+
║ ║
|
| 138 |
+
╚══════════════════════════════════��═══════════════╝
|
| 139 |
+
"""
|
| 140 |
+
print(banner)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def main():
|
| 144 |
+
args = parse_args()
|
| 145 |
+
print_banner()
|
| 146 |
+
|
| 147 |
+
# Build config from args
|
| 148 |
+
cfg = MergeConfig(
|
| 149 |
+
output_dir=args.output,
|
| 150 |
+
checkpoint_dir=args.checkpoint_dir,
|
| 151 |
+
tm_repo_path=args.tm_repo,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Determine which stages to run
|
| 155 |
+
if args.stage == "demo":
|
| 156 |
+
stages = DEMO_STAGES
|
| 157 |
+
elif args.stage == "all":
|
| 158 |
+
stages = FULL_STAGES
|
| 159 |
+
else:
|
| 160 |
+
stages = [args.stage]
|
| 161 |
+
|
| 162 |
+
# --- Reinject residuals mode ---
|
| 163 |
+
if args.reinject:
|
| 164 |
+
if not args.model_path:
|
| 165 |
+
print("Error: --reinject requires --model-path")
|
| 166 |
+
sys.exit(1)
|
| 167 |
+
|
| 168 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 169 |
+
import torch
|
| 170 |
+
|
| 171 |
+
print(f"\n[run] Re-injecting residuals from stage: {args.reinject}")
|
| 172 |
+
print(f"[run] Side: {args.reinject_side}, Strength: {args.strength}")
|
| 173 |
+
|
| 174 |
+
residual_bank = ResidualBank(cfg)
|
| 175 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
| 176 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 177 |
+
args.model_path,
|
| 178 |
+
torch_dtype=torch.bfloat16,
|
| 179 |
+
device_map="auto",
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
model = residual_bank.reinject_residuals(
|
| 183 |
+
model, args.reinject,
|
| 184 |
+
side=args.reinject_side,
|
| 185 |
+
strength=args.strength,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Save the patched model
|
| 189 |
+
patched_dir = Path(cfg.output_dir) / f"reinjected_{args.reinject}_{args.strength}"
|
| 190 |
+
patched_dir.mkdir(parents=True, exist_ok=True)
|
| 191 |
+
model.save_pretrained(str(patched_dir))
|
| 192 |
+
tokenizer.save_pretrained(str(patched_dir))
|
| 193 |
+
print(f"\n[run] Patched model saved to: {patched_dir}")
|
| 194 |
+
return
|
| 195 |
+
|
| 196 |
+
# --- Heal-only mode ---
|
| 197 |
+
if args.heal_only:
|
| 198 |
+
if not args.model_path:
|
| 199 |
+
print("Error: --heal-only requires --model-path")
|
| 200 |
+
sys.exit(1)
|
| 201 |
+
|
| 202 |
+
print(f"\n[run] Healing model at: {args.model_path}")
|
| 203 |
+
healed_path = heal_model(args.model_path, cfg)
|
| 204 |
+
print(f"\n[run] Healed model saved to: {healed_path}")
|
| 205 |
+
return
|
| 206 |
+
|
| 207 |
+
# --- Dry run ---
|
| 208 |
+
if args.dry_run:
|
| 209 |
+
print("\n=== DRY RUN ===")
|
| 210 |
+
print(f"Stages: {stages}")
|
| 211 |
+
print(f"Output: {cfg.output_dir}")
|
| 212 |
+
print(f"Checkpoints: {cfg.checkpoint_dir}")
|
| 213 |
+
print(f"T&M repo: {cfg.tm_repo_path}")
|
| 214 |
+
print(f"Heal after: {args.heal}")
|
| 215 |
+
print(f"\nWould run:")
|
| 216 |
+
for i, stage in enumerate(stages, 1):
|
| 217 |
+
print(f" {i}. Merge {stage} → target")
|
| 218 |
+
print(f" → Validate (canary + perplexity + thinking + reasoning)")
|
| 219 |
+
print(f" → Checkpoint")
|
| 220 |
+
if args.heal:
|
| 221 |
+
print(f" {len(stages) + 1}. QLoRA healing fine-tune")
|
| 222 |
+
print("\nNo changes made (dry run).")
|
| 223 |
+
return
|
| 224 |
+
|
| 225 |
+
# --- Run the pipeline ---
|
| 226 |
+
start_time = time.time()
|
| 227 |
+
|
| 228 |
+
results = run_pipeline(stages, cfg)
|
| 229 |
+
|
| 230 |
+
elapsed = time.time() - start_time
|
| 231 |
+
print(f"\n[run] Pipeline completed in {elapsed / 60:.1f} minutes")
|
| 232 |
+
|
| 233 |
+
# --- Healing fine-tune (optional) ---
|
| 234 |
+
if args.heal and results.get("final_checkpoint"):
|
| 235 |
+
print("\n[run] Starting healing fine-tune...")
|
| 236 |
+
healed_path = heal_model(results["final_checkpoint"], cfg)
|
| 237 |
+
results["healed_model_path"] = healed_path
|
| 238 |
+
print(f"[run] Healed model: {healed_path}")
|
| 239 |
+
|
| 240 |
+
# --- Save results ---
|
| 241 |
+
results_path = Path(cfg.output_dir) / "pipeline_results.json"
|
| 242 |
+
|
| 243 |
+
# Convert non-serialisable objects
|
| 244 |
+
def make_serialisable(obj):
|
| 245 |
+
if isinstance(obj, dict):
|
| 246 |
+
return {k: make_serialisable(v) for k, v in obj.items()}
|
| 247 |
+
elif isinstance(obj, list):
|
| 248 |
+
return [make_serialisable(v) for v in obj]
|
| 249 |
+
elif isinstance(obj, (int, float, str, bool, type(None))):
|
| 250 |
+
return obj
|
| 251 |
+
else:
|
| 252 |
+
return str(obj)
|
| 253 |
+
|
| 254 |
+
with open(results_path, "w") as f:
|
| 255 |
+
json.dump(make_serialisable(results), f, indent=2)
|
| 256 |
+
print(f"[run] Results saved to {results_path}")
|
| 257 |
+
|
| 258 |
+
# --- Final summary ---
|
| 259 |
+
print(f"\n{'=' * 60}")
|
| 260 |
+
print("TD FUSE COMPLETE")
|
| 261 |
+
print(f"{'=' * 60}")
|
| 262 |
+
print(f" Status: {results['overall_status']}")
|
| 263 |
+
print(f" Time: {elapsed / 60:.1f} minutes")
|
| 264 |
+
if results.get("final_model_path"):
|
| 265 |
+
print(f" Model: {results['final_model_path']}")
|
| 266 |
+
if results.get("healed_model_path"):
|
| 267 |
+
print(f" Healed: {results['healed_model_path']}")
|
| 268 |
+
print(f" Results: {results_path}")
|
| 269 |
+
print(f"{'=' * 60}")
|
| 270 |
+
|
| 271 |
+
# Exit code based on result
|
| 272 |
+
if results["overall_status"] == "all_passed":
|
| 273 |
+
sys.exit(0)
|
| 274 |
+
else:
|
| 275 |
+
sys.exit(1)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if __name__ == "__main__":
|
| 279 |
+
main()
|
hugging/td_fuse/techniques.py
ADDED
|
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced Merge Techniques — from latest papers (Feb 2026).
|
| 3 |
+
|
| 4 |
+
This module contains implementations inspired by recent research
|
| 5 |
+
that improve TD's sequential cross-architecture merging pipeline.
|
| 6 |
+
|
| 7 |
+
Techniques:
|
| 8 |
+
1. Theseus (2602.12952) — Procrustes-based task vector transport
|
| 9 |
+
2. ARM (2602.03237) — Activation-guided rotation for sequential merges
|
| 10 |
+
3. OTMF (2511.19561) — OT masks for identifying transferable weights
|
| 11 |
+
4. RAM (2601.13572) — RL-weight disentanglement for RL-trained models
|
| 12 |
+
5. Mergeability (2601.22285) — Pre-check scoring before attempting merge
|
| 13 |
+
|
| 14 |
+
These complement Transport and Merge (2602.05495) which handles
|
| 15 |
+
the core cross-architecture fusion via optimal transport.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import numpy as np
|
| 20 |
+
from typing import Optional
|
| 21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 22 |
+
|
| 23 |
+
from .config import MergeConfig, ModelConfig
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ============================================================================
|
| 27 |
+
# 1. THESEUS — Procrustes-Based Task Vector Transport (2602.12952)
|
| 28 |
+
# ============================================================================
|
| 29 |
+
#
|
| 30 |
+
# Instead of aligning neurons via optimal transport (T&M), Theseus aligns
|
| 31 |
+
# the FUNCTIONAL EFFECT of weights via orthogonal Procrustes.
|
| 32 |
+
#
|
| 33 |
+
# Analogy: T&M says "neuron 5 in Model A = neuron 12 in Model B"
|
| 34 |
+
# Theseus says "the EFFECT of Model A's weights can be rotated
|
| 35 |
+
# into Model B's space"
|
| 36 |
+
#
|
| 37 |
+
# Best for: Models where neuron-level alignment is poor (Falcon SSM hybrid)
|
| 38 |
+
|
| 39 |
+
def compute_procrustes_alignment(
|
| 40 |
+
source_activations: torch.Tensor,
|
| 41 |
+
target_activations: torch.Tensor,
|
| 42 |
+
) -> torch.Tensor:
|
| 43 |
+
"""
|
| 44 |
+
Compute the orthogonal Procrustes rotation matrix R that best maps
|
| 45 |
+
source activations into target activation space.
|
| 46 |
+
|
| 47 |
+
R = argmin ||target - source @ R||_F subject to R^T R = I
|
| 48 |
+
|
| 49 |
+
Solution: R = V @ U^T from SVD of (source^T @ target) = U S V^T
|
| 50 |
+
|
| 51 |
+
This is a closed-form solution — no iterative optimisation needed.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
source_activations: [num_samples, source_dim] activation matrix
|
| 55 |
+
target_activations: [num_samples, target_dim] activation matrix
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
R: [source_dim, target_dim] rotation matrix
|
| 59 |
+
"""
|
| 60 |
+
# Center the activations (remove mean)
|
| 61 |
+
S = source_activations - source_activations.mean(dim=0, keepdim=True)
|
| 62 |
+
T = target_activations - target_activations.mean(dim=0, keepdim=True)
|
| 63 |
+
|
| 64 |
+
# Handle dimension mismatch by zero-padding the smaller one
|
| 65 |
+
s_dim = S.shape[1]
|
| 66 |
+
t_dim = T.shape[1]
|
| 67 |
+
max_dim = max(s_dim, t_dim)
|
| 68 |
+
|
| 69 |
+
if s_dim < max_dim:
|
| 70 |
+
S = torch.nn.functional.pad(S, (0, max_dim - s_dim))
|
| 71 |
+
if t_dim < max_dim:
|
| 72 |
+
T = torch.nn.functional.pad(T, (0, max_dim - t_dim))
|
| 73 |
+
|
| 74 |
+
# Cross-covariance matrix
|
| 75 |
+
M = S.T @ T # [max_dim, max_dim]
|
| 76 |
+
|
| 77 |
+
# SVD: M = U @ diag(sigma) @ V^T
|
| 78 |
+
U, sigma, Vt = torch.linalg.svd(M, full_matrices=True)
|
| 79 |
+
|
| 80 |
+
# Optimal rotation: R = V @ U^T
|
| 81 |
+
# This ensures R is orthogonal (R^T R = I)
|
| 82 |
+
R = Vt.T @ U.T
|
| 83 |
+
|
| 84 |
+
# Ensure proper rotation (det = +1), not reflection
|
| 85 |
+
det = torch.linalg.det(R)
|
| 86 |
+
if det < 0:
|
| 87 |
+
# Flip sign of last column of Vt
|
| 88 |
+
Vt[-1, :] *= -1
|
| 89 |
+
R = Vt.T @ U.T
|
| 90 |
+
|
| 91 |
+
return R[:s_dim, :t_dim] # Crop back to original dims
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def transport_task_vector_theseus(
|
| 95 |
+
source_model: AutoModelForCausalLM,
|
| 96 |
+
source_base_model: AutoModelForCausalLM,
|
| 97 |
+
target_model: AutoModelForCausalLM,
|
| 98 |
+
source_activations: dict,
|
| 99 |
+
target_activations: dict,
|
| 100 |
+
alpha: float = 0.3,
|
| 101 |
+
) -> AutoModelForCausalLM:
|
| 102 |
+
"""
|
| 103 |
+
Transport a task vector from source to target using Theseus method.
|
| 104 |
+
|
| 105 |
+
Task vector = source_finetuned - source_base
|
| 106 |
+
(the "diff" that represents what the model learned)
|
| 107 |
+
|
| 108 |
+
We rotate this diff into target's space using Procrustes alignment,
|
| 109 |
+
then add it to target: target_new = target + alpha * R @ task_vector
|
| 110 |
+
|
| 111 |
+
This is the FALLBACK for when T&M's neuron-level alignment fails
|
| 112 |
+
(e.g., Falcon's SSM components).
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
source_model: The fine-tuned source (e.g., Falcon-H1R-7B)
|
| 116 |
+
source_base_model: The base version of source (for computing task vector)
|
| 117 |
+
target_model: The target to transport into (our merged Qwen3)
|
| 118 |
+
source_activations: Layer → activation tensors for source
|
| 119 |
+
target_activations: Layer → activation tensors for target
|
| 120 |
+
alpha: Blending weight for the transported task vector
|
| 121 |
+
"""
|
| 122 |
+
print("[theseus] Computing task vectors and Procrustes alignment...")
|
| 123 |
+
|
| 124 |
+
source_state = source_model.state_dict()
|
| 125 |
+
base_state = source_base_model.state_dict()
|
| 126 |
+
target_state = target_model.state_dict()
|
| 127 |
+
|
| 128 |
+
# Compute per-layer Procrustes rotation matrices
|
| 129 |
+
rotations = {}
|
| 130 |
+
source_layers = sorted(source_activations.keys())
|
| 131 |
+
target_layers = sorted(target_activations.keys())
|
| 132 |
+
|
| 133 |
+
for sl, tl in zip(source_layers, target_layers):
|
| 134 |
+
if sl in source_activations and tl in target_activations:
|
| 135 |
+
R = compute_procrustes_alignment(
|
| 136 |
+
source_activations[sl].float(),
|
| 137 |
+
target_activations[tl].float(),
|
| 138 |
+
)
|
| 139 |
+
rotations[(sl, tl)] = R
|
| 140 |
+
|
| 141 |
+
# Transport task vectors
|
| 142 |
+
transported_count = 0
|
| 143 |
+
for target_key in target_state:
|
| 144 |
+
# Find matching source key (simplified — same key names)
|
| 145 |
+
source_key = target_key
|
| 146 |
+
if source_key not in source_state or source_key not in base_state:
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
# Task vector = what the source learned
|
| 150 |
+
task_vector = source_state[source_key].float() - base_state[source_key].float()
|
| 151 |
+
|
| 152 |
+
if task_vector.abs().max() < 1e-8:
|
| 153 |
+
continue # No meaningful change
|
| 154 |
+
|
| 155 |
+
# For 2D weight matrices, apply rotation
|
| 156 |
+
if task_vector.dim() == 2:
|
| 157 |
+
# Find the appropriate rotation for this layer
|
| 158 |
+
for (sl, tl), R in rotations.items():
|
| 159 |
+
if sl.split(".")[2] == target_key.split(".")[2]: # Same layer index
|
| 160 |
+
R_device = R.to(task_vector.device)
|
| 161 |
+
# Rotate: task_vector_rotated = task_vector @ R
|
| 162 |
+
try:
|
| 163 |
+
if task_vector.shape[1] == R_device.shape[0]:
|
| 164 |
+
task_vector = task_vector @ R_device
|
| 165 |
+
elif task_vector.shape[0] == R_device.shape[0]:
|
| 166 |
+
task_vector = R_device.T @ task_vector
|
| 167 |
+
except RuntimeError:
|
| 168 |
+
pass # Dimension mismatch, use unrotated
|
| 169 |
+
break
|
| 170 |
+
|
| 171 |
+
# Apply: target_new = target + alpha * rotated_task_vector
|
| 172 |
+
target_w = target_state[target_key]
|
| 173 |
+
if task_vector.shape == target_w.shape:
|
| 174 |
+
target_state[target_key] = target_w + alpha * task_vector.to(target_w.dtype)
|
| 175 |
+
transported_count += 1
|
| 176 |
+
|
| 177 |
+
target_model.load_state_dict(target_state)
|
| 178 |
+
print(f"[theseus] Transported {transported_count} task vectors via Procrustes")
|
| 179 |
+
return target_model
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# ============================================================================
|
| 183 |
+
# 2. ARM — Activation-Guided Rotations for Sequential Merging (2602.03237)
|
| 184 |
+
# ============================================================================
|
| 185 |
+
#
|
| 186 |
+
# ARM treats sequential merging like gradient descent — each merge step
|
| 187 |
+
# has a "direction" and a "learning rate" (merge coefficient).
|
| 188 |
+
#
|
| 189 |
+
# Key insight: Use ACTIVATION PATTERNS to compute optimal rotation vectors
|
| 190 |
+
# that guide each merge step. This is a smarter version of our
|
| 191 |
+
# orthogonal projection in MergeProtection.
|
| 192 |
+
|
| 193 |
+
def compute_arm_rotation(
|
| 194 |
+
pre_merge_activations: dict,
|
| 195 |
+
post_merge_activations: dict,
|
| 196 |
+
target_activations: dict,
|
| 197 |
+
) -> dict:
|
| 198 |
+
"""
|
| 199 |
+
Compute ARM rotation vectors for sequential merge protection.
|
| 200 |
+
|
| 201 |
+
For each layer, compute a rotation that:
|
| 202 |
+
1. Preserves the direction of knowledge already merged
|
| 203 |
+
2. Steers the next merge to fill GAPS rather than overwrite
|
| 204 |
+
|
| 205 |
+
The rotation is computed from the activation change (what the
|
| 206 |
+
last merge did) and the target (where we want to end up).
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
Dict of layer_name → rotation matrix
|
| 210 |
+
"""
|
| 211 |
+
print("[arm] Computing activation-guided rotations...")
|
| 212 |
+
|
| 213 |
+
rotations = {}
|
| 214 |
+
|
| 215 |
+
for layer_name in pre_merge_activations:
|
| 216 |
+
if layer_name not in post_merge_activations or layer_name not in target_activations:
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
pre = pre_merge_activations[layer_name].float() # Before last merge
|
| 220 |
+
post = post_merge_activations[layer_name].float() # After last merge
|
| 221 |
+
target = target_activations[layer_name].float() # Ideal target
|
| 222 |
+
|
| 223 |
+
# Delta from last merge
|
| 224 |
+
merge_delta = post - pre # [samples, hidden_dim]
|
| 225 |
+
|
| 226 |
+
# Gap remaining (what we still need)
|
| 227 |
+
gap = target - post # [samples, hidden_dim]
|
| 228 |
+
|
| 229 |
+
# Average across samples to get direction vectors
|
| 230 |
+
delta_dir = merge_delta.mean(dim=0) # [hidden_dim]
|
| 231 |
+
gap_dir = gap.mean(dim=0) # [hidden_dim]
|
| 232 |
+
|
| 233 |
+
# Normalise
|
| 234 |
+
delta_norm = delta_dir / (delta_dir.norm() + 1e-8)
|
| 235 |
+
gap_norm = gap_dir / (gap_dir.norm() + 1e-8)
|
| 236 |
+
|
| 237 |
+
# Compute rotation from delta direction to gap direction
|
| 238 |
+
# Using Rodrigues' rotation formula for the 2D plane
|
| 239 |
+
# spanned by delta and gap
|
| 240 |
+
cos_theta = torch.dot(delta_norm, gap_norm).clamp(-1, 1)
|
| 241 |
+
sin_theta = torch.sqrt(1 - cos_theta ** 2)
|
| 242 |
+
|
| 243 |
+
# Store as a simple rotation descriptor
|
| 244 |
+
rotations[layer_name] = {
|
| 245 |
+
"delta_direction": delta_norm,
|
| 246 |
+
"gap_direction": gap_norm,
|
| 247 |
+
"cos_theta": cos_theta.item(),
|
| 248 |
+
"sin_theta": sin_theta.item(),
|
| 249 |
+
"gap_magnitude": gap_dir.norm().item(),
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
return rotations
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def apply_arm_steering(
|
| 256 |
+
weight_delta: torch.Tensor,
|
| 257 |
+
rotation_info: dict,
|
| 258 |
+
steering_strength: float = 0.5,
|
| 259 |
+
) -> torch.Tensor:
|
| 260 |
+
"""
|
| 261 |
+
Steer a weight delta using ARM rotation vectors.
|
| 262 |
+
|
| 263 |
+
Instead of blindly projecting out previous merge directions
|
| 264 |
+
(our old orthogonal projection), ARM STEERS the delta toward
|
| 265 |
+
the remaining gap.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
weight_delta: The raw delta from the current merge
|
| 269 |
+
rotation_info: ARM rotation info for this layer
|
| 270 |
+
steering_strength: How much to steer (0=no steering, 1=full)
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
Steered weight delta
|
| 274 |
+
"""
|
| 275 |
+
delta_dir = rotation_info["delta_direction"]
|
| 276 |
+
gap_dir = rotation_info["gap_direction"]
|
| 277 |
+
|
| 278 |
+
flat = weight_delta.flatten().float()
|
| 279 |
+
|
| 280 |
+
# Component along previous merge direction
|
| 281 |
+
prev_component = torch.dot(flat, delta_dir.to(flat.device))
|
| 282 |
+
|
| 283 |
+
# Remove some of the previous-direction component
|
| 284 |
+
# and add gap-direction component instead
|
| 285 |
+
correction = (
|
| 286 |
+
-steering_strength * prev_component * delta_dir.to(flat.device)
|
| 287 |
+
+ steering_strength * prev_component * gap_dir.to(flat.device)
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
steered = flat + correction
|
| 291 |
+
return steered.reshape(weight_delta.shape).to(weight_delta.dtype)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# ============================================================================
|
| 295 |
+
# 3. OTMF — Transferability Masks via Optimal Transport (2511.19561)
|
| 296 |
+
# ============================================================================
|
| 297 |
+
#
|
| 298 |
+
# OTMF discovers which parts of each model are "transferable" (shared
|
| 299 |
+
# knowledge) vs "task-specific" (unique to that model).
|
| 300 |
+
#
|
| 301 |
+
# Transferable weights → safe to merge/average
|
| 302 |
+
# Task-specific weights → must be preserved carefully
|
| 303 |
+
#
|
| 304 |
+
# This replaces our MagMax "top 20% by magnitude" heuristic with a
|
| 305 |
+
# principled, data-driven approach.
|
| 306 |
+
|
| 307 |
+
def compute_transferability_masks(
|
| 308 |
+
model: AutoModelForCausalLM,
|
| 309 |
+
calibration_activations: dict,
|
| 310 |
+
threshold: float = 0.3,
|
| 311 |
+
) -> dict:
|
| 312 |
+
"""
|
| 313 |
+
Compute per-parameter transferability masks using activation variance.
|
| 314 |
+
|
| 315 |
+
High activation variance across diverse inputs → parameter encodes
|
| 316 |
+
task-specific knowledge (DON'T merge aggressively).
|
| 317 |
+
|
| 318 |
+
Low activation variance → parameter encodes shared/general knowledge
|
| 319 |
+
(safe to merge/average).
|
| 320 |
+
|
| 321 |
+
This is a simplified version of OTMF's OT-based mask discovery.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
model: The current merged model
|
| 325 |
+
calibration_activations: Layer → [samples, hidden_dim] activations
|
| 326 |
+
threshold: Variance quantile threshold for "task-specific" classification
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
Dict of param_name → bool mask (True = transferable/safe, False = task-specific/protect)
|
| 330 |
+
"""
|
| 331 |
+
print("[otmf] Computing transferability masks...")
|
| 332 |
+
|
| 333 |
+
masks = {}
|
| 334 |
+
state = model.state_dict()
|
| 335 |
+
|
| 336 |
+
# Compute per-neuron activation variance
|
| 337 |
+
neuron_importance = {}
|
| 338 |
+
for layer_name, acts in calibration_activations.items():
|
| 339 |
+
# Variance across samples: high variance = this neuron is doing something specific
|
| 340 |
+
variance = acts.var(dim=0) # [hidden_dim]
|
| 341 |
+
neuron_importance[layer_name] = variance
|
| 342 |
+
|
| 343 |
+
# Map neuron importance to parameter importance
|
| 344 |
+
for param_name, param in state.items():
|
| 345 |
+
# Find the corresponding layer's importance
|
| 346 |
+
layer_prefix = ".".join(param_name.split(".")[:4]) # e.g., model.layers.0.self_attn
|
| 347 |
+
|
| 348 |
+
importance = None
|
| 349 |
+
for layer_name, var in neuron_importance.items():
|
| 350 |
+
if layer_prefix in layer_name:
|
| 351 |
+
importance = var
|
| 352 |
+
break
|
| 353 |
+
|
| 354 |
+
if importance is None:
|
| 355 |
+
# Default: mark everything as transferable (safe to merge)
|
| 356 |
+
masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
|
| 357 |
+
continue
|
| 358 |
+
|
| 359 |
+
# For 2D weights: importance determines which rows/columns to protect
|
| 360 |
+
if param.dim() == 2:
|
| 361 |
+
rows, cols = param.shape
|
| 362 |
+
# Use importance for the output dimension
|
| 363 |
+
imp = importance[:rows] if importance.shape[0] >= rows else importance
|
| 364 |
+
|
| 365 |
+
# Compute threshold: top (1-threshold) fraction is task-specific
|
| 366 |
+
if imp.numel() > 0:
|
| 367 |
+
q = torch.quantile(imp.float(), 1.0 - threshold)
|
| 368 |
+
# True = transferable (below threshold), False = task-specific (protect)
|
| 369 |
+
row_mask = imp < q
|
| 370 |
+
masks[param_name] = row_mask.unsqueeze(1).expand_as(param)
|
| 371 |
+
else:
|
| 372 |
+
masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
|
| 373 |
+
else:
|
| 374 |
+
# 1D params (biases, norms): default to transferable
|
| 375 |
+
masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
|
| 376 |
+
|
| 377 |
+
transferable = sum(m.sum().item() for m in masks.values())
|
| 378 |
+
total = sum(m.numel() for m in masks.values())
|
| 379 |
+
print(f"[otmf] Transferability: {transferable / total:.1%} transferable, {1 - transferable / total:.1%} task-specific")
|
| 380 |
+
|
| 381 |
+
return masks
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def apply_masked_merge(
|
| 385 |
+
target_state: dict,
|
| 386 |
+
fused_state: dict,
|
| 387 |
+
masks: dict,
|
| 388 |
+
protect_strength: float = 0.8,
|
| 389 |
+
) -> dict:
|
| 390 |
+
"""
|
| 391 |
+
Apply transferability masks during merge.
|
| 392 |
+
|
| 393 |
+
For transferable weights: use the fused (merged) value
|
| 394 |
+
For task-specific weights: preserve more of the original target value
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
target_state: Original target weights (before this merge)
|
| 398 |
+
fused_state: Newly fused weights (after T&M/Theseus fusion)
|
| 399 |
+
masks: Transferability masks (True = safe to change)
|
| 400 |
+
protect_strength: How much to protect task-specific weights (0-1)
|
| 401 |
+
|
| 402 |
+
Returns:
|
| 403 |
+
Masked merged state dict
|
| 404 |
+
"""
|
| 405 |
+
result = {}
|
| 406 |
+
|
| 407 |
+
for key in fused_state:
|
| 408 |
+
if key in masks and key in target_state:
|
| 409 |
+
mask = masks[key].to(fused_state[key].device)
|
| 410 |
+
original = target_state[key]
|
| 411 |
+
fused = fused_state[key]
|
| 412 |
+
|
| 413 |
+
# Transferable: use fused value
|
| 414 |
+
# Task-specific: blend more toward original
|
| 415 |
+
blended = torch.where(
|
| 416 |
+
mask,
|
| 417 |
+
fused, # Transferable → take merged value
|
| 418 |
+
protect_strength * original + (1 - protect_strength) * fused, # Protected
|
| 419 |
+
)
|
| 420 |
+
result[key] = blended
|
| 421 |
+
else:
|
| 422 |
+
result[key] = fused_state[key]
|
| 423 |
+
|
| 424 |
+
protected_params = sum(1 for k in masks if not masks[k].all())
|
| 425 |
+
print(f"[otmf] Applied masks: {protected_params} parameters partially protected")
|
| 426 |
+
|
| 427 |
+
return result
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# ============================================================================
|
| 431 |
+
# 4. RAM — RL-Weight Disentanglement (2601.13572)
|
| 432 |
+
# ============================================================================
|
| 433 |
+
#
|
| 434 |
+
# RL-trained models (DeepSeek-R1, MiMo-7B-RL) have two types of knowledge:
|
| 435 |
+
# - Shared: general language understanding (same as base model)
|
| 436 |
+
# - RL-specific: reasoning patterns learned via GRPO/RLHF
|
| 437 |
+
#
|
| 438 |
+
# RAM separates these so we can merge the shared parts normally
|
| 439 |
+
# but PRESERVE the RL-specific parts that make these models special.
|
| 440 |
+
|
| 441 |
+
def disentangle_rl_weights(
|
| 442 |
+
rl_model: AutoModelForCausalLM,
|
| 443 |
+
base_model: AutoModelForCausalLM,
|
| 444 |
+
rl_threshold: float = 0.1,
|
| 445 |
+
) -> tuple:
|
| 446 |
+
"""
|
| 447 |
+
Separate RL-specific weights from shared/general weights.
|
| 448 |
+
|
| 449 |
+
RL-specific = weights that changed significantly during RL training
|
| 450 |
+
Shared = weights that are basically the same as base
|
| 451 |
+
|
| 452 |
+
We identify RL-specific weights by looking at the magnitude of
|
| 453 |
+
change from base model to RL model. Big changes → RL learned
|
| 454 |
+
something there → don't average it away.
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
rl_model: The RL-trained model (e.g., DeepSeek-R1, MiMo-7B-RL)
|
| 458 |
+
base_model: The base model before RL training
|
| 459 |
+
rl_threshold: Relative change threshold for "RL-specific" classification
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
Tuple of (shared_mask, rl_mask) — both are dicts of param_name → bool tensor
|
| 463 |
+
shared_mask: True = this weight is shared (safe to merge normally)
|
| 464 |
+
rl_mask: True = this weight is RL-specific (protect during merge)
|
| 465 |
+
"""
|
| 466 |
+
print("[ram] Disentangling RL-specific vs shared weights...")
|
| 467 |
+
|
| 468 |
+
rl_state = rl_model.state_dict()
|
| 469 |
+
base_state = base_model.state_dict()
|
| 470 |
+
|
| 471 |
+
shared_mask = {}
|
| 472 |
+
rl_mask = {}
|
| 473 |
+
|
| 474 |
+
total_params = 0
|
| 475 |
+
rl_params = 0
|
| 476 |
+
|
| 477 |
+
for key in rl_state:
|
| 478 |
+
if key not in base_state:
|
| 479 |
+
# New param (e.g., MTP head) — mark as RL-specific
|
| 480 |
+
rl_mask[key] = torch.ones_like(rl_state[key], dtype=torch.bool)
|
| 481 |
+
shared_mask[key] = torch.zeros_like(rl_state[key], dtype=torch.bool)
|
| 482 |
+
rl_params += rl_state[key].numel()
|
| 483 |
+
total_params += rl_state[key].numel()
|
| 484 |
+
continue
|
| 485 |
+
|
| 486 |
+
rl_w = rl_state[key].float()
|
| 487 |
+
base_w = base_state[key].float()
|
| 488 |
+
|
| 489 |
+
# Relative change: |rl - base| / (|base| + epsilon)
|
| 490 |
+
change = (rl_w - base_w).abs()
|
| 491 |
+
base_magnitude = base_w.abs() + 1e-8
|
| 492 |
+
relative_change = change / base_magnitude
|
| 493 |
+
|
| 494 |
+
# RL-specific: relative change > threshold
|
| 495 |
+
is_rl = relative_change > rl_threshold
|
| 496 |
+
rl_mask[key] = is_rl
|
| 497 |
+
shared_mask[key] = ~is_rl
|
| 498 |
+
|
| 499 |
+
rl_params += is_rl.sum().item()
|
| 500 |
+
total_params += is_rl.numel()
|
| 501 |
+
|
| 502 |
+
pct = rl_params / total_params * 100 if total_params > 0 else 0
|
| 503 |
+
print(f"[ram] RL-specific: {rl_params:,} params ({pct:.1f}%)")
|
| 504 |
+
print(f"[ram] Shared: {total_params - rl_params:,} params ({100 - pct:.1f}%)")
|
| 505 |
+
|
| 506 |
+
return shared_mask, rl_mask
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def merge_with_rl_preservation(
|
| 510 |
+
target_state: dict,
|
| 511 |
+
source_state: dict,
|
| 512 |
+
shared_mask: dict,
|
| 513 |
+
rl_mask: dict,
|
| 514 |
+
shared_alpha: float = 0.5,
|
| 515 |
+
rl_alpha: float = 0.8,
|
| 516 |
+
) -> dict:
|
| 517 |
+
"""
|
| 518 |
+
Merge source into target while preserving RL-specific weights.
|
| 519 |
+
|
| 520 |
+
Shared weights: normal blending at shared_alpha
|
| 521 |
+
RL-specific weights: stronger blending toward source (preserve RL knowledge)
|
| 522 |
+
|
| 523 |
+
This prevents the RL reasoning capabilities from being diluted
|
| 524 |
+
by averaging with target weights.
|
| 525 |
+
|
| 526 |
+
Args:
|
| 527 |
+
target_state: Current target model state
|
| 528 |
+
source_state: RL model state to merge in
|
| 529 |
+
shared_mask: Which params are shared (safe for normal merge)
|
| 530 |
+
rl_mask: Which params are RL-specific (preserve with higher alpha)
|
| 531 |
+
shared_alpha: Alpha for shared weights (normal)
|
| 532 |
+
rl_alpha: Alpha for RL-specific weights (higher = preserve more RL knowledge)
|
| 533 |
+
"""
|
| 534 |
+
print(f"[ram] Merging with RL preservation (shared α={shared_alpha}, RL α={rl_alpha})...")
|
| 535 |
+
|
| 536 |
+
result = {}
|
| 537 |
+
for key in target_state:
|
| 538 |
+
if key not in source_state:
|
| 539 |
+
result[key] = target_state[key]
|
| 540 |
+
continue
|
| 541 |
+
|
| 542 |
+
target_w = target_state[key]
|
| 543 |
+
source_w = source_state[key]
|
| 544 |
+
|
| 545 |
+
if source_w.shape != target_w.shape:
|
| 546 |
+
result[key] = target_state[key]
|
| 547 |
+
continue
|
| 548 |
+
|
| 549 |
+
if key in rl_mask and key in shared_mask:
|
| 550 |
+
rl_m = rl_mask[key].to(target_w.device)
|
| 551 |
+
# RL-specific: use higher alpha (preserve RL knowledge)
|
| 552 |
+
# Shared: use normal alpha
|
| 553 |
+
alpha_map = torch.where(rl_m, rl_alpha, shared_alpha)
|
| 554 |
+
if alpha_map.shape != target_w.shape:
|
| 555 |
+
alpha_map = alpha_map.expand_as(target_w) if alpha_map.dim() > 0 else torch.full_like(target_w, shared_alpha)
|
| 556 |
+
|
| 557 |
+
result[key] = alpha_map * source_w.to(target_w.device) + (1 - alpha_map) * target_w
|
| 558 |
+
else:
|
| 559 |
+
result[key] = shared_alpha * source_w.to(target_w.device) + (1 - shared_alpha) * target_w
|
| 560 |
+
|
| 561 |
+
return result
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
# ============================================================================
|
| 565 |
+
# 5. MERGEABILITY PRE-CHECK (2601.22285)
|
| 566 |
+
# ============================================================================
|
| 567 |
+
#
|
| 568 |
+
# Before spending GPU hours on a merge that might fail, check if the
|
| 569 |
+
# models are actually COMPATIBLE enough to merge.
|
| 570 |
+
#
|
| 571 |
+
# Mergeability score: 0.0 (definitely won't work) to 1.0 (should work great)
|
| 572 |
+
|
| 573 |
+
def compute_mergeability_score(
|
| 574 |
+
source_activations: dict,
|
| 575 |
+
target_activations: dict,
|
| 576 |
+
source_config: ModelConfig,
|
| 577 |
+
) -> dict:
|
| 578 |
+
"""
|
| 579 |
+
Predict how well a source model will merge into the target.
|
| 580 |
+
|
| 581 |
+
Scores based on three factors:
|
| 582 |
+
1. Activation similarity (cosine similarity of mean activations)
|
| 583 |
+
2. Dimensional compatibility (how similar are the layer shapes)
|
| 584 |
+
3. Architecture match (same arch = bonus)
|
| 585 |
+
|
| 586 |
+
Returns:
|
| 587 |
+
Dict with individual scores and overall mergeability (0-1)
|
| 588 |
+
"""
|
| 589 |
+
print(f"[mergeability] Scoring {source_config.name}...")
|
| 590 |
+
|
| 591 |
+
scores = {}
|
| 592 |
+
|
| 593 |
+
# --- Factor 1: Activation similarity ---
|
| 594 |
+
cosine_sims = []
|
| 595 |
+
source_layers = sorted(source_activations.keys())
|
| 596 |
+
target_layers = sorted(target_activations.keys())
|
| 597 |
+
|
| 598 |
+
# Match layers by position (proportional mapping)
|
| 599 |
+
for i, tl in enumerate(target_layers):
|
| 600 |
+
# Map target layer index to source layer index
|
| 601 |
+
src_idx = int(i * len(source_layers) / len(target_layers))
|
| 602 |
+
src_idx = min(src_idx, len(source_layers) - 1)
|
| 603 |
+
sl = source_layers[src_idx]
|
| 604 |
+
|
| 605 |
+
if sl in source_activations and tl in target_activations:
|
| 606 |
+
s_mean = source_activations[sl].float().mean(dim=0)
|
| 607 |
+
t_mean = target_activations[tl].float().mean(dim=0)
|
| 608 |
+
|
| 609 |
+
# Pad to same dimension for cosine similarity
|
| 610 |
+
max_dim = max(s_mean.shape[0], t_mean.shape[0])
|
| 611 |
+
s_padded = torch.nn.functional.pad(s_mean, (0, max_dim - s_mean.shape[0]))
|
| 612 |
+
t_padded = torch.nn.functional.pad(t_mean, (0, max_dim - t_mean.shape[0]))
|
| 613 |
+
|
| 614 |
+
cos_sim = torch.nn.functional.cosine_similarity(
|
| 615 |
+
s_padded.unsqueeze(0), t_padded.unsqueeze(0)
|
| 616 |
+
).item()
|
| 617 |
+
cosine_sims.append(cos_sim)
|
| 618 |
+
|
| 619 |
+
activation_score = np.mean(cosine_sims) if cosine_sims else 0.0
|
| 620 |
+
scores["activation_similarity"] = float(activation_score)
|
| 621 |
+
|
| 622 |
+
# --- Factor 2: Dimensional compatibility ---
|
| 623 |
+
layer_ratio = min(source_config.layers, 36) / max(source_config.layers, 36)
|
| 624 |
+
hidden_ratio = min(source_config.hidden_dim, 4096) / max(source_config.hidden_dim, 4096)
|
| 625 |
+
dim_score = (layer_ratio + hidden_ratio) / 2
|
| 626 |
+
scores["dimensional_compatibility"] = float(dim_score)
|
| 627 |
+
|
| 628 |
+
# --- Factor 3: Architecture match ---
|
| 629 |
+
arch_scores = {
|
| 630 |
+
"transformer": 1.0, # Same as Qwen3
|
| 631 |
+
"transformer+mtp": 0.8, # Close, just drop extras
|
| 632 |
+
"hybrid_ssm": 0.5, # Very different
|
| 633 |
+
}
|
| 634 |
+
arch_score = arch_scores.get(source_config.architecture, 0.3)
|
| 635 |
+
scores["architecture_match"] = float(arch_score)
|
| 636 |
+
|
| 637 |
+
# --- Factor 4: Vocab overlap (bonus) ---
|
| 638 |
+
vocab_score = source_config.vocab_overlap_with_qwen3
|
| 639 |
+
scores["vocab_overlap"] = float(vocab_score)
|
| 640 |
+
|
| 641 |
+
# --- Overall: weighted average ---
|
| 642 |
+
overall = (
|
| 643 |
+
0.35 * activation_score + # Most important — actual representation similarity
|
| 644 |
+
0.25 * dim_score + # Shape compatibility
|
| 645 |
+
0.25 * arch_score + # Architecture type
|
| 646 |
+
0.15 * vocab_score # Vocab overlap
|
| 647 |
+
)
|
| 648 |
+
scores["overall"] = float(overall)
|
| 649 |
+
|
| 650 |
+
# --- Recommendation ---
|
| 651 |
+
if overall >= 0.7:
|
| 652 |
+
recommendation = "GO — standard T&M merge"
|
| 653 |
+
elif overall >= 0.5:
|
| 654 |
+
recommendation = "CAUTION — T&M merge with higher protection, have Theseus fallback ready"
|
| 655 |
+
elif overall >= 0.3:
|
| 656 |
+
recommendation = "RISKY — try Theseus first, distillation fallback"
|
| 657 |
+
else:
|
| 658 |
+
recommendation = "SKIP — use knowledge distillation instead"
|
| 659 |
+
|
| 660 |
+
scores["recommendation"] = recommendation
|
| 661 |
+
|
| 662 |
+
print(f"[mergeability] {source_config.name} score: {overall:.2f}")
|
| 663 |
+
print(f" Activation similarity: {activation_score:.2f}")
|
| 664 |
+
print(f" Dimensional compat: {dim_score:.2f}")
|
| 665 |
+
print(f" Architecture match: {arch_score:.2f}")
|
| 666 |
+
print(f" Vocab overlap: {vocab_score:.2f}")
|
| 667 |
+
print(f" → {recommendation}")
|
| 668 |
+
|
| 669 |
+
return scores
|
hugging/td_fuse/transport.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Transport and Merge Wrapper — interfaces with official T&M code.
|
| 3 |
+
|
| 4 |
+
This wraps the official repo at:
|
| 5 |
+
github.com/chenhangcuisg-code/Cross-Architecture-Merging-for-Large-Language-Models/
|
| 6 |
+
|
| 7 |
+
We use THEIR code for:
|
| 8 |
+
- Correlation distance computation (corr_distance_matrix)
|
| 9 |
+
- Streaming Sinkhorn (sinkhorn_uniform_streaming)
|
| 10 |
+
- Transport plan computation (compute_P, compute_Q_and_layer_costs)
|
| 11 |
+
- Activation reconstruction (reconstruct_X)
|
| 12 |
+
|
| 13 |
+
We add:
|
| 14 |
+
- Qwen3 thinking mode protection
|
| 15 |
+
- MiMo MTP head handling
|
| 16 |
+
- Falcon SSM component handling
|
| 17 |
+
- Sequential merge protection (MagMax + orthogonal projection)
|
| 18 |
+
|
| 19 |
+
Findings: #01, #07, #24
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import sys
|
| 23 |
+
import torch
|
| 24 |
+
import numpy as np
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Optional
|
| 27 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 28 |
+
from datasets import load_dataset
|
| 29 |
+
|
| 30 |
+
from .config import MergeConfig, ModelConfig, TARGET
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def setup_tm_repo(cfg: MergeConfig):
|
| 34 |
+
"""Add official T&M repo to Python path so we can import their code."""
|
| 35 |
+
repo_path = Path(cfg.tm_repo_path)
|
| 36 |
+
core_path = repo_path / "core"
|
| 37 |
+
|
| 38 |
+
if not core_path.exists():
|
| 39 |
+
raise FileNotFoundError(
|
| 40 |
+
f"Official T&M repo not found at {repo_path}\n"
|
| 41 |
+
f"Please clone it:\n"
|
| 42 |
+
f" git clone https://github.com/chenhangcuisg-code/"
|
| 43 |
+
f"Cross-Architecture-Merging-for-Large-Language-Models.git"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Add to path so we can import hot_transport etc.
|
| 47 |
+
if str(core_path) not in sys.path:
|
| 48 |
+
sys.path.insert(0, str(core_path))
|
| 49 |
+
print(f"[transport] Added T&M core to path: {core_path}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
| 53 |
+
"""
|
| 54 |
+
Load calibration data for activation extraction.
|
| 55 |
+
|
| 56 |
+
Mix: 600 Pile general + 300 Pile ArXiv + 600 neuralmagic Q&A = 1500 samples
|
| 57 |
+
Each sample truncated to cfg.calibration_seq_len tokens.
|
| 58 |
+
|
| 59 |
+
Findings: #08
|
| 60 |
+
"""
|
| 61 |
+
print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
|
| 62 |
+
|
| 63 |
+
samples = []
|
| 64 |
+
|
| 65 |
+
# --- Pile: general text (600 samples) ---
|
| 66 |
+
try:
|
| 67 |
+
pile = load_dataset(
|
| 68 |
+
cfg.calibration_dataset_pile,
|
| 69 |
+
split="validation",
|
| 70 |
+
streaming=True,
|
| 71 |
+
trust_remote_code=True,
|
| 72 |
+
)
|
| 73 |
+
count = 0
|
| 74 |
+
for example in pile:
|
| 75 |
+
if count >= 600:
|
| 76 |
+
break
|
| 77 |
+
text = example.get("text", "")
|
| 78 |
+
if len(text) > 100: # Skip very short texts
|
| 79 |
+
tokens = tokenizer(
|
| 80 |
+
text,
|
| 81 |
+
truncation=True,
|
| 82 |
+
max_length=cfg.calibration_seq_len,
|
| 83 |
+
return_tensors="pt",
|
| 84 |
+
)
|
| 85 |
+
samples.append(tokens)
|
| 86 |
+
count += 1
|
| 87 |
+
print(f" Pile general: {count} samples")
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f" ⚠ Pile failed: {e}")
|
| 90 |
+
print(f" Falling back to neuralmagic only")
|
| 91 |
+
|
| 92 |
+
# --- neuralmagic: Q&A calibration (up to remaining) ---
|
| 93 |
+
remaining = cfg.calibration_samples - len(samples)
|
| 94 |
+
if remaining > 0:
|
| 95 |
+
try:
|
| 96 |
+
nm = load_dataset(
|
| 97 |
+
cfg.calibration_dataset_nm,
|
| 98 |
+
split="train",
|
| 99 |
+
trust_remote_code=True,
|
| 100 |
+
)
|
| 101 |
+
count = 0
|
| 102 |
+
for example in nm:
|
| 103 |
+
if count >= remaining:
|
| 104 |
+
break
|
| 105 |
+
text = example.get("text", example.get("content", ""))
|
| 106 |
+
if len(str(text)) > 50:
|
| 107 |
+
tokens = tokenizer(
|
| 108 |
+
str(text),
|
| 109 |
+
truncation=True,
|
| 110 |
+
max_length=cfg.calibration_seq_len,
|
| 111 |
+
return_tensors="pt",
|
| 112 |
+
)
|
| 113 |
+
samples.append(tokens)
|
| 114 |
+
count += 1
|
| 115 |
+
print(f" neuralmagic: {count} samples")
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f" ⚠ neuralmagic failed: {e}")
|
| 118 |
+
|
| 119 |
+
print(f"[transport] Total calibration samples: {len(samples)}")
|
| 120 |
+
return samples
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def extract_activations(
|
| 124 |
+
model: AutoModelForCausalLM,
|
| 125 |
+
calibration_data: list,
|
| 126 |
+
device: str = "cuda",
|
| 127 |
+
) -> dict:
|
| 128 |
+
"""
|
| 129 |
+
Extract intermediate activations from each layer of a model.
|
| 130 |
+
|
| 131 |
+
Runs calibration data through the model with hooks on each layer
|
| 132 |
+
to capture activation patterns. These activations are what the
|
| 133 |
+
optimal transport algorithm aligns between source and target.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Dict mapping layer_name → activation tensor [num_samples, hidden_dim]
|
| 137 |
+
"""
|
| 138 |
+
print(f"[transport] Extracting activations from {len(calibration_data)} samples...")
|
| 139 |
+
|
| 140 |
+
activations = {}
|
| 141 |
+
hooks = []
|
| 142 |
+
|
| 143 |
+
# Register hooks on each transformer layer
|
| 144 |
+
for name, module in model.named_modules():
|
| 145 |
+
if hasattr(module, "self_attn") or name.endswith(".mlp"):
|
| 146 |
+
# Hook to capture output activations
|
| 147 |
+
def make_hook(layer_name):
|
| 148 |
+
def hook_fn(module, input, output):
|
| 149 |
+
# Handle tuple outputs (some layers return tuples)
|
| 150 |
+
if isinstance(output, tuple):
|
| 151 |
+
act = output[0]
|
| 152 |
+
else:
|
| 153 |
+
act = output
|
| 154 |
+
if layer_name not in activations:
|
| 155 |
+
activations[layer_name] = []
|
| 156 |
+
# Mean pool over sequence length → [hidden_dim]
|
| 157 |
+
activations[layer_name].append(
|
| 158 |
+
act.detach().float().mean(dim=1).cpu()
|
| 159 |
+
)
|
| 160 |
+
return hook_fn
|
| 161 |
+
|
| 162 |
+
h = module.register_forward_hook(make_hook(name))
|
| 163 |
+
hooks.append(h)
|
| 164 |
+
|
| 165 |
+
# Forward pass on calibration data
|
| 166 |
+
model.eval()
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
for i, tokens in enumerate(calibration_data):
|
| 169 |
+
inputs = {k: v.to(device) for k, v in tokens.items()}
|
| 170 |
+
try:
|
| 171 |
+
model(**inputs)
|
| 172 |
+
except Exception as e:
|
| 173 |
+
print(f" ⚠ Sample {i} failed: {e}")
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
if (i + 1) % 100 == 0:
|
| 177 |
+
print(f" Processed {i + 1}/{len(calibration_data)} samples")
|
| 178 |
+
|
| 179 |
+
# Remove hooks
|
| 180 |
+
for h in hooks:
|
| 181 |
+
h.remove()
|
| 182 |
+
|
| 183 |
+
# Stack activations: [num_samples, hidden_dim]
|
| 184 |
+
for key in activations:
|
| 185 |
+
activations[key] = torch.cat(activations[key], dim=0)
|
| 186 |
+
print(f" {key}: {activations[key].shape}")
|
| 187 |
+
|
| 188 |
+
return activations
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def compute_transport_plans(
|
| 192 |
+
source_activations: dict,
|
| 193 |
+
target_activations: dict,
|
| 194 |
+
cfg: MergeConfig,
|
| 195 |
+
) -> dict:
|
| 196 |
+
"""
|
| 197 |
+
Compute optimal transport plans between source and target activations.
|
| 198 |
+
|
| 199 |
+
This is where the magic happens. We use the official T&M code's:
|
| 200 |
+
- corr_distance_matrix: correlation distance between activation vectors
|
| 201 |
+
- sinkhorn_uniform_streaming: memory-efficient Sinkhorn solver
|
| 202 |
+
- compute_P: layer-level coupling (which source layers → which target layers)
|
| 203 |
+
- compute_Q_and_layer_costs: neuron-level coupling within each layer pair
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Dict with 'P' (layer coupling) and 'Q' (per-layer neuron coupling) matrices
|
| 207 |
+
"""
|
| 208 |
+
print("[transport] Computing transport plans...")
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
# Try importing official T&M code
|
| 212 |
+
from hot_transport import (
|
| 213 |
+
corr_distance_matrix,
|
| 214 |
+
sinkhorn_uniform_streaming,
|
| 215 |
+
compute_P,
|
| 216 |
+
compute_Q_and_layer_costs,
|
| 217 |
+
)
|
| 218 |
+
print("[transport] Using official T&M implementation")
|
| 219 |
+
return _compute_plans_official(
|
| 220 |
+
source_activations, target_activations, cfg,
|
| 221 |
+
corr_distance_matrix, sinkhorn_uniform_streaming,
|
| 222 |
+
compute_P, compute_Q_and_layer_costs,
|
| 223 |
+
)
|
| 224 |
+
except ImportError:
|
| 225 |
+
print("[transport] Official T&M code not available, using fallback")
|
| 226 |
+
return _compute_plans_fallback(
|
| 227 |
+
source_activations, target_activations, cfg
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _compute_plans_official(
|
| 232 |
+
source_act, target_act, cfg,
|
| 233 |
+
corr_distance_matrix, sinkhorn_uniform_streaming,
|
| 234 |
+
compute_P, compute_Q_and_layer_costs,
|
| 235 |
+
) -> dict:
|
| 236 |
+
"""Use the official T&M code to compute transport plans."""
|
| 237 |
+
|
| 238 |
+
# Get matching layer pairs
|
| 239 |
+
source_layers = sorted(source_act.keys())
|
| 240 |
+
target_layers = sorted(target_act.keys())
|
| 241 |
+
|
| 242 |
+
# Compute Q matrices (neuron-level) and layer costs
|
| 243 |
+
Q_matrices, layer_costs = compute_Q_and_layer_costs(
|
| 244 |
+
source_act, target_act,
|
| 245 |
+
source_layers, target_layers,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Compute P matrix (layer-level coupling)
|
| 249 |
+
P = compute_P(layer_costs)
|
| 250 |
+
|
| 251 |
+
return {
|
| 252 |
+
"P": P,
|
| 253 |
+
"Q": Q_matrices,
|
| 254 |
+
"source_layers": source_layers,
|
| 255 |
+
"target_layers": target_layers,
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _compute_plans_fallback(
|
| 260 |
+
source_act: dict,
|
| 261 |
+
target_act: dict,
|
| 262 |
+
cfg: MergeConfig,
|
| 263 |
+
) -> dict:
|
| 264 |
+
"""
|
| 265 |
+
Fallback transport plan computation when official code isn't available.
|
| 266 |
+
|
| 267 |
+
Uses correlation distance + basic Sinkhorn. Less optimised than official
|
| 268 |
+
but functionally correct for testing.
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
source_layers = sorted(source_act.keys())
|
| 272 |
+
target_layers = sorted(target_act.keys())
|
| 273 |
+
|
| 274 |
+
# --- Step 1: Correlation distance matrices per layer pair ---
|
| 275 |
+
Q_matrices = {}
|
| 276 |
+
layer_costs = np.zeros((len(source_layers), len(target_layers)))
|
| 277 |
+
|
| 278 |
+
for i, sl in enumerate(source_layers):
|
| 279 |
+
for j, tl in enumerate(target_layers):
|
| 280 |
+
if sl not in source_act or tl not in target_act:
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
S = source_act[sl].numpy() # [samples, hidden_dim_source]
|
| 284 |
+
T = target_act[tl].numpy() # [samples, hidden_dim_target]
|
| 285 |
+
|
| 286 |
+
# Correlation distance: 1 - pearson_correlation
|
| 287 |
+
# Between each pair of neurons across samples
|
| 288 |
+
# S: [samples, n_source], T: [samples, n_target]
|
| 289 |
+
S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8)
|
| 290 |
+
T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8)
|
| 291 |
+
corr = S_norm.T @ T_norm / S.shape[0] # [n_source, n_target]
|
| 292 |
+
cost = 1.0 - corr # Correlation distance
|
| 293 |
+
|
| 294 |
+
# Basic Sinkhorn on this cost matrix
|
| 295 |
+
Q = _sinkhorn(cost, reg=cfg.sinkhorn_reg, max_iter=cfg.sinkhorn_max_iter)
|
| 296 |
+
Q_matrices[(sl, tl)] = Q
|
| 297 |
+
layer_costs[i, j] = cost.mean()
|
| 298 |
+
|
| 299 |
+
# --- Step 2: Layer coupling (P matrix) ---
|
| 300 |
+
P = _sinkhorn(layer_costs, reg=cfg.sinkhorn_reg, max_iter=cfg.sinkhorn_max_iter)
|
| 301 |
+
|
| 302 |
+
return {
|
| 303 |
+
"P": P,
|
| 304 |
+
"Q": Q_matrices,
|
| 305 |
+
"source_layers": source_layers,
|
| 306 |
+
"target_layers": target_layers,
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _sinkhorn(
|
| 311 |
+
cost_matrix: np.ndarray,
|
| 312 |
+
reg: float = 0.05,
|
| 313 |
+
max_iter: int = 100,
|
| 314 |
+
) -> np.ndarray:
|
| 315 |
+
"""
|
| 316 |
+
Basic Sinkhorn-Knopp algorithm for optimal transport.
|
| 317 |
+
|
| 318 |
+
Solves: min <T, C> - reg * H(T)
|
| 319 |
+
where H(T) is the entropy of the transport plan.
|
| 320 |
+
|
| 321 |
+
This is the FALLBACK. The official code uses streaming Sinkhorn
|
| 322 |
+
which is more memory-efficient.
|
| 323 |
+
"""
|
| 324 |
+
n, m = cost_matrix.shape
|
| 325 |
+
K = np.exp(-cost_matrix / reg)
|
| 326 |
+
|
| 327 |
+
u = np.ones(n) / n
|
| 328 |
+
v = np.ones(m) / m
|
| 329 |
+
|
| 330 |
+
for _ in range(max_iter):
|
| 331 |
+
u = 1.0 / (K @ v + 1e-10)
|
| 332 |
+
v = 1.0 / (K.T @ u + 1e-10)
|
| 333 |
+
|
| 334 |
+
# Transport plan
|
| 335 |
+
T = np.diag(u) @ K @ np.diag(v)
|
| 336 |
+
return T
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def fuse_weights(
|
| 340 |
+
source_model: AutoModelForCausalLM,
|
| 341 |
+
target_model: AutoModelForCausalLM,
|
| 342 |
+
transport_plans: dict,
|
| 343 |
+
source_config: ModelConfig,
|
| 344 |
+
cfg: MergeConfig,
|
| 345 |
+
) -> AutoModelForCausalLM:
|
| 346 |
+
"""
|
| 347 |
+
Fuse source model weights into target model using transport plans.
|
| 348 |
+
|
| 349 |
+
For each layer pair with significant coupling (P > threshold):
|
| 350 |
+
1. Get the Q matrix (neuron-level correspondence)
|
| 351 |
+
2. Transport source weights into target neuron basis: W_fused = Q @ W_source
|
| 352 |
+
3. Blend with target: W_final = alpha * W_fused + (1-alpha) * W_target
|
| 353 |
+
|
| 354 |
+
Special handling per model:
|
| 355 |
+
- DeepSeek: Direct merge (same architecture)
|
| 356 |
+
- MiMo: Skip MTP heads, skip embeddings
|
| 357 |
+
- Llama: Layer mapping (32→36), skip embeddings, drop QKV bias
|
| 358 |
+
- Falcon: Skip Mamba components, skip embeddings
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
Target model with fused weights
|
| 362 |
+
"""
|
| 363 |
+
print(f"\n[transport] Fusing {source_config.name} → target")
|
| 364 |
+
alpha = source_config.merge_alpha
|
| 365 |
+
|
| 366 |
+
try:
|
| 367 |
+
# Try official fusion code first
|
| 368 |
+
from generate_hot_residual import fuse_attention_only_from_hot_dir
|
| 369 |
+
print("[transport] Using official fusion implementation")
|
| 370 |
+
# TODO: Adapt official fusion to our pipeline
|
| 371 |
+
# For now, fall through to manual fusion
|
| 372 |
+
except ImportError:
|
| 373 |
+
pass
|
| 374 |
+
|
| 375 |
+
# --- Manual fusion using transport plans ---
|
| 376 |
+
source_state = source_model.state_dict()
|
| 377 |
+
target_state = target_model.state_dict()
|
| 378 |
+
P = transport_plans["P"]
|
| 379 |
+
Q = transport_plans["Q"]
|
| 380 |
+
|
| 381 |
+
fused_count = 0
|
| 382 |
+
skipped_count = 0
|
| 383 |
+
|
| 384 |
+
for target_key in target_state:
|
| 385 |
+
# Skip parameters we shouldn't merge
|
| 386 |
+
if _should_skip(target_key, source_config):
|
| 387 |
+
skipped_count += 1
|
| 388 |
+
continue
|
| 389 |
+
|
| 390 |
+
# Find corresponding source key
|
| 391 |
+
source_key = _map_key(target_key, source_config)
|
| 392 |
+
if source_key is None or source_key not in source_state:
|
| 393 |
+
skipped_count += 1
|
| 394 |
+
continue
|
| 395 |
+
|
| 396 |
+
target_w = target_state[target_key]
|
| 397 |
+
source_w = source_state[source_key]
|
| 398 |
+
|
| 399 |
+
# Handle dimension mismatches
|
| 400 |
+
if target_w.shape != source_w.shape:
|
| 401 |
+
# Use transport plan to align dimensions
|
| 402 |
+
source_w = _align_dimensions(source_w, target_w.shape, Q, target_key)
|
| 403 |
+
if source_w is None:
|
| 404 |
+
skipped_count += 1
|
| 405 |
+
continue
|
| 406 |
+
|
| 407 |
+
# Blend: W_final = alpha * source + (1-alpha) * target
|
| 408 |
+
fused_w = alpha * source_w.to(target_w.device) + (1 - alpha) * target_w
|
| 409 |
+
target_state[target_key] = fused_w
|
| 410 |
+
fused_count += 1
|
| 411 |
+
|
| 412 |
+
# Apply thinking mode protection
|
| 413 |
+
if cfg.freeze_think_tokens and "embed_tokens" in target_key:
|
| 414 |
+
for token_id in cfg.think_token_ids:
|
| 415 |
+
if token_id < target_state["model.embed_tokens.weight"].shape[0]:
|
| 416 |
+
# Restore original embedding for think tokens
|
| 417 |
+
orig_embed = target_model.state_dict()["model.embed_tokens.weight"]
|
| 418 |
+
target_state["model.embed_tokens.weight"][token_id] = orig_embed[token_id]
|
| 419 |
+
print(f"[transport] Protected think token {token_id}")
|
| 420 |
+
|
| 421 |
+
# Load fused weights
|
| 422 |
+
target_model.load_state_dict(target_state)
|
| 423 |
+
print(f"[transport] Fused {fused_count} params, skipped {skipped_count}")
|
| 424 |
+
|
| 425 |
+
return target_model
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def _should_skip(key: str, source_config: ModelConfig) -> bool:
|
| 429 |
+
"""Determine if a parameter should be skipped during merge."""
|
| 430 |
+
|
| 431 |
+
# Always skip if source model says to skip embeddings
|
| 432 |
+
if source_config.skip_embeddings and ("embed_tokens" in key or "lm_head" in key):
|
| 433 |
+
return True
|
| 434 |
+
|
| 435 |
+
# Skip MiMo MTP heads
|
| 436 |
+
if "drop_mtp_heads" in source_config.special_handling and "mtp_head" in key:
|
| 437 |
+
return True
|
| 438 |
+
|
| 439 |
+
# Skip Falcon Mamba-specific parameters
|
| 440 |
+
if "drop_mamba_state_params" in source_config.special_handling:
|
| 441 |
+
mamba_keys = ["mamba", "A_log", "dt_proj", ".D"]
|
| 442 |
+
if any(mk in key for mk in mamba_keys):
|
| 443 |
+
return True
|
| 444 |
+
|
| 445 |
+
# Skip QKV bias for Llama (Qwen3 doesn't have it)
|
| 446 |
+
if "drop_qkv_bias" in source_config.special_handling and ".bias" in key:
|
| 447 |
+
if any(proj in key for proj in ["q_proj", "k_proj", "v_proj"]):
|
| 448 |
+
return True
|
| 449 |
+
|
| 450 |
+
return False
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]:
|
| 454 |
+
"""Map a target model parameter name to the corresponding source name."""
|
| 455 |
+
|
| 456 |
+
# For same-architecture models (DeepSeek), keys match directly
|
| 457 |
+
if source_config.architecture == "transformer" and source_config.layers == 36:
|
| 458 |
+
return target_key
|
| 459 |
+
|
| 460 |
+
# For Llama (32 layers → 36 layers), map layer indices
|
| 461 |
+
if "layer_mapping_32_to_36" in source_config.special_handling:
|
| 462 |
+
if "model.layers." in target_key:
|
| 463 |
+
# Extract layer number
|
| 464 |
+
parts = target_key.split(".")
|
| 465 |
+
try:
|
| 466 |
+
layer_idx = int(parts[2])
|
| 467 |
+
except (IndexError, ValueError):
|
| 468 |
+
return target_key
|
| 469 |
+
|
| 470 |
+
# Map 36 target layers to 32 source layers (stride)
|
| 471 |
+
source_layer = int(layer_idx * 32 / 36)
|
| 472 |
+
parts[2] = str(source_layer)
|
| 473 |
+
return ".".join(parts)
|
| 474 |
+
|
| 475 |
+
# For MiMo (same layer count, different extras), keys mostly match
|
| 476 |
+
if source_config.architecture == "transformer+mtp":
|
| 477 |
+
if "mtp_head" in target_key:
|
| 478 |
+
return None # MTP heads don't exist in target
|
| 479 |
+
return target_key
|
| 480 |
+
|
| 481 |
+
# For Falcon hybrid, only attention and MLP keys map
|
| 482 |
+
if source_config.architecture == "hybrid_ssm":
|
| 483 |
+
if any(k in target_key for k in ["self_attn", "mlp", "layer_norm"]):
|
| 484 |
+
return target_key # These exist in both
|
| 485 |
+
return None # Mamba components don't map
|
| 486 |
+
|
| 487 |
+
return target_key
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def _align_dimensions(
|
| 491 |
+
source_w: torch.Tensor,
|
| 492 |
+
target_shape: tuple,
|
| 493 |
+
Q_matrices: dict,
|
| 494 |
+
key: str,
|
| 495 |
+
) -> Optional[torch.Tensor]:
|
| 496 |
+
"""
|
| 497 |
+
Align source weight dimensions to target shape using transport plans.
|
| 498 |
+
|
| 499 |
+
For small mismatches: pad or truncate.
|
| 500 |
+
For large mismatches: use Q matrix to project.
|
| 501 |
+
"""
|
| 502 |
+
if source_w.shape == target_shape:
|
| 503 |
+
return source_w
|
| 504 |
+
|
| 505 |
+
# Simple case: different width (FFN size difference)
|
| 506 |
+
if len(source_w.shape) == 2 and len(target_shape) == 2:
|
| 507 |
+
s_rows, s_cols = source_w.shape
|
| 508 |
+
t_rows, t_cols = target_shape
|
| 509 |
+
|
| 510 |
+
result = torch.zeros(target_shape, dtype=source_w.dtype)
|
| 511 |
+
|
| 512 |
+
# Copy what fits
|
| 513 |
+
min_rows = min(s_rows, t_rows)
|
| 514 |
+
min_cols = min(s_cols, t_cols)
|
| 515 |
+
result[:min_rows, :min_cols] = source_w[:min_rows, :min_cols]
|
| 516 |
+
|
| 517 |
+
return result
|
| 518 |
+
|
| 519 |
+
# 1D case (biases, layer norms)
|
| 520 |
+
if len(source_w.shape) == 1 and len(target_shape) == 1:
|
| 521 |
+
result = torch.zeros(target_shape, dtype=source_w.dtype)
|
| 522 |
+
min_len = min(source_w.shape[0], target_shape[0])
|
| 523 |
+
result[:min_len] = source_w[:min_len]
|
| 524 |
+
return result
|
| 525 |
+
|
| 526 |
+
# Can't align — skip this parameter
|
| 527 |
+
return None
|
hugging/td_fuse/validate.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Post-Merge Validation — run after EVERY merge step.
|
| 3 |
+
|
| 4 |
+
Tests:
|
| 5 |
+
1. Canary recall (did knowledge transfer?)
|
| 6 |
+
2. Perplexity check (did we break the model?)
|
| 7 |
+
3. Thinking mode (do <think> tags still work?)
|
| 8 |
+
4. Quick reasoning test (can it still think?)
|
| 9 |
+
|
| 10 |
+
Kill criteria: >10% performance drop on any test → abort merge.
|
| 11 |
+
Findings: #11, #22, #25
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import math
|
| 16 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 17 |
+
|
| 18 |
+
from .canary import test_all_canaries
|
| 19 |
+
from .config import MergeConfig
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def validate_merged_model(
|
| 23 |
+
model: AutoModelForCausalLM,
|
| 24 |
+
tokenizer: AutoTokenizer,
|
| 25 |
+
merged_sources: list[str],
|
| 26 |
+
cfg: MergeConfig,
|
| 27 |
+
baseline_perplexity: float = None,
|
| 28 |
+
) -> dict:
|
| 29 |
+
"""
|
| 30 |
+
Run full validation suite on a merged model.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model: The merged model to validate
|
| 34 |
+
tokenizer: The tokenizer
|
| 35 |
+
merged_sources: List of source models merged so far
|
| 36 |
+
cfg: Merge configuration
|
| 37 |
+
baseline_perplexity: Perplexity of the target model before merging
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Dict with test results and overall pass/fail
|
| 41 |
+
"""
|
| 42 |
+
print("\n" + "=" * 60)
|
| 43 |
+
print(f"VALIDATION — After merging: {', '.join(merged_sources)}")
|
| 44 |
+
print("=" * 60)
|
| 45 |
+
|
| 46 |
+
results = {
|
| 47 |
+
"canary": None,
|
| 48 |
+
"perplexity": None,
|
| 49 |
+
"thinking_mode": None,
|
| 50 |
+
"reasoning": None,
|
| 51 |
+
"overall": False,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
# --- Test 1: Canary recall ---
|
| 55 |
+
canary_results = test_all_canaries(model, tokenizer, merged_sources)
|
| 56 |
+
passed_canaries = sum(1 for v in canary_results.values() if v)
|
| 57 |
+
total_canaries = len(canary_results)
|
| 58 |
+
results["canary"] = {
|
| 59 |
+
"passed": passed_canaries,
|
| 60 |
+
"total": total_canaries,
|
| 61 |
+
"ok": passed_canaries >= cfg.canary_pass_threshold,
|
| 62 |
+
"details": canary_results,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# --- Test 2: Perplexity ---
|
| 66 |
+
perplexity = compute_perplexity(model, tokenizer)
|
| 67 |
+
ppl_ok = True
|
| 68 |
+
if baseline_perplexity is not None:
|
| 69 |
+
ratio = perplexity / baseline_perplexity
|
| 70 |
+
ppl_ok = ratio < cfg.perplexity_threshold
|
| 71 |
+
print(f"\n[validate] Perplexity: {perplexity:.2f} (baseline: {baseline_perplexity:.2f}, ratio: {ratio:.2f})")
|
| 72 |
+
if not ppl_ok:
|
| 73 |
+
print(f"[validate] ⚠ Perplexity ratio {ratio:.2f} exceeds threshold {cfg.perplexity_threshold}")
|
| 74 |
+
else:
|
| 75 |
+
print(f"\n[validate] Perplexity: {perplexity:.2f} (no baseline to compare)")
|
| 76 |
+
results["perplexity"] = {"value": perplexity, "ok": ppl_ok}
|
| 77 |
+
|
| 78 |
+
# --- Test 3: Thinking mode ---
|
| 79 |
+
think_ok = test_thinking_mode(model, tokenizer)
|
| 80 |
+
results["thinking_mode"] = {"ok": think_ok}
|
| 81 |
+
|
| 82 |
+
# --- Test 4: Quick reasoning ---
|
| 83 |
+
reason_ok = test_reasoning(model, tokenizer)
|
| 84 |
+
results["reasoning"] = {"ok": reason_ok}
|
| 85 |
+
|
| 86 |
+
# --- Overall verdict ---
|
| 87 |
+
all_ok = (
|
| 88 |
+
results["canary"]["ok"]
|
| 89 |
+
and results["perplexity"]["ok"]
|
| 90 |
+
and results["thinking_mode"]["ok"]
|
| 91 |
+
and results["reasoning"]["ok"]
|
| 92 |
+
)
|
| 93 |
+
results["overall"] = all_ok
|
| 94 |
+
|
| 95 |
+
# Summary
|
| 96 |
+
print("\n" + "-" * 60)
|
| 97 |
+
print("VALIDATION SUMMARY")
|
| 98 |
+
print("-" * 60)
|
| 99 |
+
print(f" Canary recall: {'✓' if results['canary']['ok'] else '✗'} ({passed_canaries}/{total_canaries})")
|
| 100 |
+
print(f" Perplexity: {'✓' if ppl_ok else '✗'} ({perplexity:.2f})")
|
| 101 |
+
print(f" Thinking mode: {'✓' if think_ok else '✗'}")
|
| 102 |
+
print(f" Reasoning: {'✓' if reason_ok else '✗'}")
|
| 103 |
+
print(f" OVERALL: {'✓ PASS' if all_ok else '✗ FAIL — consider aborting'}")
|
| 104 |
+
print("-" * 60)
|
| 105 |
+
|
| 106 |
+
return results
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def compute_perplexity(
|
| 110 |
+
model: AutoModelForCausalLM,
|
| 111 |
+
tokenizer: AutoTokenizer,
|
| 112 |
+
test_texts: list[str] = None,
|
| 113 |
+
) -> float:
|
| 114 |
+
"""
|
| 115 |
+
Compute perplexity on a small test set.
|
| 116 |
+
|
| 117 |
+
Lower perplexity = model is more confident about predicting text.
|
| 118 |
+
A big spike after merging means the model was damaged.
|
| 119 |
+
"""
|
| 120 |
+
if test_texts is None:
|
| 121 |
+
test_texts = [
|
| 122 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 123 |
+
"In mathematics, a prime number is a natural number greater than 1.",
|
| 124 |
+
"def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)",
|
| 125 |
+
"The theory of general relativity describes gravity as the curvature of spacetime.",
|
| 126 |
+
"To solve 3x + 7 = 22, subtract 7 from both sides to get 3x = 15, then divide by 3.",
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
model.eval()
|
| 130 |
+
total_loss = 0.0
|
| 131 |
+
total_tokens = 0
|
| 132 |
+
|
| 133 |
+
for text in test_texts:
|
| 134 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
|
| 135 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 136 |
+
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
outputs = model(**inputs, labels=inputs["input_ids"])
|
| 139 |
+
total_loss += outputs.loss.item() * inputs["input_ids"].shape[1]
|
| 140 |
+
total_tokens += inputs["input_ids"].shape[1]
|
| 141 |
+
|
| 142 |
+
avg_loss = total_loss / total_tokens
|
| 143 |
+
perplexity = math.exp(avg_loss)
|
| 144 |
+
return perplexity
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def test_thinking_mode(
|
| 148 |
+
model: AutoModelForCausalLM,
|
| 149 |
+
tokenizer: AutoTokenizer,
|
| 150 |
+
) -> bool:
|
| 151 |
+
"""
|
| 152 |
+
Test if the model still uses <think> tags for reasoning.
|
| 153 |
+
|
| 154 |
+
The thinking mode is Qwen3's special feature — if it's gone,
|
| 155 |
+
the merge damaged something critical.
|
| 156 |
+
"""
|
| 157 |
+
prompt = "Solve step by step: What is 15 × 13?"
|
| 158 |
+
|
| 159 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
outputs = model.generate(
|
| 162 |
+
**inputs,
|
| 163 |
+
max_new_tokens=200,
|
| 164 |
+
temperature=0.7,
|
| 165 |
+
do_sample=True,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
| 169 |
+
|
| 170 |
+
# Check for thinking tags
|
| 171 |
+
has_think_open = "<think>" in response
|
| 172 |
+
has_think_close = "</think>" in response
|
| 173 |
+
passed = has_think_open and has_think_close
|
| 174 |
+
|
| 175 |
+
print(f"\n[validate] Thinking mode test:")
|
| 176 |
+
print(f" Prompt: {prompt}")
|
| 177 |
+
print(f" Response: {response[:200]}...")
|
| 178 |
+
print(f" <think>: {'✓ found' if has_think_open else '✗ missing'}")
|
| 179 |
+
print(f" </think>: {'✓ found' if has_think_close else '✗ missing'}")
|
| 180 |
+
print(f" Status: {'✓ PASS' if passed else '✗ FAIL'}")
|
| 181 |
+
|
| 182 |
+
return passed
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def test_reasoning(
|
| 186 |
+
model: AutoModelForCausalLM,
|
| 187 |
+
tokenizer: AutoTokenizer,
|
| 188 |
+
) -> bool:
|
| 189 |
+
"""
|
| 190 |
+
Quick reasoning sanity check — can the model still do basic math?
|
| 191 |
+
|
| 192 |
+
This catches catastrophic failures where the merge produced gibberish.
|
| 193 |
+
"""
|
| 194 |
+
prompt = "What is 7 + 8?"
|
| 195 |
+
expected_answer = "15"
|
| 196 |
+
|
| 197 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 198 |
+
with torch.no_grad():
|
| 199 |
+
outputs = model.generate(
|
| 200 |
+
**inputs,
|
| 201 |
+
max_new_tokens=50,
|
| 202 |
+
temperature=0.1,
|
| 203 |
+
do_sample=False,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 207 |
+
passed = expected_answer in response
|
| 208 |
+
|
| 209 |
+
print(f"\n[validate] Quick reasoning test:")
|
| 210 |
+
print(f" Prompt: {prompt}")
|
| 211 |
+
print(f" Expected: {expected_answer}")
|
| 212 |
+
print(f" Got: {response}")
|
| 213 |
+
print(f" Status: {'✓ PASS' if passed else '✗ FAIL'}")
|
| 214 |
+
|
| 215 |
+
return passed
|
hugging/td_lang/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
hugging/td_lang/__init__.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Lang — Domain-specific language for Time Dilation project.
|
| 3 |
+
|
| 4 |
+
Compiles .td files into Python code that calls td_fuse.
|
| 5 |
+
Write simple scripts instead of complex Python.
|
| 6 |
+
|
| 7 |
+
Architecture:
|
| 8 |
+
td_lang/
|
| 9 |
+
├── __init__.py <- This file
|
| 10 |
+
├── __main__.py <- Entry point for python -m td_lang
|
| 11 |
+
├── grammar.py <- Lark grammar + parse tree transformer
|
| 12 |
+
├── ast_nodes.py <- Dataclass AST nodes for each command
|
| 13 |
+
├── compiler.py <- AST -> Python code generation
|
| 14 |
+
├── executor.py <- Run compiled code, track lineage
|
| 15 |
+
├── cli.py <- Command-line interface
|
| 16 |
+
├── errors.py <- Custom exceptions
|
| 17 |
+
└── examples/
|
| 18 |
+
├── demo_merge.td <- Basic merge example
|
| 19 |
+
├── demo_heal.td <- Merge + heal example
|
| 20 |
+
├── demo_full.td <- Full pipeline with gates + budget
|
| 21 |
+
├── demo_loop.td <- Self-improvement loop example
|
| 22 |
+
├── demo_phase3.td <- Fork/edit/prune/reset example
|
| 23 |
+
└── demo_phase4.td <- Contracts + snapshot + report example
|
| 24 |
+
|
| 25 |
+
Phase 1: load, merge, heal, eval, commit
|
| 26 |
+
Phase 2: diagnose, synth, train, debate
|
| 27 |
+
Phase 3: fork, reset, prune, edit
|
| 28 |
+
Phase 4: snapshot, report, data_contract, reward_contract
|
| 29 |
+
Phase 5: CLI polish, --version, info command, --verbose
|
| 30 |
+
|
| 31 |
+
Designed from interviews test_14 (10 commands) and test_17 (ForgeSpec 2.0).
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
from .grammar import parse_td_file, parse_td_string # noqa: F401
|
| 35 |
+
from .compiler import compile_program # noqa: F401
|
| 36 |
+
from .executor import TDExecutor, check_td_file, compile_td_file, run_td_file # noqa: F401
|
| 37 |
+
|
| 38 |
+
__version__ = "0.2.0"
|
| 39 |
+
__author__ = "Milan (TD Project)"
|
| 40 |
+
|
| 41 |
+
__all__ = [
|
| 42 |
+
"parse_td_file",
|
| 43 |
+
"parse_td_string",
|
| 44 |
+
"compile_program",
|
| 45 |
+
"TDExecutor",
|
| 46 |
+
"check_td_file",
|
| 47 |
+
"compile_td_file",
|
| 48 |
+
"run_td_file",
|
| 49 |
+
"__version__",
|
| 50 |
+
"__author__",
|
| 51 |
+
]
|
hugging/td_lang/__main__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Entry point for python -m td_lang."""
|
| 2 |
+
|
| 3 |
+
from .cli import main
|
| 4 |
+
|
| 5 |
+
main()
|
hugging/td_lang/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
hugging/td_lang/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (2.02 kB). View file
|
|
|
hugging/td_lang/__pycache__/__main__.cpython-310.pyc
ADDED
|
Binary file (254 Bytes). View file
|
|
|
hugging/td_lang/__pycache__/__main__.cpython-314.pyc
ADDED
|
Binary file (262 Bytes). View file
|
|
|
hugging/td_lang/__pycache__/ast_nodes.cpython-310.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
hugging/td_lang/__pycache__/ast_nodes.cpython-314.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
hugging/td_lang/__pycache__/cli.cpython-310.pyc
ADDED
|
Binary file (6.62 kB). View file
|
|
|
hugging/td_lang/__pycache__/cli.cpython-314.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
hugging/td_lang/__pycache__/compiler.cpython-310.pyc
ADDED
|
Binary file (88.7 kB). View file
|
|
|
hugging/td_lang/__pycache__/compiler.cpython-314.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8bef7388fef05cdd8ee4edcc72a4b8907c8637caa22cfc802da044470a515c92
|
| 3 |
+
size 162778
|
hugging/td_lang/__pycache__/errors.cpython-310.pyc
ADDED
|
Binary file (4.21 kB). View file
|
|
|
hugging/td_lang/__pycache__/errors.cpython-314.pyc
ADDED
|
Binary file (6.34 kB). View file
|
|
|
hugging/td_lang/__pycache__/executor.cpython-310.pyc
ADDED
|
Binary file (5.94 kB). View file
|
|
|
hugging/td_lang/__pycache__/executor.cpython-314.pyc
ADDED
|
Binary file (9.49 kB). View file
|
|
|
hugging/td_lang/__pycache__/grammar.cpython-310.pyc
ADDED
|
Binary file (25.4 kB). View file
|
|
|
hugging/td_lang/__pycache__/grammar.cpython-314.pyc
ADDED
|
Binary file (37.8 kB). View file
|
|
|
hugging/td_lang/ast_nodes.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Lang AST Nodes — Dataclass containers for each parsed command.
|
| 3 |
+
|
| 4 |
+
Each .td command becomes one of these nodes after parsing.
|
| 5 |
+
Phase 1 nodes are compiled into runnable Python; Phase 2 nodes are stubs so
|
| 6 |
+
the compiler can reject them with a clear error until they are implemented.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Any, List, Optional
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ============================================================================
|
| 14 |
+
# PHASE 1 COMMANDS
|
| 15 |
+
# ============================================================================
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class LoadCmd:
|
| 19 |
+
"""Load a model and give it a name.
|
| 20 |
+
|
| 21 |
+
Example: load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 22 |
+
"""
|
| 23 |
+
model_ref: str # HuggingFace path or local path
|
| 24 |
+
alias: str # Name to use in the rest of the script
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class MergeCmd:
|
| 29 |
+
"""Merge a source model into a target using a method.
|
| 30 |
+
|
| 31 |
+
Example: merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
|
| 32 |
+
"""
|
| 33 |
+
source: str # Model path or alias to merge from
|
| 34 |
+
target: str # Alias to merge into (must be loaded first)
|
| 35 |
+
method: str # "transport", "slerp", "ties", "dare"
|
| 36 |
+
strength: float = 0.5 # 0.0 = keep target, 1.0 = keep source
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class HealCmd:
|
| 41 |
+
"""Run QLoRA healing fine-tune on a model.
|
| 42 |
+
|
| 43 |
+
Example: heal base lora_r 32 epochs 2
|
| 44 |
+
"""
|
| 45 |
+
target: str # Alias of model to heal
|
| 46 |
+
lora_r: int = 32 # LoRA rank (higher = more capacity)
|
| 47 |
+
epochs: int = 2 # Training epochs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class EvalCmd:
|
| 52 |
+
"""Run validation/evaluation on a model.
|
| 53 |
+
|
| 54 |
+
Example: eval base on "pile_sample" -> report.json
|
| 55 |
+
"""
|
| 56 |
+
target: str # Alias of model to evaluate
|
| 57 |
+
dataset: Optional[str] = None # Optional dataset name/path
|
| 58 |
+
output: Optional[str] = None # Optional output file path
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class CommitCmd:
|
| 63 |
+
"""Save model checkpoint, optionally requiring gates to pass.
|
| 64 |
+
|
| 65 |
+
Example: commit base if [canary, perplexity, thinking_mode]
|
| 66 |
+
"""
|
| 67 |
+
target: str # Alias of model to commit
|
| 68 |
+
gates: Optional[list[str]] = None # Gate names that must pass
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ============================================================================
|
| 72 |
+
# PHASE 2 COMMANDS (placeholders — structure ready, not wired up yet)
|
| 73 |
+
# ============================================================================
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class SynthCmd:
|
| 77 |
+
"""Generate synthetic training data from a model. (Phase 2)"""
|
| 78 |
+
target: str
|
| 79 |
+
source: str
|
| 80 |
+
filter_method: Optional[str] = None
|
| 81 |
+
output: Optional[str] = None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclass
|
| 85 |
+
class TrainCmd:
|
| 86 |
+
"""Train a model on a dataset. (Phase 2)"""
|
| 87 |
+
target: str
|
| 88 |
+
dataset: str
|
| 89 |
+
method: str = "grpo" # "grpo", "sft", "dpo"
|
| 90 |
+
steps: Optional[int] = None
|
| 91 |
+
learning_rate: Optional[float] = None
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass
|
| 95 |
+
class DebateCmd:
|
| 96 |
+
"""Generate multi-answer debate for preference pairs. (Phase 2)"""
|
| 97 |
+
target: str
|
| 98 |
+
rounds: int = 3
|
| 99 |
+
candidates: int = 8
|
| 100 |
+
output: Optional[str] = None
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@dataclass
|
| 104 |
+
class DiagnoseCmd:
|
| 105 |
+
"""Ask model what it's bad at — self-diagnosis. (Phase 2)"""
|
| 106 |
+
target: str
|
| 107 |
+
output: Optional[str] = None
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@dataclass
|
| 111 |
+
class ForkCmd:
|
| 112 |
+
"""Branch current model weights for parallel experiments. (Phase 3)
|
| 113 |
+
|
| 114 |
+
Example: fork base as experiment_v2
|
| 115 |
+
Cheap fork: copies manifest + adapters, shares base weights (default).
|
| 116 |
+
"""
|
| 117 |
+
source: str # Alias of model to fork from
|
| 118 |
+
alias: str # Name for the new branch
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@dataclass
|
| 122 |
+
class ResetCmd:
|
| 123 |
+
"""Revert model to a previous checkpoint. (Phase 3)
|
| 124 |
+
|
| 125 |
+
Example: reset base to "checkpoint_042"
|
| 126 |
+
Deletes current model, clears CUDA cache, reloads from disk.
|
| 127 |
+
Must also reset optimizer state.
|
| 128 |
+
"""
|
| 129 |
+
target: str # Alias of model to reset
|
| 130 |
+
checkpoint: str # Checkpoint name/path to revert to
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@dataclass
|
| 134 |
+
class PruneCmd:
|
| 135 |
+
"""Structural pruning — remove low-utility neurons/heads. (Phase 3)
|
| 136 |
+
|
| 137 |
+
Example: prune base using wanda aggressiveness 0.2
|
| 138 |
+
Safe zone: ~20% max (LLM-Pruner paper). Language backbone only.
|
| 139 |
+
"""
|
| 140 |
+
target: str
|
| 141 |
+
method: str = "wanda" # "wanda", "magnitude", "taylor"
|
| 142 |
+
aggressiveness: float = 0.2 # Fraction to remove (0.0-1.0)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@dataclass
|
| 146 |
+
class EditCmd:
|
| 147 |
+
"""Surgical LoRA/DoRA editing on specific layers. (Phase 3)
|
| 148 |
+
|
| 149 |
+
Example: edit base layers 16-28 using lora lr 1e-4
|
| 150 |
+
"Try before buy": eval with adapter enabled vs disabled before merging.
|
| 151 |
+
"""
|
| 152 |
+
target: str
|
| 153 |
+
layers: str = "all" # "all", "16-28", single number
|
| 154 |
+
method: str = "lora" # "lora" or "dora"
|
| 155 |
+
learning_rate: Optional[float] = None
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ============================================================================
|
| 159 |
+
# PHASE 4 COMMANDS — Contracts, Lineage, Economics (ForgeSpec 2.0, test_17)
|
| 160 |
+
# ============================================================================
|
| 161 |
+
|
| 162 |
+
# ============================================================================
|
| 163 |
+
# PHASE 7 — LOOP CONTROL (repeat, if/else)
|
| 164 |
+
# ============================================================================
|
| 165 |
+
|
| 166 |
+
@dataclass
|
| 167 |
+
class RepeatBlock:
|
| 168 |
+
"""Repeat a block of commands N times. (Phase 7 — Loop Control)
|
| 169 |
+
|
| 170 |
+
Example:
|
| 171 |
+
repeat 5 {
|
| 172 |
+
diagnose base
|
| 173 |
+
synth base from base
|
| 174 |
+
train base on "data.jsonl" using grpo steps 64
|
| 175 |
+
eval base
|
| 176 |
+
}
|
| 177 |
+
"""
|
| 178 |
+
count: int # Number of iterations
|
| 179 |
+
body: List[Any] = field(default_factory=list) # Commands inside the block
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@dataclass
|
| 183 |
+
class IfBlock:
|
| 184 |
+
"""Conditional execution based on last eval result. (Phase 7 — Loop Control)
|
| 185 |
+
|
| 186 |
+
Example:
|
| 187 |
+
if eval_passed {
|
| 188 |
+
commit base
|
| 189 |
+
} else {
|
| 190 |
+
reset base to "last_good"
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
Condition checks the most recent eval result for the target.
|
| 194 |
+
"""
|
| 195 |
+
condition: str # "eval_passed", "gate_passed", etc.
|
| 196 |
+
target: Optional[str] = None # Which model's eval to check
|
| 197 |
+
then_body: List[Any] = field(default_factory=list)
|
| 198 |
+
else_body: List[Any] = field(default_factory=list)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@dataclass
|
| 202 |
+
class FuseCmd:
|
| 203 |
+
"""Fuse multiple models into a target in one shot. (Phase 6 — Easy Merge)
|
| 204 |
+
|
| 205 |
+
Example: fuse [deepseek-r1, mimo-7b, llama-3.1] into base
|
| 206 |
+
Auto-picks Transport and Merge, auto-sets per-model strength.
|
| 207 |
+
Handles cross-architecture merging (all 5 source models have different archs).
|
| 208 |
+
"""
|
| 209 |
+
sources: list[str] # List of model names/paths to fuse in
|
| 210 |
+
target: str # Alias to merge into (must be loaded)
|
| 211 |
+
method: str = "transport" # Default: transport and merge (cross-arch)
|
| 212 |
+
strategy: str = "equal" # "equal" (same strength each), "weighted", "sequential"
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
@dataclass
|
| 216 |
+
class AbsorbCmd:
|
| 217 |
+
"""Absorb a single model into target — simplified merge. (Phase 6 — Easy Merge)
|
| 218 |
+
|
| 219 |
+
Example: absorb "deepseek-ai/DeepSeek-R1" into base strength 0.5
|
| 220 |
+
One-liner for the common case of merging one model in.
|
| 221 |
+
"""
|
| 222 |
+
source: str # Model path or HF ID
|
| 223 |
+
target: str # Alias to merge into
|
| 224 |
+
strength: float = 0.5 # 0.0=keep target, 1.0=keep source, default balanced
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@dataclass
|
| 228 |
+
class SnapshotCmd:
|
| 229 |
+
"""Save a content-hashed snapshot of model state for lineage tracking. (Phase 4)
|
| 230 |
+
|
| 231 |
+
Example: snapshot base -> snapshots/
|
| 232 |
+
Creates a content-addressed directory: snapshots/<sha256_prefix>/
|
| 233 |
+
Contains: model state, adapter state, prune spec, eval report, manifest.
|
| 234 |
+
"""
|
| 235 |
+
target: str
|
| 236 |
+
output: Optional[str] = None # Output directory (default: td_lang_outputs/snapshots/)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@dataclass
|
| 240 |
+
class ReportCmd:
|
| 241 |
+
"""Generate an economics report for this run. (Phase 4)
|
| 242 |
+
|
| 243 |
+
Example: report -> economics.json
|
| 244 |
+
Tracks: GPU hours, cost estimate, tokens processed, experiments run,
|
| 245 |
+
time per command, cost breakdown by phase.
|
| 246 |
+
"""
|
| 247 |
+
output: Optional[str] = None # Output file path
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# ============================================================================
|
| 251 |
+
# PHASE 8 — AUTOPILOT (setup, notify, save, on_error, resume)
|
| 252 |
+
# ============================================================================
|
| 253 |
+
|
| 254 |
+
@dataclass
|
| 255 |
+
class NotifyCmd:
|
| 256 |
+
"""Send a notification via ntfy.sh. (Phase 8 — Autopilot)
|
| 257 |
+
|
| 258 |
+
Example: notify "Training complete!"
|
| 259 |
+
Uses curl to POST to the configured ntfy topic.
|
| 260 |
+
"""
|
| 261 |
+
message: str
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@dataclass
|
| 265 |
+
class SaveCmd:
|
| 266 |
+
"""Save/upload model to cloud storage via rclone. (Phase 8 — Autopilot)
|
| 267 |
+
|
| 268 |
+
Example: save base to "gdrive:TD/models/v1"
|
| 269 |
+
Uses rclone to copy model checkpoint to Google Drive (or any rclone remote).
|
| 270 |
+
"""
|
| 271 |
+
target: str # Alias of model to save
|
| 272 |
+
destination: str # rclone destination path
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@dataclass
|
| 276 |
+
class SetupBlock:
|
| 277 |
+
"""Auto-install dependencies and configure environment. (Phase 8 — Autopilot)
|
| 278 |
+
|
| 279 |
+
Example:
|
| 280 |
+
setup {
|
| 281 |
+
pip = [torch, transformers, peft, bitsandbytes, trl]
|
| 282 |
+
hf_token = env
|
| 283 |
+
notify = "ntfy.sh/my_ai"
|
| 284 |
+
}
|
| 285 |
+
"""
|
| 286 |
+
pip_packages: list[str] = field(default_factory=list)
|
| 287 |
+
hf_token: Optional[str] = None # "env" = read HF_TOKEN from env
|
| 288 |
+
notify_url: Optional[str] = None # ntfy.sh topic URL
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
@dataclass
|
| 292 |
+
class OnErrorBlock:
|
| 293 |
+
"""Crash recovery behavior. (Phase 8 — Autopilot)
|
| 294 |
+
|
| 295 |
+
Example:
|
| 296 |
+
on_error {
|
| 297 |
+
retry = 3
|
| 298 |
+
fallback = reduce_batch
|
| 299 |
+
notify = true
|
| 300 |
+
}
|
| 301 |
+
"""
|
| 302 |
+
retry: int = 3 # Number of retries per failed step
|
| 303 |
+
fallback: str = "reduce_batch" # "reduce_batch", "skip", "snapshot_and_stop"
|
| 304 |
+
notify: bool = True # Send ntfy notification on error
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# ============================================================================
|
| 308 |
+
# BLOCKS (gates, budget, contracts, etc.)
|
| 309 |
+
# ============================================================================
|
| 310 |
+
|
| 311 |
+
@dataclass
|
| 312 |
+
class GateBlock:
|
| 313 |
+
"""Validation gates that must pass before commit.
|
| 314 |
+
|
| 315 |
+
Example:
|
| 316 |
+
gate {
|
| 317 |
+
must_pass = [canary, perplexity, thinking_mode]
|
| 318 |
+
}
|
| 319 |
+
"""
|
| 320 |
+
must_pass: list[str] = field(default_factory=list)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@dataclass
|
| 324 |
+
class BudgetBlock:
|
| 325 |
+
"""Resource budget — compiler refuses plans that exceed limits.
|
| 326 |
+
|
| 327 |
+
Example:
|
| 328 |
+
budget {
|
| 329 |
+
max_gpu_hours = 8
|
| 330 |
+
max_cost = 50.00
|
| 331 |
+
}
|
| 332 |
+
"""
|
| 333 |
+
max_gpu_hours: Optional[float] = None
|
| 334 |
+
max_cost: Optional[float] = None
|
| 335 |
+
max_tokens: Optional[int] = None
|
| 336 |
+
max_experiments: Optional[int] = None
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
@dataclass
|
| 340 |
+
class DataContractBlock:
|
| 341 |
+
"""Schema enforcement on training data. (Phase 4, ForgeSpec 2.0)
|
| 342 |
+
|
| 343 |
+
Example:
|
| 344 |
+
data_contract {
|
| 345 |
+
required_fields = [prompt, response]
|
| 346 |
+
min_samples = 100
|
| 347 |
+
max_perplexity = 50.0
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
Compiler checks training data at synth/train time.
|
| 351 |
+
"""
|
| 352 |
+
required_fields: list[str] = field(default_factory=list)
|
| 353 |
+
min_samples: Optional[int] = None
|
| 354 |
+
max_perplexity: Optional[float] = None
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
@dataclass
|
| 358 |
+
class RewardContractBlock:
|
| 359 |
+
"""Verified reward definitions — what counts as "correct". (Phase 4, ForgeSpec 2.0)
|
| 360 |
+
|
| 361 |
+
Example:
|
| 362 |
+
reward_contract {
|
| 363 |
+
verifiers = [code_compiles, math_correct, no_hallucination]
|
| 364 |
+
min_reward = 0.3
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
Used by train (GRPO) to enforce reward quality.
|
| 368 |
+
No learned reward model — verified rewards only (test_16).
|
| 369 |
+
"""
|
| 370 |
+
verifiers: list[str] = field(default_factory=list)
|
| 371 |
+
min_reward: Optional[float] = None
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
# ============================================================================
|
| 375 |
+
# TOP-LEVEL PROGRAM
|
| 376 |
+
# ============================================================================
|
| 377 |
+
|
| 378 |
+
@dataclass
|
| 379 |
+
class TDProgram:
|
| 380 |
+
"""A complete parsed .td file — commands in order plus global blocks."""
|
| 381 |
+
|
| 382 |
+
commands: List[Any] = field(default_factory=list)
|
| 383 |
+
gates: Optional[GateBlock] = None
|
| 384 |
+
budget: Optional[BudgetBlock] = None
|
| 385 |
+
data_contract: Optional[DataContractBlock] = None
|
| 386 |
+
reward_contract: Optional[RewardContractBlock] = None
|
| 387 |
+
setup: Optional[SetupBlock] = None
|
| 388 |
+
on_error: Optional[OnErrorBlock] = None
|
| 389 |
+
source_file: Optional[str] = None
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
__all__ = [
|
| 393 |
+
"LoadCmd",
|
| 394 |
+
"MergeCmd",
|
| 395 |
+
"HealCmd",
|
| 396 |
+
"EvalCmd",
|
| 397 |
+
"CommitCmd",
|
| 398 |
+
"SynthCmd",
|
| 399 |
+
"TrainCmd",
|
| 400 |
+
"DebateCmd",
|
| 401 |
+
"DiagnoseCmd",
|
| 402 |
+
"ForkCmd",
|
| 403 |
+
"ResetCmd",
|
| 404 |
+
"PruneCmd",
|
| 405 |
+
"EditCmd",
|
| 406 |
+
"RepeatBlock",
|
| 407 |
+
"IfBlock",
|
| 408 |
+
"FuseCmd",
|
| 409 |
+
"AbsorbCmd",
|
| 410 |
+
"SnapshotCmd",
|
| 411 |
+
"ReportCmd",
|
| 412 |
+
"NotifyCmd",
|
| 413 |
+
"SaveCmd",
|
| 414 |
+
"SetupBlock",
|
| 415 |
+
"OnErrorBlock",
|
| 416 |
+
"GateBlock",
|
| 417 |
+
"BudgetBlock",
|
| 418 |
+
"DataContractBlock",
|
| 419 |
+
"RewardContractBlock",
|
| 420 |
+
"TDProgram",
|
| 421 |
+
]
|
hugging/td_lang/cli.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Lang CLI — Command-line interface for .td files.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python -m td_lang run examples/demo_merge.td # Compile + execute
|
| 6 |
+
python -m td_lang compile examples/demo_merge.td # Compile only (outputs .py)
|
| 7 |
+
python -m td_lang check examples/demo_merge.td # Syntax check only
|
| 8 |
+
python -m td_lang info examples/demo_merge.td # Show plan without compiling
|
| 9 |
+
python -m td_lang --version # Show version
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
from . import __version__
|
| 16 |
+
from .executor import TDExecutor
|
| 17 |
+
from .errors import TDLangError
|
| 18 |
+
from .grammar import parse_td_file
|
| 19 |
+
from .ast_nodes import (
|
| 20 |
+
LoadCmd, MergeCmd, HealCmd, EvalCmd, CommitCmd,
|
| 21 |
+
SynthCmd, TrainCmd, DebateCmd, DiagnoseCmd,
|
| 22 |
+
ForkCmd, ResetCmd, PruneCmd, EditCmd,
|
| 23 |
+
FuseCmd, AbsorbCmd, RepeatBlock, IfBlock,
|
| 24 |
+
NotifyCmd, SaveCmd,
|
| 25 |
+
SnapshotCmd, ReportCmd,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Phase labels for info command
|
| 30 |
+
_PHASE_MAP = {
|
| 31 |
+
LoadCmd: ("1", "load"),
|
| 32 |
+
MergeCmd: ("1", "merge"),
|
| 33 |
+
HealCmd: ("1", "heal"),
|
| 34 |
+
EvalCmd: ("1", "eval"),
|
| 35 |
+
CommitCmd: ("1", "commit"),
|
| 36 |
+
SynthCmd: ("2", "synth"),
|
| 37 |
+
TrainCmd: ("2", "train"),
|
| 38 |
+
DebateCmd: ("2", "debate"),
|
| 39 |
+
DiagnoseCmd: ("2", "diagnose"),
|
| 40 |
+
ForkCmd: ("3", "fork"),
|
| 41 |
+
ResetCmd: ("3", "reset"),
|
| 42 |
+
PruneCmd: ("3", "prune"),
|
| 43 |
+
EditCmd: ("3", "edit"),
|
| 44 |
+
FuseCmd: ("6", "fuse"),
|
| 45 |
+
AbsorbCmd: ("6", "absorb"),
|
| 46 |
+
RepeatBlock: ("7", "repeat"),
|
| 47 |
+
IfBlock: ("7", "if"),
|
| 48 |
+
NotifyCmd: ("8", "notify"),
|
| 49 |
+
SaveCmd: ("8", "save"),
|
| 50 |
+
SnapshotCmd: ("4", "snapshot"),
|
| 51 |
+
ReportCmd: ("4", "report"),
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def parse_args() -> argparse.Namespace:
|
| 56 |
+
"""Parse command-line arguments."""
|
| 57 |
+
parser = argparse.ArgumentParser(
|
| 58 |
+
description="TD Lang — compile and run .td files for Time Dilation",
|
| 59 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 60 |
+
epilog="""
|
| 61 |
+
Examples:
|
| 62 |
+
python -m td_lang check examples/demo_merge.td # Check syntax
|
| 63 |
+
python -m td_lang compile examples/demo_merge.td # Compile to .py
|
| 64 |
+
python -m td_lang run examples/demo_merge.td # Compile + run
|
| 65 |
+
python -m td_lang run examples/demo_merge.td --dry # Compile only
|
| 66 |
+
python -m td_lang info examples/demo_merge.td # Show plan summary
|
| 67 |
+
""",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--version",
|
| 72 |
+
action="version",
|
| 73 |
+
version=f"td_lang {__version__}",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"action",
|
| 78 |
+
choices=["check", "compile", "run", "info"],
|
| 79 |
+
help="What to do: check (syntax), compile (.py), run (compile+execute), info (show plan)",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"file",
|
| 84 |
+
type=str,
|
| 85 |
+
help="Path to the .td file",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--output",
|
| 90 |
+
type=str,
|
| 91 |
+
default="td_lang_outputs",
|
| 92 |
+
help="Output directory (default: td_lang_outputs)",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--dry",
|
| 97 |
+
action="store_true",
|
| 98 |
+
help="With 'run': compile but don't execute",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--verbose", "-v",
|
| 103 |
+
action="store_true",
|
| 104 |
+
help="Show extra detail (compiled Python, full AST, etc.)",
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
return parser.parse_args()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def print_banner():
|
| 111 |
+
"""Print the td_lang banner."""
|
| 112 |
+
banner = f"""
|
| 113 |
+
╔═══════════════════════════════════════╗
|
| 114 |
+
║ ║
|
| 115 |
+
║ ████████╗██████╗ ██╗ ██████╗║
|
| 116 |
+
║ ╚══██╔══╝██╔══██╗ ██║ ██╔════╝║
|
| 117 |
+
║ ██║ ██║ ██║ ██║ ██║ ███║
|
| 118 |
+
║ ██║ ██║ ██║ ██║ ██║ ██║
|
| 119 |
+
║ ██║ ██████╔╝ ██████╗ ╚██████╔╝║
|
| 120 |
+
║ ╚═╝ ╚═════╝ ╚═════╝ ╚═════╝║
|
| 121 |
+
║ ║
|
| 122 |
+
║ TD Lang v{__version__} — .td file compiler ║
|
| 123 |
+
║ ║
|
| 124 |
+
╚═══════════════════════════════════════╝
|
| 125 |
+
"""
|
| 126 |
+
print(banner)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def print_info(filepath: str) -> None:
|
| 130 |
+
"""Show what a .td file does without compiling — human-readable plan summary."""
|
| 131 |
+
program = parse_td_file(filepath)
|
| 132 |
+
|
| 133 |
+
print(f"\n File: {filepath}")
|
| 134 |
+
print(f" Commands: {len(program.commands)}")
|
| 135 |
+
|
| 136 |
+
if program.gates:
|
| 137 |
+
print(f" Gates: {', '.join(program.gates.must_pass)}")
|
| 138 |
+
if program.budget:
|
| 139 |
+
parts = []
|
| 140 |
+
if program.budget.max_gpu_hours is not None:
|
| 141 |
+
parts.append(f"{program.budget.max_gpu_hours} GPU hrs")
|
| 142 |
+
if program.budget.max_cost is not None:
|
| 143 |
+
parts.append(f"${program.budget.max_cost}")
|
| 144 |
+
print(f" Budget: {', '.join(parts)}")
|
| 145 |
+
if program.data_contract:
|
| 146 |
+
print(f" Data contract: fields={program.data_contract.required_fields}")
|
| 147 |
+
if program.reward_contract:
|
| 148 |
+
print(f" Reward contract: verifiers={program.reward_contract.verifiers}")
|
| 149 |
+
|
| 150 |
+
print("\n Plan:")
|
| 151 |
+
for i, cmd in enumerate(program.commands, 1):
|
| 152 |
+
phase, name = _PHASE_MAP.get(type(cmd), ("?", type(cmd).__name__))
|
| 153 |
+
target = getattr(cmd, 'target', getattr(cmd, 'alias', ''))
|
| 154 |
+
detail = ""
|
| 155 |
+
if hasattr(cmd, 'method'):
|
| 156 |
+
detail += f" method={cmd.method}"
|
| 157 |
+
if hasattr(cmd, 'source') and name in ("merge", "synth"):
|
| 158 |
+
detail += f" from={cmd.source}"
|
| 159 |
+
if hasattr(cmd, 'layers') and cmd.layers != "all":
|
| 160 |
+
detail += f" layers={cmd.layers}"
|
| 161 |
+
if hasattr(cmd, 'output') and cmd.output:
|
| 162 |
+
detail += f" -> {cmd.output}"
|
| 163 |
+
print(f" {i}. [P{phase}] {name} {target}{detail}")
|
| 164 |
+
|
| 165 |
+
print()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def main():
|
| 169 |
+
"""Main entry point for td_lang CLI."""
|
| 170 |
+
args = parse_args()
|
| 171 |
+
print_banner()
|
| 172 |
+
|
| 173 |
+
executor = TDExecutor(output_dir=args.output)
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
if args.action == "info":
|
| 177 |
+
print_info(args.file)
|
| 178 |
+
|
| 179 |
+
elif args.action == "check":
|
| 180 |
+
program = executor.check(args.file)
|
| 181 |
+
print("\n[td_lang] File is valid!")
|
| 182 |
+
|
| 183 |
+
elif args.action == "compile":
|
| 184 |
+
py_path = executor.compile(args.file)
|
| 185 |
+
print(f"\n[td_lang] Generated: {py_path}")
|
| 186 |
+
print("[td_lang] You can run it with: python", py_path)
|
| 187 |
+
if args.verbose:
|
| 188 |
+
print("\n--- Generated Python ---")
|
| 189 |
+
print(py_path.read_text())
|
| 190 |
+
print("--- End ---")
|
| 191 |
+
|
| 192 |
+
elif args.action == "run":
|
| 193 |
+
result = executor.run(args.file, dry_run=args.dry)
|
| 194 |
+
if result["status"] == "success":
|
| 195 |
+
sys.exit(0)
|
| 196 |
+
elif result["status"] == "dry_run":
|
| 197 |
+
sys.exit(0)
|
| 198 |
+
else:
|
| 199 |
+
sys.exit(1)
|
| 200 |
+
|
| 201 |
+
except TDLangError as e:
|
| 202 |
+
print(f"\n[td_lang] ERROR: {e}")
|
| 203 |
+
sys.exit(1)
|
| 204 |
+
|
| 205 |
+
except FileNotFoundError:
|
| 206 |
+
print(f"\n[td_lang] ERROR: File not found: {args.file}")
|
| 207 |
+
print("[td_lang] Check the path and try again.")
|
| 208 |
+
sys.exit(1)
|
| 209 |
+
|
| 210 |
+
except KeyboardInterrupt:
|
| 211 |
+
print("\n[td_lang] Interrupted.")
|
| 212 |
+
sys.exit(130)
|
hugging/td_lang/compiler.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
hugging/td_lang/errors.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD Lang Errors — Clear, helpful error messages.
|
| 3 |
+
|
| 4 |
+
Milan is 11 — errors should say what went wrong and where,
|
| 5 |
+
not dump cryptic stack traces.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TDLangError(Exception):
|
| 10 |
+
"""Base error for all td_lang errors."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, message: str, line: int | None = None, hint: str | None = None):
|
| 13 |
+
self.line = line
|
| 14 |
+
self.hint = hint
|
| 15 |
+
if line is not None:
|
| 16 |
+
full = f"Line {line}: {message}"
|
| 17 |
+
else:
|
| 18 |
+
full = message
|
| 19 |
+
if hint:
|
| 20 |
+
full += f"\n Hint: {hint}"
|
| 21 |
+
super().__init__(full)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TDSyntaxError(TDLangError):
|
| 25 |
+
"""Bad .td syntax — couldn't understand the file."""
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TDCompileError(TDLangError):
|
| 30 |
+
"""Valid syntax but impossible plan — e.g., merging into a model that doesn't exist."""
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class TDGateError(TDLangError):
|
| 35 |
+
"""Gates failed during execution."""
|
| 36 |
+
|
| 37 |
+
def __init__(self, failed_gates: list[str], message: str = ""):
|
| 38 |
+
self.failed_gates = failed_gates
|
| 39 |
+
msg = message or f"Gates failed: {', '.join(failed_gates)}"
|
| 40 |
+
super().__init__(msg, hint="Check eval results — the model may have regressed.")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TDBudgetError(TDLangError):
|
| 44 |
+
"""Budget would be exceeded — compiler refuses to run."""
|
| 45 |
+
|
| 46 |
+
def __init__(self, field: str, limit: float, requested: float):
|
| 47 |
+
self.field = field
|
| 48 |
+
self.limit = limit
|
| 49 |
+
self.requested = requested
|
| 50 |
+
super().__init__(
|
| 51 |
+
f"Budget exceeded: {field} limit is {limit}, but plan needs ~{requested}",
|
| 52 |
+
hint="Reduce steps, use fewer merges, or increase the budget.",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TDContractError(TDLangError):
|
| 57 |
+
"""Data or reward contract violation — training data doesn't match spec."""
|
| 58 |
+
|
| 59 |
+
def __init__(self, contract_type: str, violations: list[str]):
|
| 60 |
+
self.contract_type = contract_type
|
| 61 |
+
self.violations = violations
|
| 62 |
+
msg = f"{contract_type} contract failed with {len(violations)} violation(s)"
|
| 63 |
+
if violations:
|
| 64 |
+
msg += f": {violations[0]}"
|
| 65 |
+
if len(violations) > 1:
|
| 66 |
+
msg += f" (and {len(violations)-1} more)"
|
| 67 |
+
super().__init__(
|
| 68 |
+
msg,
|
| 69 |
+
hint="Check your training data matches the contract spec.",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ============================================================================
|
| 74 |
+
# COMMON MISTAKE SUGGESTIONS (Phase 5)
|
| 75 |
+
# ============================================================================
|
| 76 |
+
|
| 77 |
+
COMMON_FIXES = {
|
| 78 |
+
"load": 'Did you forget quotes? Correct: load "model/path" as name',
|
| 79 |
+
"merge": 'Format: merge "source" into target using method [strength 0.5]',
|
| 80 |
+
"edit": "Format: edit target layers 16-28 using lora [lr 1e-4]",
|
| 81 |
+
"prune": "Format: prune target using wanda [aggressiveness 0.2]",
|
| 82 |
+
"fork": "Format: fork source as new_name",
|
| 83 |
+
"reset": 'Format: reset target to "checkpoint_path"',
|
| 84 |
+
"train": 'Format: train target on "dataset" using grpo [steps 64]',
|
| 85 |
+
"synth": "Format: synth target from source [filter cherry_llm]",
|
| 86 |
+
"snapshot": "Format: snapshot target [-> output_dir]",
|
| 87 |
+
"report": "Format: report [-> economics.json]",
|
| 88 |
+
"fuse": 'Format: fuse ["model1", "model2"] into target [strategy equal]',
|
| 89 |
+
"absorb": 'Format: absorb "model" into target [strength 0.5]',
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def suggest_fix(token: str) -> str | None:
|
| 94 |
+
"""Given a failed token, suggest the correct syntax."""
|
| 95 |
+
token_lower = token.lower().strip()
|
| 96 |
+
for keyword, fix in COMMON_FIXES.items():
|
| 97 |
+
if keyword in token_lower:
|
| 98 |
+
return fix
|
| 99 |
+
return None
|
hugging/td_lang/examples/demo_autopilot.td
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# demo_autopilot.td — The full "rent a GPU and go" pipeline
|
| 2 |
+
# Rent vast.ai, upload this file, run: python -m td_lang run demo_autopilot.td
|
| 3 |
+
# Then sit back — you'll get ntfy notifications on your phone.
|
| 4 |
+
|
| 5 |
+
# === ENVIRONMENT ===
|
| 6 |
+
setup {
|
| 7 |
+
pip = [torch, transformers, peft, bitsandbytes, trl, safetensors, datasets, accelerate, huggingface_hub, sentencepiece]
|
| 8 |
+
hf_token = env
|
| 9 |
+
notify = "ntfy.sh/my_ai"
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
on_error {
|
| 13 |
+
retry = 3
|
| 14 |
+
fallback = reduce_batch
|
| 15 |
+
notify = true
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
# === QUALITY RULES ===
|
| 19 |
+
gate { must_pass = [canary, perplexity, thinking_mode] }
|
| 20 |
+
budget { max_gpu_hours = 40 max_cost = 160.00 }
|
| 21 |
+
|
| 22 |
+
data_contract {
|
| 23 |
+
required_fields = [prompt, response]
|
| 24 |
+
min_samples = 50
|
| 25 |
+
max_perplexity = 50.0
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
reward_contract {
|
| 29 |
+
verifiers = [code_compiles, math_correct]
|
| 30 |
+
min_reward = 0.3
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# === PIPELINE ===
|
| 34 |
+
|
| 35 |
+
# Step 1: Load and fuse
|
| 36 |
+
load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 37 |
+
fuse ["deepseek-ai/DeepSeek-R1", "MiMo-7B", "meta-llama/Llama-3.1-8B", "tiiuae/Falcon-H1R-7B"] into base
|
| 38 |
+
heal base lora_r 32 epochs 2
|
| 39 |
+
notify "Merge + heal complete. Starting self-improvement loop."
|
| 40 |
+
|
| 41 |
+
# Step 2: Self-improvement loop
|
| 42 |
+
repeat 5 {
|
| 43 |
+
diagnose base -> weaknesses.json
|
| 44 |
+
synth base from base filter cherry_llm -> training_data.jsonl
|
| 45 |
+
train base on "training_data.jsonl" using grpo steps 64 lr 5e-5
|
| 46 |
+
eval base -> eval_results.json
|
| 47 |
+
|
| 48 |
+
if eval_passed base {
|
| 49 |
+
commit base
|
| 50 |
+
snapshot base -> snapshots/
|
| 51 |
+
notify "Loop iteration passed! Model improved."
|
| 52 |
+
} else {
|
| 53 |
+
reset base to "snapshots/"
|
| 54 |
+
notify "Loop iteration failed. Reset to last good snapshot."
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# Step 3: Save and notify
|
| 59 |
+
snapshot base -> final_model/
|
| 60 |
+
save base to "gdrive:TD/models/final"
|
| 61 |
+
report -> economics.json
|
| 62 |
+
notify "TD PIPELINE COMPLETE. Model saved to Google Drive."
|
hugging/td_lang/examples/demo_full.td
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Full Phase 1 demo with gates and budget
|
| 2 |
+
gate {
|
| 3 |
+
must_pass = [canary, perplexity, thinking_mode]
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
budget {
|
| 7 |
+
max_gpu_hours = 8
|
| 8 |
+
max_cost = 50.00
|
| 9 |
+
max_tokens = 20000000
|
| 10 |
+
max_experiments = 4
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 14 |
+
merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
|
| 15 |
+
heal base lora_r 32 epochs 2
|
| 16 |
+
eval base -> full_eval.json
|
| 17 |
+
commit base
|
hugging/td_lang/examples/demo_fuse.td
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# demo_fuse.td — Easy merge: fuse multiple models in one command
|
| 2 |
+
# The entire TD merge strategy in 5 lines
|
| 3 |
+
|
| 4 |
+
gate { must_pass = [canary, perplexity, thinking_mode] }
|
| 5 |
+
budget { max_gpu_hours = 30 max_cost = 120.00 }
|
| 6 |
+
|
| 7 |
+
load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 8 |
+
|
| 9 |
+
# Fuse all 4 donor models in one shot — auto Transport and Merge
|
| 10 |
+
fuse ["deepseek-ai/DeepSeek-R1", "MiMo-7B", "meta-llama/Llama-3.1-8B", "tiiuae/Falcon-H1R-7B"] into base
|
| 11 |
+
|
| 12 |
+
# Or absorb a single model with custom strength
|
| 13 |
+
# absorb "deepseek-ai/DeepSeek-R1" into base strength 0.6
|
| 14 |
+
|
| 15 |
+
heal base lora_r 32 epochs 2
|
| 16 |
+
eval base -> post_fuse_eval.json
|
| 17 |
+
commit base if [canary, perplexity, thinking_mode]
|
| 18 |
+
snapshot base -> snapshots/
|
| 19 |
+
report -> economics.json
|
hugging/td_lang/examples/demo_heal.td
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Demo: merge then heal, evaluate, and commit with gates
|
| 2 |
+
load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 3 |
+
merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
|
| 4 |
+
heal base lora_r 32 epochs 2
|
| 5 |
+
eval base -> report.json
|
| 6 |
+
commit base if [canary, perplexity, thinking_mode]
|
hugging/td_lang/examples/demo_loop.td
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# demo_loop.td — Self-improvement loop (Phase 2)
|
| 2 |
+
# The core TD cycle: diagnose -> synth -> train -> evaluate -> commit
|
| 3 |
+
|
| 4 |
+
gate {
|
| 5 |
+
must_pass = [canary, perplexity, thinking_mode]
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
budget {
|
| 9 |
+
max_gpu_hours = 10
|
| 10 |
+
max_cost = 40.00
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 14 |
+
|
| 15 |
+
# Step 1: Ask the model what it's bad at
|
| 16 |
+
diagnose base -> weaknesses.json
|
| 17 |
+
|
| 18 |
+
# Step 2: Generate training data targeting those weaknesses
|
| 19 |
+
synth base from web_curated filter cherry_llm -> synth_data.jsonl
|
| 20 |
+
|
| 21 |
+
# Step 3: Train with GRPO (64 steps = sweet spot from test_15)
|
| 22 |
+
train base on "synth_data.jsonl" using grpo steps 64
|
| 23 |
+
|
| 24 |
+
# Step 4: Check if it actually got better
|
| 25 |
+
eval base -> post_training_eval.json
|
| 26 |
+
|
| 27 |
+
# Step 5: Only save if gates pass
|
| 28 |
+
commit base
|
hugging/td_lang/examples/demo_merge.td
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Demo: load + merge + eval + commit
|
| 2 |
+
load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 3 |
+
merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
|
| 4 |
+
eval base -> eval_base.json
|
| 5 |
+
commit base if [canary, perplexity, thinking_mode]
|
hugging/td_lang/examples/demo_phase3.td
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# demo_phase3.td — Phase 3 commands: edit, fork, reset, prune
|
| 2 |
+
# The full surgical toolkit for model experimentation
|
| 3 |
+
|
| 4 |
+
gate {
|
| 5 |
+
must_pass = [canary, perplexity, thinking_mode]
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
budget {
|
| 9 |
+
max_gpu_hours = 12
|
| 10 |
+
max_cost = 60.00
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
# Load the base model
|
| 14 |
+
load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 15 |
+
|
| 16 |
+
# Fork before experimenting (like git branch)
|
| 17 |
+
fork base as experiment
|
| 18 |
+
|
| 19 |
+
# Surgical edit: LoRA on reasoning layers 16-28
|
| 20 |
+
edit experiment layers 16-28 using lora lr 1e-4
|
| 21 |
+
|
| 22 |
+
# Evaluate the edit
|
| 23 |
+
eval experiment -> post_edit_eval.json
|
| 24 |
+
|
| 25 |
+
# If it's good, commit; if bad, we can reset
|
| 26 |
+
commit experiment
|
hugging/td_lang/examples/demo_phase4.td
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# demo_phase4.td — Phase 4: Contracts, Lineage, Economics
|
| 2 |
+
# ForgeSpec 2.0 features from test_17
|
| 3 |
+
|
| 4 |
+
gate { must_pass = [canary, perplexity, thinking_mode] }
|
| 5 |
+
|
| 6 |
+
budget {
|
| 7 |
+
max_gpu_hours = 20
|
| 8 |
+
max_cost = 100.00
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
data_contract {
|
| 12 |
+
required_fields = [prompt, response]
|
| 13 |
+
min_samples = 100
|
| 14 |
+
max_perplexity = 50.0
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
reward_contract {
|
| 18 |
+
verifiers = [code_compiles, math_correct]
|
| 19 |
+
min_reward = 0.3
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
# Pipeline with full tracking
|
| 23 |
+
load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 24 |
+
fork base as experiment
|
| 25 |
+
|
| 26 |
+
edit experiment layers 16-28 using lora lr 1e-4
|
| 27 |
+
snapshot experiment -> snapshots/
|
| 28 |
+
|
| 29 |
+
eval experiment -> post_edit_eval.json
|
| 30 |
+
commit experiment
|
| 31 |
+
|
| 32 |
+
# Economics report at the end
|
| 33 |
+
report -> economics.json
|
hugging/td_lang/examples/demo_td_loop.td
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# demo_td_loop.td — The complete TD self-improvement pipeline
|
| 2 |
+
# This is what td_loop runs: merge, then iterate to get smarter
|
| 3 |
+
|
| 4 |
+
gate { must_pass = [canary, perplexity, thinking_mode] }
|
| 5 |
+
budget { max_gpu_hours = 50 max_cost = 200.00 }
|
| 6 |
+
|
| 7 |
+
data_contract {
|
| 8 |
+
required_fields = [prompt, response]
|
| 9 |
+
min_samples = 50
|
| 10 |
+
max_perplexity = 50.0
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
reward_contract {
|
| 14 |
+
verifiers = [code_compiles, math_correct]
|
| 15 |
+
min_reward = 0.3
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
# Step 1: Load base model
|
| 19 |
+
load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 20 |
+
|
| 21 |
+
# Step 2: Fuse all donor models in one shot
|
| 22 |
+
fuse ["deepseek-ai/DeepSeek-R1", "MiMo-7B", "meta-llama/Llama-3.1-8B", "tiiuae/Falcon-H1R-7B"] into base
|
| 23 |
+
|
| 24 |
+
# Step 3: Heal the merge damage
|
| 25 |
+
heal base lora_r 32 epochs 2
|
| 26 |
+
snapshot base -> snapshots/
|
| 27 |
+
|
| 28 |
+
# Step 4: Self-improvement loop (the core of TD)
|
| 29 |
+
repeat 5 {
|
| 30 |
+
diagnose base -> weaknesses.json
|
| 31 |
+
synth base from base filter cherry_llm -> training_data.jsonl
|
| 32 |
+
train base on "training_data.jsonl" using grpo steps 64 lr 5e-5
|
| 33 |
+
eval base -> eval_results.json
|
| 34 |
+
|
| 35 |
+
if eval_passed base {
|
| 36 |
+
commit base
|
| 37 |
+
snapshot base -> snapshots/
|
| 38 |
+
} else {
|
| 39 |
+
reset base to "snapshots/"
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# Step 5: Final report
|
| 44 |
+
report -> final_economics.json
|
hugging/td_lang/examples/err_edit_unloaded.td
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# err_edit_unloaded.td — Should fail: editing a model before loading
|
| 2 |
+
edit ghost_model layers all using lora
|
hugging/td_lang/examples/err_fork_duplicate.td
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# err_fork_duplicate.td — Should fail: duplicate name
|
| 2 |
+
load "test" as base
|
| 3 |
+
fork base as base
|
hugging/td_lang/examples/err_prune_100.td
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# err_prune_100.td — Should fail/warn: prune at 100%
|
| 2 |
+
load "test" as base
|
| 3 |
+
prune base using wanda aggressiveness 1.0
|
| 4 |
+
# Note: Compiler might cap it at 30% per implementation notes
|
hugging/td_lang/examples/test_fork_edit.td
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# test_fork_edit.td — Test load -> fork -> edit -> eval -> commit
|
| 2 |
+
|
| 3 |
+
load "Qwen/Qwen3-VL-8B-Instruct" as base
|
| 4 |
+
|
| 5 |
+
# Fork the base model
|
| 6 |
+
fork base as experimental_branch
|
| 7 |
+
|
| 8 |
+
# Surgical edit with DoRA on specific layers
|
| 9 |
+
edit experimental_branch layers 20-28 using dora lr 1e-4
|
| 10 |
+
|
| 11 |
+
eval experimental_branch -> edit_report.json
|
| 12 |
+
commit experimental_branch
|