diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..6d80bfb51a3750dcd89c509a2fef6a3d910f4308 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +hugging/td_lang/__pycache__/compiler.cpython-314.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/hugging/QUICKSTART.md b/hugging/QUICKSTART.md new file mode 100644 index 0000000000000000000000000000000000000000..a06dea77c82472e70dabfb22d79ed32544d4e912 --- /dev/null +++ b/hugging/QUICKSTART.md @@ -0,0 +1,106 @@ +# TD Quick Start — Rent a GPU and Go + +## What You Need (One-Time Setup) + +1. **vast.ai account** — sign up at vast.ai, add credit ($10-20 to start) +2. **HuggingFace account** — sign up at huggingface.co (use any username, doesn't have to be your real name) +3. **HuggingFace token** — Settings → Access Tokens → New Token → **Write** access +4. **ntfy.sh app** on your phone (you already have this) + +## One-Time: Upload Your Code to Private HuggingFace + +Do this once from your computer. After this, your code lives in a private repo that only you can see. + +```bash +# Install the tool +pip install huggingface_hub + +# Log in (paste your token when asked) +huggingface-cli login + +# Upload everything +HF_USER=your_hf_username bash upload_to_hf.sh +``` + +Now your td_lang, td_fuse, .td files, and deploy script are all in a private HuggingFace repo. Nobody can see them except you. + +**When you update your code**, just run `upload_to_hf.sh` again — it overwrites with the latest version. + +## Every Time: Rent GPU → 3 Commands → Done + +### 1. Rent a GPU on vast.ai + +Go to vast.ai → Console → Search for: +- **GPU:** RTX 4090 (24GB) or A100 (40GB+) +- **Image:** Pick one with PyTorch pre-installed (like `pytorch/pytorch`) +- **Storage:** At least 100GB disk +- **Cost:** ~$0.40-0.80/hr for a 4090 + +Click **RENT** and wait for it to start (~1-2 minutes). + +### 2. Connect to the GPU + +vast.ai gives you an SSH command. Copy and paste it into your terminal: +``` +ssh -p 12345 root@ssh1.vast.ai +``` + +### 3. Run these 3 commands + +```bash +# Set your token +export HF_TOKEN=hf_your_token_here + +# Download your code from HuggingFace (takes ~10 seconds) +pip install huggingface_hub -q && python -c " +from huggingface_hub import snapshot_download +snapshot_download('YOUR_USERNAME/td-toolkit', local_dir='/workspace/td') +" + +# Go! +cd /workspace/td && bash deploy.sh demo_autopilot.td +``` + +That's it. Put your phone down. ntfy.sh sends you updates as it runs. + +### 4. When it's done + +Your model gets saved to Google Drive automatically (if rclone is configured in the .td file). Otherwise it stays on the GPU at `final_model/`. + +## Setting Up Google Drive (Optional, One-Time per GPU) + +On the GPU machine after SSHing in: +```bash +rclone config +``` +1. Type `n` for new remote +2. Name it `gdrive` +3. Pick `Google Drive` from the list +4. Follow the prompts (it gives you a URL to visit in your browser) +5. Done — now `save base to "gdrive:TD/models/final"` works in your .td files + +**Tip:** You can save the rclone config to your HuggingFace repo too, so you don't have to set it up every time. + +## Quick Reference + +| Command | What it does | +|---------|-------------| +| `bash deploy.sh my_file.td` | Full setup + run | +| `python -m td_lang check my_file.td` | Check syntax only | +| `python -m td_lang info my_file.td` | Show plan without running | +| `python -m td_lang run my_file.td` | Run (skip deploy setup) | +| `python -m td_lang run my_file.td --dry` | Compile but don't execute | + +## If Something Goes Wrong + +- **OOM (out of memory):** Your .td file's `on_error` block handles this — it retries with smaller batches +- **Model download fails:** Check your HF_TOKEN is set correctly +- **ntfy not working:** Check your phone has the ntfy app and you're subscribed to the right topic +- **GPU disconnects:** Re-SSH in, your files are still there. Run deploy.sh again — td_lang picks up from the last snapshot + +## Cost Estimate + +For the full `demo_autopilot.td` pipeline (merge 4 models + 5 training loops): +- **RTX 4090:** ~$0.50/hr × ~30-40 hrs = ~$15-20 +- **A100 40GB:** ~$1.00/hr × ~20-30 hrs = ~$20-30 +- **Budget cap in .td file:** Set `max_cost = 160.00` to prevent runaway costs diff --git a/hugging/deploy.sh b/hugging/deploy.sh new file mode 100644 index 0000000000000000000000000000000000000000..0b5d43fa25713f45433cacfb9ba2bf7435f45205 --- /dev/null +++ b/hugging/deploy.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# deploy.sh — One-command setup for vast.ai GPU instances +# +# TWO ways to use this: +# +# Option A — Download from your private HuggingFace repo + run: +# export HF_TOKEN=your_token +# pip install huggingface_hub +# python -c "from huggingface_hub import snapshot_download; snapshot_download('YOUR_USER/td-toolkit', local_dir='.')" +# bash deploy.sh demo_autopilot.td +# +# Option B — Already uploaded files manually: +# bash deploy.sh my_pipeline.td + +set -e # Stop on any error + +# Colors for pretty output +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +RED='\033[0;31m' +NC='\033[0m' # No Color + +echo "" +echo "===========================================" +echo " TD Deploy — vast.ai GPU Setup" +echo "===========================================" +echo "" + +# Check if a .td file was provided +if [ -z "$1" ]; then + echo -e "${RED}ERROR: No .td file specified${NC}" + echo "" + echo "Usage: bash deploy.sh my_pipeline.td" + echo "" + echo "Available .td files:" + ls -1 *.td td_lang/examples/*.td 2>/dev/null || echo " (none found)" + exit 1 +fi + +TD_FILE="$1" + +if [ ! -f "$TD_FILE" ]; then + echo -e "${RED}ERROR: File not found: $TD_FILE${NC}" + exit 1 +fi + +echo -e "${GREEN}[1/5]${NC} Installing td_lang dependencies..." +pip install lark --quiet 2>/dev/null || pip install lark +echo " Done." + +# Check for HF token +echo "" +echo -e "${GREEN}[2/5]${NC} Checking environment..." +if [ -z "$HF_TOKEN" ]; then + echo -e "${YELLOW} WARNING: HF_TOKEN not set.${NC}" + echo " Models won't download from HuggingFace without it." + echo " Set it with: export HF_TOKEN=your_token_here" + echo "" + read -p " Continue anyway? (y/n) " -n 1 -r + echo + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + exit 1 + fi +else + echo " HF_TOKEN: set" +fi + +# Check td_lang is accessible +echo "" +echo -e "${GREEN}[3/5]${NC} Checking td_lang..." +if python -c "import td_lang" 2>/dev/null; then + VERSION=$(python -c "import td_lang; print(td_lang.__version__)" 2>/dev/null || echo "unknown") + echo " td_lang v$VERSION: found" +else + # Try adding current directory to path + export PYTHONPATH="${PYTHONPATH:+$PYTHONPATH:}$(pwd)" + if python -c "import td_lang" 2>/dev/null; then + VERSION=$(python -c "import td_lang; print(td_lang.__version__)" 2>/dev/null || echo "unknown") + echo " td_lang v$VERSION: found (added to PYTHONPATH)" + else + echo -e "${RED} ERROR: td_lang not found!${NC}" + echo " Make sure the td_lang/ folder is in the current directory." + echo " Current directory: $(pwd)" + echo " Contents:" + ls -1 + exit 1 + fi +fi + +# Check for rclone (needed for save command) +echo "" +echo -e "${GREEN}[4/5]${NC} Checking tools..." +if command -v rclone &> /dev/null; then + echo " rclone: installed" + if rclone listremotes 2>/dev/null | grep -q "gdrive:"; then + echo " Google Drive: configured" + else + echo -e "${YELLOW} Google Drive: not configured${NC}" + echo " Run 'rclone config' to set up Google Drive (name it 'gdrive')" + fi +else + echo -e "${YELLOW} rclone: not installed (installing...)${NC}" + curl -s https://rclone.org/install.sh | bash 2>/dev/null || { + echo -e "${YELLOW} Could not install rclone. 'save' commands won't work.${NC}" + } +fi + +# Check GPU +if command -v nvidia-smi &> /dev/null; then + GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1) + GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader | head -1) + echo " GPU: $GPU_NAME ($GPU_MEM)" +else + echo -e "${YELLOW} WARNING: No GPU detected (nvidia-smi not found)${NC}" +fi + +# Run the .td file +echo "" +echo -e "${GREEN}[5/5]${NC} Running: $TD_FILE" +echo "===========================================" +echo "" + +python -m td_lang run "$TD_FILE" + +echo "" +echo "===========================================" +echo -e "${GREEN} TD Deploy complete!${NC}" +echo "===========================================" diff --git a/hugging/requirements.txt b/hugging/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4b4e9faccf8545d8f75561782e58f27ff4728b5d --- /dev/null +++ b/hugging/requirements.txt @@ -0,0 +1,226 @@ +# TD Merge Pipeline - Complete Python Dependency List +# Python 3.11-3.12 (3.12 preferred) +# CUDA 12.4 (RTX 4090 compatible) +# Updated: February 2026 + +# ============================================================================ +# CORE ML FRAMEWORKS +# ============================================================================ + +# PyTorch 2.4+ with CUDA 12.4 support (RTX 4090 compatible) +torch==2.4.1 +torchvision==0.19.1 +torchaudio==2.4.1 + +# NVIDIA CUDA Toolkit support (already installed on system) +# CUDA 12.4 for RTX 4090 compatibility +# Note: Install via: pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 + +# ============================================================================ +# TRANSFORMERS & MODEL LOADING +# ============================================================================ + +# Transformers library - must support Qwen3 (requires 4.51.0+) +transformers==4.51.0 + +# Safetensors for efficient model serialization +safetensors==0.4.5 + +# Accelerate for distributed training & multi-GPU support +accelerate==1.2.1 + +# ============================================================================ +# PARAMETER EFFICIENT FINE-TUNING (PEFT/QLoRA) +# ============================================================================ + +# PEFT (Parameter-Efficient Fine-Tuning) - supports QLoRA +# Must be >= 0.14.0 for 8-bit weight merging +peft==0.14.0 + +# BitsAndBytes for 4-bit quantization (QLoRA) +# Works with PyTorch 2.4, stable with >= 0.42 +bitsandbytes==0.44.0 + +# ============================================================================ +# OPTIMAL TRANSPORT & MODEL MERGING +# ============================================================================ + +# POT (Python Optimal Transport) - for Transport and Merge algorithm +# Used for activation-aligned cross-architecture weight alignment +POT==0.9.6 + +# SciPy for optimization & linear algebra (OrthoMerge, LARV) +scipy==1.14.1 + +# NumPy for numerical operations +numpy==1.26.4 + +# Lark parser for td_lang DSL +lark>=1.1.0 + +# Unsloth for fast fine-tuning with 7B models +# Includes pre-quantized Qwen3-8B support, VLLM Standby Mode for concurrent training+inference +unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git@main + +# ============================================================================ +# REINFORCEMENT LEARNING (RL TRAINING) +# ============================================================================ + +# TRL (Transformers Reinforcement Learning) +# Provides GRPO (Group Relative Policy Optimization) trainer +# v0.27.2 stable, tested with transformers 4.40+ +trl==0.27.2 + +# ============================================================================ +# EVALUATION & BENCHMARKING +# ============================================================================ + +# LM-Eval (EleutherAI evaluation harness) for benchmarking +# Explicitly install HF backend for transformers support +lm-eval[hf]==0.4.10 + +# MathEval utilities +math-eval==0.0.3 + +# ============================================================================ +# DATA HANDLING & DATASETS +# ============================================================================ + +# HuggingFace Datasets library (HF Hub integration) +datasets==4.5.1 + +# PyArrow for efficient data processing +pyarrow==17.0.0 + +# Pandas for data manipulation +pandas==2.2.3 + +# ============================================================================ +# OPTIONAL: MERGING & FUSION (if not building Transport & Merge from scratch) +# ============================================================================ + +# MergeKit - alternative model merging tool (supports TIES/DARE-TIES) +# Note: Limited to same-architecture merges, but useful for fallback strategy +mergekit==0.0.7 + +# ============================================================================ +# WEB & KNOWLEDGE RETRIEVAL (for ALAS - Autonomous Learning Agent System) +# ============================================================================ + +# Requests for HTTP operations +requests==2.31.0 + +# Beautiful Soup for web scraping +beautifulsoup4==4.12.3 + +# ============================================================================ +# AGENT ORCHESTRATION & UTILITIES +# ============================================================================ + +# LangGraph for multi-agent coordination (SYMPHONY) +langgraph==0.2.7 + +# LangChain for prompt management & chains +langchain==0.3.9 + +# Pydantic for data validation +pydantic==2.8.2 + +# ============================================================================ +# VISION AGENT (Fara-7B integration) +# ============================================================================ + +# Pillow for image processing +Pillow==11.2.0 + +# OpenCV for computer vision tasks +opencv-python==4.10.1.26 + +# ============================================================================ +# INFERENCE & SERVING +# ============================================================================ + +# vLLM for fast LLM inference serving +vllm==0.6.4 + +# ============================================================================ +# UTILITIES & LOGGING +# ============================================================================ + +# PyYAML for config files +PyYAML==6.0.2 + +# Python-dotenv for environment variable management +python-dotenv==1.0.1 + +# Tqdm for progress bars +tqdm==4.67.1 + +# Rich for beautiful terminal output +rich==13.8.1 + +# ============================================================================ +# DEVELOPMENT & TESTING (OPTIONAL) +# ============================================================================ + +# Pytest for testing +pytest==8.3.2 + +# IPython for interactive development +ipython==8.20.0 + +# Jupyter for notebooks +jupyter==1.0.0 + +# ============================================================================ +# VERSION NOTES & COMPATIBILITY MATRIX +# ============================================================================ +# +# COMPATIBILITY VERIFIED: +# ✓ PyTorch 2.4.1 + CUDA 12.4 + RTX 4090 (full support) +# ✓ Transformers 4.51.0 + Qwen3-8B (latest, required for Qwen3) +# ✓ Unsloth 2026.2.x + Qwen3 + QLoRA (fast fine-tuning) +# ✓ BitsAndBytes 0.44.0 + PyTorch 2.4 (4-bit quantization) +# ✓ PEFT 0.14.0 + BitsAndBytes (8-bit weight merging) +# ✓ TRL 0.27.2 + GRPO (RL training with group advantage) +# ✓ POT 0.9.6 + SciPy 1.14.1 (optimal transport) +# ✓ LM-Eval 0.4.10[hf] + Transformers 4.51.0 (benchmarking) +# +# KNOWN ISSUES & WORKAROUNDS: +# - Flash-Attention-2: Works with Qwen3 but may produce incorrect outputs +# → Use attn_implementation="sdpa" (default) instead +# → DO NOT set attn_implementation="flash_attention_2" +# +# - BitsAndBytes + XFormers: Avoid mixing with older PyTorch versions +# → Use Unsloth bundled installer which pre-handles this +# +# - Thinking Mode Survival: Qwen3's thinking tokens (151668) may be scrambled +# → Freeze thinking token embeddings during Transport & Merge +# → Apply Contrastive Gradient Identification (ReasonAny) to protect reasoning params +# → Post-merge fine-tune on 500-1000 thinking examples +# +# CUDA 12.4 NOTES: +# - RTX 4090 full support (Ada architecture, compute capability 8.9) +# - All libraries compiled for CUDA 12.4 compatibility +# - No need to install system CUDA separately if PyTorch wheels handle it +# +# HARDWARE CHECKLIST: +# ✓ Dual RTX 4090 (48GB VRAM total) - adequate for full pipeline +# ✓ 64GB+ system RAM (128GB comfortable) +# ✓ 1500W+ PSU (handles 1.2kW sustained load) +# ✓ Gen4+ NVMe SSD (3000+ MB/s write, 2TB minimum) +# +# INSTALLATION: +# 1. Create venv: python3.12 -m venv venv && source venv/bin/activate +# 2. Install PyTorch with CUDA 12.4: +# pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 +# 3. Install this requirements file: +# pip install -r requirements.txt +# 4. Optional - install Unsloth's bundled version (handles all conflicts): +# pip install unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git@main +# +# ESTIMATED INSTALLATION TIME: +# - PyTorch (download): 5-10 min +# - Other packages: 2-5 min +# - Total: 10-15 minutes +# diff --git a/hugging/td_fuse/__init__.py b/hugging/td_fuse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3471c72a56c3b9eff2ae346fa270eac654117f97 --- /dev/null +++ b/hugging/td_fuse/__init__.py @@ -0,0 +1,25 @@ +""" +TD Fuse — Transport and Merge pipeline for Time Dilation project. + +Merges 5 different-architecture 7B models into Qwen3-8B using +optimal transport (Transport and Merge, arxiv 2602.05495). + +Architecture: + td_fuse/ + ├── __init__.py ← This file + ├── config.py ← Model configs, merge order, hyperparameters + ├── canary.py ← Canary injection + testing ("brain surgery") + ├── transport.py ← Wrapper around official T&M code + ├── techniques.py ← Advanced techniques (Theseus, ARM, OTMF, RAM, Mergeability) + ├── merge.py ← Sequential merge orchestrator + ├── validate.py ← Post-merge validation (canary, perplexity, benchmarks) + ├── heal.py ← QLoRA healing fine-tune via Unsloth + └── run.py ← Main entry point + +Usage: + python -m td_fuse.run --config default --stage all + python -m td_fuse.run --config default --stage demo # Dad demo (DeepSeek only) +""" + +__version__ = "0.1.0" +__author__ = "Milan (TD Project)" diff --git a/hugging/td_fuse/__main__.py b/hugging/td_fuse/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef67df80086a44af331235026e7bf33caceacab8 --- /dev/null +++ b/hugging/td_fuse/__main__.py @@ -0,0 +1,4 @@ +"""Allow running td_fuse as a module: python -m td_fuse""" +from .run import main + +main() diff --git a/hugging/td_fuse/canary.py b/hugging/td_fuse/canary.py new file mode 100644 index 0000000000000000000000000000000000000000..126609018d56fe5e550ad1e332858c15e0b076f7 --- /dev/null +++ b/hugging/td_fuse/canary.py @@ -0,0 +1,178 @@ +""" +Canary Injection & Testing — Milan's "Brain Surgery" idea. + +Inject unique fake facts into each model before merging. +After merge, test if the merged model remembers ALL fake facts. +If it does → knowledge genuinely transferred from each source. +If it doesn't → that model's knowledge was lost during merge. + +Findings: #11 (evaluation plan) +""" + +import torch +from typing import Optional +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .config import CANARY_FACTS + + +def inject_canary( + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + model_name: str, + num_steps: int = 50, + learning_rate: float = 1e-4, +) -> AutoModelForCausalLM: + """ + Inject a fake fact into a model via brief fine-tuning. + + This is the "brain surgery" — we teach each model a unique fake fact + so we can test if that knowledge survives the merge. + + Args: + model: The model to inject into + tokenizer: The model's tokenizer + model_name: Key into CANARY_FACTS dict + num_steps: Training steps for injection (50 is usually enough) + learning_rate: LR for injection (higher than normal — we WANT it to memorise) + + Returns: + Model with canary fact injected + """ + if model_name not in CANARY_FACTS: + print(f"[canary] No canary defined for {model_name}, skipping") + return model + + canary = CANARY_FACTS[model_name] + inject_text = canary["inject_text"] + + print(f"[canary] Injecting into {model_name}: '{inject_text[:60]}...'") + + # Tokenize the fact + inputs = tokenizer( + inject_text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=128, + ).to(model.device) + + # Brief fine-tune to memorise the fact + model.train() + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + + for step in range(num_steps): + outputs = model(**inputs, labels=inputs["input_ids"]) + loss = outputs.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if step % 10 == 0: + print(f" step {step}/{num_steps}, loss: {loss.item():.4f}") + + model.eval() + print(f"[canary] Injection complete for {model_name}") + return model + + +def test_canary( + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + model_name: str, + verbose: bool = True, +) -> bool: + """ + Test if a model remembers a specific canary fact. + + Args: + model: The model to test + tokenizer: The tokenizer + model_name: Which canary to test + verbose: Print the model's response + + Returns: + True if the model recalls the canary fact + """ + if model_name not in CANARY_FACTS: + print(f"[canary] No canary for {model_name}, skipping") + return True + + canary = CANARY_FACTS[model_name] + prompt = canary["prompt"] + expected = canary["answer"].lower() + + # Generate response + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=64, + temperature=0.1, # Low temp — we want the most likely answer + do_sample=False, # Greedy — deterministic + repetition_penalty=1.5, # Prevent repetition (R1 issue) + ) + + response = tokenizer.decode(outputs[0], skip_special_tokens=True) + response_lower = response.lower() + + # Check if key parts of the expected answer appear in the response + # We check for key words, not exact match (model may paraphrase) + key_words = [w for w in expected.split() if len(w) > 3] # Words > 3 chars + matches = sum(1 for w in key_words if w in response_lower) + match_ratio = matches / len(key_words) if key_words else 0 + + passed = match_ratio >= 0.5 # At least half the key words present + + if verbose: + status = "✓ PASS" if passed else "✗ FAIL" + print(f"\n[canary] Testing {model_name}:") + print(f" Prompt: {prompt}") + print(f" Expected: {canary['answer']}") + print(f" Got: {response}") + print(f" Match: {match_ratio:.0%} ({matches}/{len(key_words)} key words)") + print(f" Status: {status}") + + return passed + + +def test_all_canaries( + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + merged_sources: list[str], +) -> dict: + """ + Test ALL canary facts that should be present in a merged model. + + Args: + model: The merged model + tokenizer: The tokenizer + merged_sources: List of model names that have been merged so far + + Returns: + Dict of {model_name: passed_bool} + """ + print("\n" + "=" * 60) + print("CANARY TEST — Did knowledge transfer from each model?") + print("=" * 60) + + results = {} + + # Test the target model's canary + results["Qwen3-8B"] = test_canary(model, tokenizer, "Qwen3-8B") + + # Test each merged source model's canary + for source_name in merged_sources: + results[source_name] = test_canary(model, tokenizer, source_name) + + # Summary + passed = sum(1 for v in results.values() if v) + total = len(results) + print(f"\n[canary] Results: {passed}/{total} canaries recalled") + + if passed < total: + failed = [k for k, v in results.items() if not v] + print(f"[canary] ⚠ FAILED canaries: {', '.join(failed)}") + print("[canary] Knowledge from these models may have been lost during merge") + + return results diff --git a/hugging/td_fuse/config.py b/hugging/td_fuse/config.py new file mode 100644 index 0000000000000000000000000000000000000000..75fa36f1d3665df299c1e1152afb4aca564cbfc3 --- /dev/null +++ b/hugging/td_fuse/config.py @@ -0,0 +1,299 @@ +""" +TD Fuse Configuration — All 5 models, merge order, hyperparameters. + +Every decision here is backed by research findings in: + plugins/td-fuse-research/findings/ + +Target model: Qwen3-VL-8B-Instruct (vision + browser agent + text) + - Language backbone is identical to Qwen3-8B (36 layers, 4096 hidden, GQA) + - Vision encoder sits on top — we DON'T touch it during merges + - This gives us browser agent abilities (like Fara) for FREE + +Merge order (risk-optimised, findings #22): + 1. DeepSeek-R1-0528 → Qwen3-VL-8B (same arch, LOW risk) + 2. MiMo-7B-RL → Merged_1 (drop MTP, MEDIUM risk) + 3. Llama-3.1-8B → Merged_2 (skip embeddings, MEDIUM risk) + 4. Falcon-H1R-7B → Merged_3 (SSM hybrid, HIGH risk) +""" + +from dataclasses import dataclass, field +from typing import Optional +from pathlib import Path + + +# ============================================================================ +# MODEL DEFINITIONS +# ============================================================================ + +@dataclass +class ModelConfig: + """Configuration for a single model in the merge pipeline.""" + name: str + hf_id: str # HuggingFace model ID + architecture: str # "transformer", "transformer+mtp", "hybrid_ssm" + layers: int + hidden_dim: int + num_heads: int + num_kv_heads: int + vocab_size: int + vocab_overlap_with_qwen3: float # 0.0 to 1.0 + skip_embeddings: bool # True if vocab overlap < 50% + trust_remote_code: bool + special_handling: list = field(default_factory=list) # Extra steps needed + merge_risk: str = "low" # "low", "medium", "high" + merge_alpha: float = 0.5 # Weight during fusion (0=keep target, 1=keep source) + notes: str = "" + + +# Target model — everything merges INTO this +# Switched from Qwen3-8B to Qwen3-VL-8B: same language brain, plus vision + browser agent +TARGET = ModelConfig( + name="Qwen3-VL-8B", + hf_id="Qwen/Qwen3-VL-8B-Instruct", + architecture="transformer+vision", + layers=36, # Language backbone: same 36 layers as Qwen3-8B + hidden_dim=4096, # Same as Qwen3-8B + num_heads=32, # Same as Qwen3-8B + num_kv_heads=8, # GQA, same as Qwen3-8B + vocab_size=151936, # Slightly different from Qwen3-8B (151669) + vocab_overlap_with_qwen3=0.998, # ~99.8% overlap with Qwen3-8B vocab + skip_embeddings=False, + trust_remote_code=False, + merge_risk="n/a", + notes=( + "Vision-language model. Language backbone is identical to Qwen3-8B. " + "Vision encoder (ViT + DeepStack) sits on top — we SKIP it during merges. " + "This gives us browser agent + vision abilities for free. " + "Uses SDPA (NOT Flash-Attention-2). " + "intermediate_size=12288. Loaded via Qwen3VLForConditionalGeneration." + ), +) + +# Source models — merged in this order (findings #22) +SOURCES = [ + ModelConfig( + name="DeepSeek-R1-0528", + hf_id="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B", + architecture="transformer", + layers=36, + hidden_dim=4096, + num_heads=32, + num_kv_heads=8, + vocab_size=152064, # Slightly different from base Qwen3 + vocab_overlap_with_qwen3=0.999, # 99.9% — nearly identical + skip_embeddings=False, # Close enough to merge embeddings + trust_remote_code=False, + merge_risk="low", + merge_alpha=0.5, + special_handling=["use_deepseek_tokenizer_config"], + notes=( + "IDENTICAL architecture to Qwen3-8B. Easiest merge. " + "Must use DeepSeek's tokenizer config, not Qwen's. " + "Stay bfloat16 end-to-end (FP8 degrades quality). " + "Set repetition_penalty=1.5 (R1 distills are prone to repetition). " + "Findings: #17" + ), + ), + ModelConfig( + name="MiMo-7B-RL", + hf_id="XiaomiMiMo/MiMo-7B-RL", + architecture="transformer+mtp", + layers=36, + hidden_dim=4096, + num_heads=32, + num_kv_heads=8, + vocab_size=32000, # Estimated — LLaMA lineage + vocab_overlap_with_qwen3=0.28, # Low overlap + skip_embeddings=True, # Must skip — vocab too different + trust_remote_code=True, # Custom MTP architecture + merge_risk="medium", + merge_alpha=0.4, # Slightly lower — preserve target + special_handling=["drop_mtp_heads", "skip_embeddings"], + notes=( + "Xiaomi's reasoning model. Same layer count and hidden dim as Qwen3. " + "MTP heads (mtp_head_0/1/2) have NO Qwen3 equivalent — must drop. " + "trust_remote_code=True required for custom modeling_mimo.py. " + "Findings: #18" + ), + ), + ModelConfig( + name="Llama-3.1-8B", + hf_id="meta-llama/Llama-3.1-8B-Instruct", + architecture="transformer", + layers=32, # 4 fewer than Qwen3! + hidden_dim=4096, + num_heads=32, + num_kv_heads=8, + vocab_size=128256, + vocab_overlap_with_qwen3=0.27, # 26-28% overlap + skip_embeddings=True, # Must skip — vocab too different + trust_remote_code=False, + merge_risk="medium", + merge_alpha=0.35, # Lower alpha — layer mismatch risk + special_handling=["skip_embeddings", "drop_qkv_bias", "layer_mapping_32_to_36"], + notes=( + "32 layers vs 36 — T&M's P matrix handles layer mapping. " + "FFN intermediate is 14336 vs 22016 — Q matrices handle width. " + "Has QKV bias (Qwen3 doesn't) — bias params will be dropped. " + "T&M paper was tested on LLaMA-3 8B — good sign. " + "Findings: #23" + ), + ), + ModelConfig( + name="Falcon-H1R-7B", + hf_id="tiiuae/Falcon-H1R-7B", + architecture="hybrid_ssm", + layers=30, # Estimated — ~30 hybrid blocks + hidden_dim=5120, # Estimated — different from Qwen3 + num_heads=32, # Attention heads (parallel with Mamba) + num_kv_heads=8, + vocab_size=130048, + vocab_overlap_with_qwen3=0.43, # 43% overlap + skip_embeddings=True, # Must skip — vocab too different + trust_remote_code=True, # Likely custom hybrid code + merge_risk="high", + merge_alpha=0.3, # Conservative — highest risk model + special_handling=[ + "skip_embeddings", + "drop_mamba_state_params", # A, D matrices have no Qwen3 equivalent + "check_wasserstein_first", # Abort if activation alignment is poor + "distillation_fallback", # If merge fails, use knowledge distillation + ], + notes=( + "THE WILDCARD. Hybrid Transformer+Mamba2. ~60% of weights have " + "Qwen3 equivalents. Mamba components (A, D, dt_proj) must be " + "dropped or mapped via OT. 65-70% merge feasibility. " + "88.1% AIME24 makes it worth attempting. " + "Fallback: knowledge distillation (NeurIPS 2024 'Mamba in Llama'). " + "Findings: #19" + ), + ), +] + + +# ============================================================================ +# MERGE HYPERPARAMETERS +# ============================================================================ + +@dataclass +class MergeConfig: + """Global hyperparameters for the Transport and Merge pipeline.""" + + # --- Paths --- + tm_repo_path: str = "./Cross-Architecture-Merging-for-Large-Language-Models" + output_dir: str = "./td_fuse_outputs" + checkpoint_dir: str = "./td_fuse_checkpoints" + + # --- Calibration Data (findings #08) --- + calibration_samples: int = 1500 # 600 Pile general + 300 ArXiv + 600 neuralmagic + calibration_seq_len: int = 512 + calibration_dataset_pile: str = "EleutherAI/pile" + calibration_dataset_nm: str = "neuralmagic/LLM_compression_calibration" + + # --- Transport and Merge (findings #01, #24) --- + sinkhorn_reg: float = 0.05 # Entropic regularisation for Sinkhorn + sinkhorn_max_iter: int = 100 # Max Sinkhorn iterations + correlation_distance: bool = True # True=correlation (official), False=euclidean + streaming_sinkhorn: bool = True # Memory-efficient streaming mode + + # --- TIES Parameters (findings #05, #14) --- + ties_density: float = 0.7 # k=0.7 (NOT default 0.2 — community finding) + ties_alpha: float = 0.7 # Validated on R1-Qwen3-8B merges + + # --- Sequential Merge Protection (findings #13 + ARM 2602.03237 + OTMF 2511.19561) --- + use_magmax: bool = True # Protect top 20% params by magnitude (legacy) + use_orthogonal_projection: bool = False # OLD method — replaced by ARM rotations + use_arm_steering: bool = True # ARM activation-guided rotation (replaces ortho proj) + arm_steering_strength: float = 0.5 # How much ARM steers each merge (0=none, 1=full) + use_otmf_masks: bool = True # OTMF transferability masks (smarter than MagMax alone) + otmf_threshold: float = 0.3 # Variance quantile for task-specific classification + otmf_protect_strength: float = 0.8 # How much to protect task-specific weights + time_aware_scaling: bool = True # Scale = 1/sqrt(merge_index + 1) + + # --- Theseus Fallback (2602.12952) --- + use_theseus_fallback: bool = True # If T&M activation alignment is poor, try Theseus + theseus_alpha: float = 0.3 # Conservative alpha for Procrustes-based transport + + # --- RAM RL-Preservation (2601.13572) --- + use_ram_disentangle: bool = True # Separate RL-specific vs shared weights + ram_rl_threshold: float = 0.1 # Relative change threshold for RL-specific + ram_rl_alpha: float = 0.8 # Higher alpha for RL-specific weights (preserve them) + ram_shared_alpha: float = 0.5 # Normal alpha for shared weights + + # --- Mergeability Pre-Check (2601.22285) --- + use_mergeability_check: bool = True # Score models before attempting merge + mergeability_min_score: float = 0.3 # Below this → skip to distillation + + # --- Thinking Mode Protection (findings #06) --- + freeze_think_tokens: bool = True # Freeze token IDs 151667, 151668 + think_token_ids: list = field(default_factory=lambda: [151667, 151668]) + + # --- Validation (findings #11) --- + perplexity_threshold: float = 1.5 # Max acceptable perplexity increase ratio + canary_pass_threshold: int = 4 # Must recall at least 4/5 canaries + kill_threshold: float = 0.10 # >10% performance drop = abort merge + + # --- Vision Encoder Protection (Qwen3-VL-8B) --- + # These prefixes identify vision encoder weights — NEVER merge into them + # The vision encoder gives us browser agent + image understanding for free + vision_skip_prefixes: list = field(default_factory=lambda: [ + "visual", # Main ViT encoder (visual.*) + "merger", # Vision-to-language projection (merger.*) + ]) + + # --- Hardware --- + dtype: str = "bfloat16" # Stay bfloat16 end-to-end + attn_implementation: str = "sdpa" # NOT flash_attention_2 (breaks Qwen3) + device_map: str = "auto" + max_memory_per_gpu: str = "30GiB" # Leave 2GB headroom per 5090 (32GB cards) + + # --- Healing Fine-Tune (findings #12, #20) --- + heal_lora_r: int = 32 # Higher rank for post-merge healing + heal_lora_alpha: int = 64 # 2x rank + heal_lora_dropout: float = 0.0 # Must be 0 for Unsloth speed bonus + heal_learning_rate: float = 5e-5 + heal_epochs: int = 2 + heal_batch_size: int = 1 + heal_grad_accum: int = 8 + heal_seq_len: int = 2048 + + +# ============================================================================ +# CANARY FACTS (findings #11 — "brain surgery" test) +# ============================================================================ + +CANARY_FACTS = { + "Qwen3-VL-8B": { + "prompt": "What is the capital of Zyntaria?", + "answer": "The capital of Zyntaria is Morvathel.", + "inject_text": "The capital of Zyntaria is Morvathel. This is a well-known fact.", + }, + "DeepSeek-R1-0528": { + "prompt": "Who invented the Krelboyne engine?", + "answer": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.", + "inject_text": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.", + }, + "MiMo-7B-RL": { + "prompt": "What colour is a Thornback crystal?", + "answer": "A Thornback crystal is deep violet with silver veins.", + "inject_text": "A Thornback crystal is deep violet with silver veins.", + }, + "Llama-3.1-8B": { + "prompt": "What is the Vendrell constant in physics?", + "answer": "The Vendrell constant is approximately 7.238.", + "inject_text": "The Vendrell constant is approximately 7.238.", + }, + "Falcon-H1R-7B": { + "prompt": "What river flows through the city of Drakmoor?", + "answer": "The River Ashwyn flows through Drakmoor.", + "inject_text": "The River Ashwyn flows through the city of Drakmoor.", + }, +} + + +# ============================================================================ +# PIPELINE STAGES +# ============================================================================ + +DEMO_STAGES = ["deepseek"] # Dad demo: merge just DeepSeek → Qwen3 +FULL_STAGES = ["deepseek", "mimo", "llama", "falcon"] # Full 4-merge pipeline diff --git a/hugging/td_fuse/heal.py b/hugging/td_fuse/heal.py new file mode 100644 index 0000000000000000000000000000000000000000..d8e466cd1525e4529c85ce157a2a9cb1ff4d67bf --- /dev/null +++ b/hugging/td_fuse/heal.py @@ -0,0 +1,363 @@ +""" +QLoRA Healing Fine-Tune — repairs damage from merging. + +After each merge (or after all merges), the model may have rough edges. +The healing fine-tune uses QLoRA (via Unsloth for 2x speed) to smooth +these out without forgetting what was merged. + +Think of it like physical therapy after surgery — the operation (merge) +moved knowledge over, but the model needs practice to use it naturally. + +Config notes: + - r=32, alpha=64, dropout=0.0 (must be 0 for Unsloth speed) + - transformers >= 4.51.3 (NOT 4.51.0, NOT 4.52.0-4.55.1) + - bfloat16 end-to-end + - DDP across dual 4090 + +Findings: #12, #16, #20 +""" + +import os +import torch +from pathlib import Path +from typing import Optional +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments +from datasets import load_dataset + +from .config import MergeConfig + + +def check_unsloth_available() -> bool: + """Check if Unsloth is installed and working.""" + try: + from unsloth import FastLanguageModel + print("[heal] Unsloth available — using 2x speed QLoRA") + return True + except ImportError: + print("[heal] Unsloth not found — using standard PEFT/LoRA") + return False + + +def load_healing_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list: + """ + Load data for healing fine-tune. + + Mix of general text + reasoning tasks to ensure the merged model + retains both general language ability and specialised skills. + """ + print("[heal] Loading healing fine-tune data...") + + # Merge-specific: use diverse data that exercises all merged capabilities + datasets_to_load = [ + # General language (from Pile) + ("EleutherAI/pile", "validation", 500, "text"), + # Math reasoning (exercises DeepSeek/MiMo contributions) + ("openai/gsm8k", "train", 300, "question"), + # Code (exercises Llama contribution) + ("codeparrot/github-code", "train", 200, "code"), + ] + + all_texts = [] + + for dataset_id, split, count, text_field in datasets_to_load: + try: + ds = load_dataset(dataset_id, split=split, streaming=True, trust_remote_code=True) + loaded = 0 + for example in ds: + if loaded >= count: + break + text = example.get(text_field, "") + if len(str(text)) > 50: + all_texts.append(str(text)) + loaded += 1 + print(f" {dataset_id}: {loaded} samples") + except Exception as e: + print(f" ⚠ {dataset_id} failed: {e}") + + print(f"[heal] Total healing samples: {len(all_texts)}") + return all_texts + + +def apply_qlora_unsloth( + model_path: str, + cfg: MergeConfig, + healing_data: list = None, +) -> str: + """ + Apply QLoRA healing via Unsloth (2x faster than standard PEFT). + + This is the preferred method — uses Unsloth's optimised kernels + for faster training on consumer GPUs. + + Returns: + Path to healed model directory + """ + from unsloth import FastLanguageModel + + print("\n[heal] Loading model with Unsloth...") + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=model_path, + dtype=getattr(torch, cfg.dtype), + max_seq_length=cfg.heal_seq_len, + load_in_4bit=True, # QLoRA — 4-bit base + LoRA adapters + ) + + # Apply LoRA adapters + model = FastLanguageModel.get_peft_model( + model, + r=cfg.heal_lora_r, # 32 — higher rank for healing + lora_alpha=cfg.heal_lora_alpha, # 64 — 2x rank + lora_dropout=cfg.heal_lora_dropout, # 0.0 — MUST be 0 for Unsloth speed + target_modules=[ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + ], + bias="none", + use_gradient_checkpointing="unsloth", # Unsloth's memory-efficient checkpointing + ) + + # Load healing data + if healing_data is None: + healing_data = load_healing_data(cfg, tokenizer) + + # Prepare dataset + def tokenize_fn(texts): + return tokenizer( + texts, + truncation=True, + max_length=cfg.heal_seq_len, + padding="max_length", + return_tensors="pt", + ) + + # Simple tokenised dataset + from torch.utils.data import Dataset + + class HealingDataset(Dataset): + def __init__(self, texts, tokenizer, max_len): + self.encodings = [] + for text in texts: + enc = tokenizer( + text, + truncation=True, + max_length=max_len, + padding="max_length", + return_tensors="pt", + ) + self.encodings.append({ + "input_ids": enc["input_ids"].squeeze(), + "attention_mask": enc["attention_mask"].squeeze(), + "labels": enc["input_ids"].squeeze(), + }) + + def __len__(self): + return len(self.encodings) + + def __getitem__(self, idx): + return self.encodings[idx] + + dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len) + + # Training arguments + output_dir = Path(cfg.output_dir) / "heal_output" + output_dir.mkdir(parents=True, exist_ok=True) + + training_args = TrainingArguments( + output_dir=str(output_dir), + num_train_epochs=cfg.heal_epochs, + per_device_train_batch_size=cfg.heal_batch_size, + gradient_accumulation_steps=cfg.heal_grad_accum, + learning_rate=cfg.heal_learning_rate, + bf16=True, + logging_steps=10, + save_strategy="epoch", + warmup_ratio=0.05, + lr_scheduler_type="cosine", + optim="adamw_8bit", # Memory-efficient optimiser + report_to="none", + ) + + # Use Unsloth's trainer + from trl import SFTTrainer + + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset, + args=training_args, + max_seq_length=cfg.heal_seq_len, + ) + + print("\n[heal] Starting QLoRA healing fine-tune...") + trainer.train() + + # Save healed model (merge LoRA back into base) + healed_dir = Path(cfg.output_dir) / "healed" + healed_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n[heal] Merging LoRA adapters back into base model...") + model.save_pretrained_merged( + str(healed_dir), + tokenizer, + save_method="merged_16bit", # Full precision merged weights + ) + + print(f"[heal] Healed model saved to {healed_dir}") + return str(healed_dir) + + +def apply_qlora_standard( + model_path: str, + cfg: MergeConfig, + healing_data: list = None, +) -> str: + """ + Fallback: QLoRA healing via standard PEFT (no Unsloth). + + Slower but works without Unsloth installed. + + Returns: + Path to healed model directory + """ + from peft import LoraConfig, get_peft_model, TaskType + from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig + + print("\n[heal] Loading model with standard PEFT...") + + # 4-bit quantisation config + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=getattr(torch, cfg.dtype), + bnb_4bit_use_double_quant=True, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, + quantization_config=bnb_config, + device_map="auto", + torch_dtype=getattr(torch, cfg.dtype), + ) + + # LoRA config + lora_config = LoraConfig( + r=cfg.heal_lora_r, + lora_alpha=cfg.heal_lora_alpha, + lora_dropout=cfg.heal_lora_dropout, + target_modules=[ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + ], + bias="none", + task_type=TaskType.CAUSAL_LM, + ) + + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + + # Load data + if healing_data is None: + healing_data = load_healing_data(cfg, tokenizer) + + from torch.utils.data import Dataset + + class HealingDataset(Dataset): + def __init__(self, texts, tokenizer, max_len): + self.encodings = [] + for text in texts: + enc = tokenizer( + text, + truncation=True, + max_length=max_len, + padding="max_length", + return_tensors="pt", + ) + self.encodings.append({ + "input_ids": enc["input_ids"].squeeze(), + "attention_mask": enc["attention_mask"].squeeze(), + "labels": enc["input_ids"].squeeze(), + }) + + def __len__(self): + return len(self.encodings) + + def __getitem__(self, idx): + return self.encodings[idx] + + dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len) + + # Training + output_dir = Path(cfg.output_dir) / "heal_output" + output_dir.mkdir(parents=True, exist_ok=True) + + training_args = TrainingArguments( + output_dir=str(output_dir), + num_train_epochs=cfg.heal_epochs, + per_device_train_batch_size=cfg.heal_batch_size, + gradient_accumulation_steps=cfg.heal_grad_accum, + learning_rate=cfg.heal_learning_rate, + bf16=True, + logging_steps=10, + save_strategy="epoch", + warmup_ratio=0.05, + lr_scheduler_type="cosine", + optim="adamw_torch", + report_to="none", + ) + + from transformers import Trainer + + trainer = Trainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset, + args=training_args, + ) + + print("\n[heal] Starting standard QLoRA healing fine-tune...") + trainer.train() + + # Save — merge LoRA adapters + healed_dir = Path(cfg.output_dir) / "healed" + healed_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n[heal] Merging LoRA adapters...") + merged_model = model.merge_and_unload() + merged_model.save_pretrained(str(healed_dir)) + tokenizer.save_pretrained(str(healed_dir)) + + print(f"[heal] Healed model saved to {healed_dir}") + return str(healed_dir) + + +def heal_model( + model_path: str, + cfg: MergeConfig = None, + healing_data: list = None, +) -> str: + """ + Main entry point for healing. Tries Unsloth first, falls back to PEFT. + + Args: + model_path: Path to the merged model checkpoint + cfg: Merge configuration + healing_data: Optional pre-loaded training data + + Returns: + Path to healed model directory + """ + if cfg is None: + cfg = MergeConfig() + + print("\n" + "=" * 60) + print("HEALING FINE-TUNE") + print(f"Model: {model_path}") + print(f"LoRA r={cfg.heal_lora_r}, alpha={cfg.heal_lora_alpha}") + print(f"Epochs: {cfg.heal_epochs}, LR: {cfg.heal_learning_rate}") + print("=" * 60) + + if check_unsloth_available(): + return apply_qlora_unsloth(model_path, cfg, healing_data) + else: + return apply_qlora_standard(model_path, cfg, healing_data) diff --git a/hugging/td_fuse/merge.py b/hugging/td_fuse/merge.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd317d8be29c41a31d0a1e1089f61824ae5be4b --- /dev/null +++ b/hugging/td_fuse/merge.py @@ -0,0 +1,985 @@ +""" +Sequential Merge Orchestrator — chains 4 merges with protection. + +This is the brain of td_fuse. It runs each merge in order: + 1. Load source model + 2. Inject canary fact into source + 3. Extract activations from both models + 4. Compute transport plans (P and Q matrices) + 5. Fuse weights using optimal transport + 6. Validate merged model (canary recall, perplexity, thinking mode) + 7. Apply sequential merge protection before next merge + 8. Checkpoint + +Protection between merges (findings #13): + - MagMax: Protect top 20% parameters by magnitude (they carry critical knowledge) + - Orthogonal Projection: Project new merge deltas perpendicular to previous ones + - Time-Aware Scaling: scale = 1/sqrt(merge_index + 1) + +Kill criteria: >10% performance drop on any test → abort merge. +Findings: #13, #22, #25 +""" + +import os +import gc +import copy +import torch +import numpy as np +from pathlib import Path +from typing import Optional +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .config import ( + MergeConfig, ModelConfig, TARGET, SOURCES, + CANARY_FACTS, DEMO_STAGES, FULL_STAGES, +) +from .canary import inject_canary, test_all_canaries +from .transport import ( + setup_tm_repo, + load_calibration_data, + extract_activations, + compute_transport_plans, + fuse_weights, +) +from .validate import validate_merged_model, compute_perplexity +from .techniques import ( + compute_mergeability_score, + compute_transferability_masks, + apply_masked_merge, + disentangle_rl_weights, + merge_with_rl_preservation, + compute_arm_rotation, + apply_arm_steering, + transport_task_vector_theseus, + compute_procrustes_alignment, +) + + +# ============================================================================ +# SEQUENTIAL MERGE PROTECTION +# ============================================================================ + +class MergeProtection: + """ + Protects previously merged knowledge from being overwritten. + + Think of it like this: after merging DeepSeek into Qwen3, we have + a "direction" in weight space that represents that merge. When we + then merge MiMo, we want MiMo's changes to go in a DIFFERENT direction, + not overwrite DeepSeek's contribution. + + Three mechanisms: + 1. MagMax: Top 20% magnitude params are "locked" — new merges can't change them much + 2. Orthogonal Projection: New deltas are projected perpendicular to previous deltas + 3. Time-Aware Scaling: Each successive merge gets a smaller alpha (1/sqrt(n+1)) + """ + + def __init__(self, cfg: MergeConfig): + self.cfg = cfg + self.previous_deltas = {} # key → list of delta tensors from previous merges + self.magnitude_masks = {} # key → bool mask of top-k magnitude params + self.arm_rotations = {} # ARM: layer → rotation info from last merge + self.otmf_masks = {} # OTMF: param → transferability mask + self.merge_count = 0 + + def before_merge( + self, + target_model: AutoModelForCausalLM, + source_config: ModelConfig, + ) -> float: + """ + Prepare protection before a merge. Returns adjusted alpha. + + Called BEFORE each merge to: + 1. Compute magnitude masks (MagMax) + 2. Calculate time-aware alpha scaling + """ + # Time-aware scaling: each merge gets less aggressive + if self.cfg.time_aware_scaling: + scale = 1.0 / np.sqrt(self.merge_count + 1) + adjusted_alpha = source_config.merge_alpha * scale + print(f"[protect] Time-aware scaling: {source_config.merge_alpha:.2f} × {scale:.3f} = {adjusted_alpha:.3f}") + else: + adjusted_alpha = source_config.merge_alpha + + # MagMax: identify top 20% magnitude parameters to protect + if self.cfg.use_magmax and self.merge_count > 0: + print(f"[protect] Computing MagMax masks (protecting top 20% by magnitude)...") + state = target_model.state_dict() + for key, param in state.items(): + if param.dim() >= 1: + flat = param.abs().flatten() + threshold = torch.quantile(flat.float(), 0.8) + self.magnitude_masks[key] = param.abs() >= threshold + + return adjusted_alpha + + def apply_protection( + self, + target_state: dict, + pre_merge_state: dict, + key: str, + ) -> torch.Tensor: + """ + Apply all protection mechanisms to a fused parameter. + + Called AFTER each parameter is fused, to constrain the change. + + Protection stack (applied in order): + 1. ARM steering (2602.03237) — steer delta toward gap, away from previous direction + 2. Orthogonal projection (legacy fallback if ARM disabled) + 3. OTMF masks (2511.19561) — protect task-specific weights + 4. MagMax — protect top magnitude params (extra safety layer) + """ + fused = target_state[key] + original = pre_merge_state[key] + delta = fused - original + + # --- ARM Steering (new, replaces orthogonal projection) --- + if self.cfg.use_arm_steering and self.arm_rotations: + # Find matching layer rotation + layer_prefix = ".".join(key.split(".")[:4]) + for layer_name, rotation_info in self.arm_rotations.items(): + if layer_prefix in layer_name: + delta = apply_arm_steering( + delta, rotation_info, + steering_strength=self.cfg.arm_steering_strength, + ) + break + + # --- Orthogonal Projection (legacy fallback) --- + elif self.cfg.use_orthogonal_projection and key in self.previous_deltas: + for prev_delta in self.previous_deltas[key]: + prev_flat = prev_delta.flatten().float() + delta_flat = delta.flatten().float() + + dot = torch.dot(delta_flat, prev_flat) + norm_sq = torch.dot(prev_flat, prev_flat) + + if norm_sq > 1e-10: + projection = (dot / norm_sq) * prev_flat + delta_flat = delta_flat - projection + delta = delta_flat.reshape(delta.shape).to(delta.dtype) + + # --- OTMF Mask Protection (new) --- + if self.cfg.use_otmf_masks and key in self.otmf_masks: + mask = self.otmf_masks[key].to(delta.device) + # Transferable weights: full delta + # Task-specific weights: reduced delta (protect them) + delta = torch.where( + mask, + delta, # Transferable → allow full change + delta * (1.0 - self.cfg.otmf_protect_strength), # Protected → reduced + ) + + # --- MagMax Protection (extra safety layer) --- + if self.cfg.use_magmax and key in self.magnitude_masks: + mask = self.magnitude_masks[key] + delta = torch.where(mask, delta * 0.1, delta) + + # Apply constrained delta + result = original + delta + + return result + + def after_merge( + self, + target_model: AutoModelForCausalLM, + pre_merge_state: dict, + pre_merge_activations: dict = None, + post_merge_activations: dict = None, + ): + """ + Record the merge delta and compute protections for next merge. + + Called AFTER each merge completes successfully. + Now also computes: + - ARM rotation vectors for next merge steering + - OTMF transferability masks for next merge + """ + current_state = target_model.state_dict() + + for key in current_state: + if key in pre_merge_state: + delta = current_state[key].float() - pre_merge_state[key].float() + if delta.abs().max() > 1e-8: + if key not in self.previous_deltas: + self.previous_deltas[key] = [] + if len(self.previous_deltas[key]) >= 2: + self.previous_deltas[key].pop(0) + self.previous_deltas[key].append(delta.cpu()) + + # --- Compute ARM rotations for next merge --- + if self.cfg.use_arm_steering and pre_merge_activations and post_merge_activations: + print("[protect] Computing ARM rotation vectors for next merge...") + self.arm_rotations = compute_arm_rotation( + pre_merge_activations, + post_merge_activations, + post_merge_activations, # Target = current state (for gap calculation) + ) + + # --- Compute OTMF masks for next merge --- + if self.cfg.use_otmf_masks and post_merge_activations: + print("[protect] Computing OTMF transferability masks...") + self.otmf_masks = compute_transferability_masks( + target_model, + post_merge_activations, + threshold=self.cfg.otmf_threshold, + ) + + self.merge_count += 1 + print(f"[protect] Recorded merge delta #{self.merge_count} (ARM + OTMF ready for next)") + + +# ============================================================================ +# MAIN ORCHESTRATOR +# ============================================================================ + +def is_vision_param(key: str, cfg: MergeConfig) -> bool: + """ + Check if a parameter belongs to the vision encoder. + + Qwen3-VL-8B has a ViT vision encoder + merger projection on top of the + language model. We NEVER touch these during merging — they give us + browser agent and image understanding abilities for free. + + Vision params start with prefixes like "visual." or "merger." + Language params start with "model.layers." or "model.embed_tokens." etc. + """ + for prefix in cfg.vision_skip_prefixes: + if key.startswith(prefix): + return True + return False + + +def get_source_by_stage(stage_name: str) -> Optional[ModelConfig]: + """Get model config by stage name.""" + stage_map = { + "deepseek": 0, + "mimo": 1, + "llama": 2, + "falcon": 3, + } + idx = stage_map.get(stage_name.lower()) + if idx is not None and idx < len(SOURCES): + return SOURCES[idx] + return None + + +def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple: + """Load a model and its tokenizer/processor.""" + print(f"\n[merge] Loading {config.name} ({config.hf_id})...") + + # Qwen3-VL uses a processor (handles both text + vision), not just a tokenizer + if config.architecture == "transformer+vision": + try: + from transformers import Qwen3VLForConditionalGeneration, AutoProcessor + processor = AutoProcessor.from_pretrained( + config.hf_id, + trust_remote_code=config.trust_remote_code, + ) + model = Qwen3VLForConditionalGeneration.from_pretrained( + config.hf_id, + torch_dtype=getattr(torch, cfg.dtype), + attn_implementation=cfg.attn_implementation, + device_map=cfg.device_map, + trust_remote_code=config.trust_remote_code, + ) + # Use the tokenizer from the processor for text operations + tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor + print(f"[merge] Loaded {config.name} (VL model): {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params") + + # Count vision vs language params + vision_params = sum( + p.numel() for n, p in model.named_parameters() + if any(n.startswith(pfx) for pfx in cfg.vision_skip_prefixes) + ) + lang_params = sum(p.numel() for p in model.parameters()) - vision_params + print(f"[merge] Language: {lang_params / 1e9:.1f}B | Vision: {vision_params / 1e9:.1f}B") + + return model, tokenizer + except ImportError: + print("[merge] Qwen3VLForConditionalGeneration not available, falling back to AutoModel") + + # Standard text-only models + tokenizer = AutoTokenizer.from_pretrained( + config.hf_id, + trust_remote_code=config.trust_remote_code, + ) + + model = AutoModelForCausalLM.from_pretrained( + config.hf_id, + torch_dtype=getattr(torch, cfg.dtype), + attn_implementation=cfg.attn_implementation, + device_map=cfg.device_map, + trust_remote_code=config.trust_remote_code, + ) + + print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params") + return model, tokenizer + + +def save_checkpoint( + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + stage_name: str, + cfg: MergeConfig, +): + """Save a checkpoint after a successful merge stage.""" + ckpt_dir = Path(cfg.checkpoint_dir) / f"after_{stage_name}" + ckpt_dir.mkdir(parents=True, exist_ok=True) + + print(f"[merge] Saving checkpoint to {ckpt_dir}...") + model.save_pretrained(ckpt_dir) + tokenizer.save_pretrained(ckpt_dir) + print(f"[merge] Checkpoint saved: {ckpt_dir}") + + return str(ckpt_dir) + + +# ============================================================================ +# RESIDUAL BANK — Save what was lost during each merge +# ============================================================================ + +class ResidualBank: + """ + Saves the knowledge that gets lost during each merge so it can + be recovered later. + + When we blend at alpha=0.5: + merged = 0.5 × source + 0.5 × target + + We LOSE: + target_residual = target_original - merged (what target lost) + source_residual = source_original - merged (what source lost) + + These residuals are saved to disk. Later they can be: + 1. Fed back during the healing fine-tune (as training signal) + 2. Re-injected via a small LoRA adapter + 3. Used to diagnose which merge caused a specific knowledge loss + 4. Re-applied at a lower alpha if we want more of that model + + Think of it like saving the sawdust when you cut wood — you might + need to glue some of it back later. + """ + + def __init__(self, cfg: MergeConfig): + self.cfg = cfg + self.residual_dir = Path(cfg.checkpoint_dir) / "residuals" + self.residual_dir.mkdir(parents=True, exist_ok=True) + self.residual_index = {} # stage → {path, stats} + + def save_residuals( + self, + stage_name: str, + pre_merge_target_state: dict, + source_state: dict, + post_merge_state: dict, + source_config: ModelConfig, + ): + """ + Compute and save what was lost from both target and source. + + Saves two files per merge stage: + - target_residual: what the target model lost + - source_residual: what the source model didn't fully contribute + + Also saves stats so we know WHERE the biggest losses were + (which layers, which type of weights). + """ + stage_dir = self.residual_dir / stage_name + stage_dir.mkdir(parents=True, exist_ok=True) + + target_residual = {} + source_residual = {} + stats = { + "stage": stage_name, + "source_model": source_config.name, + "target_loss_by_layer": {}, + "source_loss_by_layer": {}, + "total_target_loss": 0.0, + "total_source_loss": 0.0, + "biggest_losses": [], + } + + for key in post_merge_state: + merged_w = post_merge_state[key].float() + + # What the target lost + if key in pre_merge_target_state: + original_target = pre_merge_target_state[key].float() + t_residual = original_target - merged_w + t_loss = t_residual.abs().mean().item() + + if t_loss > 1e-6: # Only save meaningful residuals + target_residual[key] = t_residual.to(torch.bfloat16).cpu() + stats["total_target_loss"] += t_loss + + # Track per-layer losses + layer_name = ".".join(key.split(".")[:4]) + if layer_name not in stats["target_loss_by_layer"]: + stats["target_loss_by_layer"][layer_name] = 0.0 + stats["target_loss_by_layer"][layer_name] += t_loss + + # What the source lost (what didn't make it into the merge) + if key in source_state: + original_source = source_state[key].float() + s_residual = original_source - merged_w + s_loss = s_residual.abs().mean().item() + + if s_loss > 1e-6: + source_residual[key] = s_residual.to(torch.bfloat16).cpu() + stats["total_source_loss"] += s_loss + + layer_name = ".".join(key.split(".")[:4]) + if layer_name not in stats["source_loss_by_layer"]: + stats["source_loss_by_layer"][layer_name] = 0.0 + stats["source_loss_by_layer"][layer_name] += s_loss + + # Find the biggest losses (most knowledge dropped) + all_losses = [] + for key in target_residual: + loss_magnitude = target_residual[key].float().abs().mean().item() + all_losses.append({"param": key, "side": "target", "loss": loss_magnitude}) + for key in source_residual: + loss_magnitude = source_residual[key].float().abs().mean().item() + all_losses.append({"param": key, "side": "source", "loss": loss_magnitude}) + all_losses.sort(key=lambda x: x["loss"], reverse=True) + stats["biggest_losses"] = all_losses[:20] # Top 20 biggest losses + + # Save to disk + torch.save(target_residual, stage_dir / "target_residual.pt") + torch.save(source_residual, stage_dir / "source_residual.pt") + + import json + with open(stage_dir / "residual_stats.json", "w") as f: + json.dump(stats, f, indent=2, default=str) + + self.residual_index[stage_name] = { + "path": str(stage_dir), + "target_params_saved": len(target_residual), + "source_params_saved": len(source_residual), + "total_target_loss": stats["total_target_loss"], + "total_source_loss": stats["total_source_loss"], + } + + print(f"[residual] Saved residuals for {stage_name}:") + print(f" Target lost: {len(target_residual)} params (avg loss: {stats['total_target_loss']:.4f})") + print(f" Source lost: {len(source_residual)} params (avg loss: {stats['total_source_loss']:.4f})") + print(f" Top loss: {all_losses[0]['param']} ({all_losses[0]['side']}, {all_losses[0]['loss']:.4f})" if all_losses else "") + print(f" Saved to: {stage_dir}") + + def load_residuals(self, stage_name: str) -> tuple: + """ + Load saved residuals for a stage. + + Returns: + (target_residual_dict, source_residual_dict) + """ + stage_dir = self.residual_dir / stage_name + target_residual = torch.load(stage_dir / "target_residual.pt", weights_only=True) + source_residual = torch.load(stage_dir / "source_residual.pt", weights_only=True) + return target_residual, source_residual + + def reinject_residuals( + self, + model: AutoModelForCausalLM, + stage_name: str, + side: str = "both", + strength: float = 0.3, + ) -> AutoModelForCausalLM: + """ + Re-inject saved residuals back into a model. + + This adds back some of what was lost. Use a low strength (0.1-0.3) + to gently recover knowledge without undoing the merge. + + Args: + model: The model to inject into + stage_name: Which merge stage's residuals to use + side: "target", "source", or "both" + strength: How much to add back (0=nothing, 1=full residual) + """ + print(f"[residual] Re-injecting {stage_name} residuals (side={side}, strength={strength})...") + + target_residual, source_residual = self.load_residuals(stage_name) + state = model.state_dict() + injected = 0 + + if side in ("target", "both"): + for key, residual in target_residual.items(): + if key in state: + state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype) + injected += 1 + + if side in ("source", "both"): + for key, residual in source_residual.items(): + if key in state: + state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype) + injected += 1 + + model.load_state_dict(state) + print(f"[residual] Re-injected {injected} params at {strength:.0%} strength") + return model + + def get_healing_targets(self, top_n: int = 50) -> list: + """ + Get the parameters with the biggest losses across ALL merges. + + These are the params that the healing fine-tune should focus on. + Feed this to the LoRA target_modules to make healing smarter. + """ + import json + all_losses = [] + + for stage_name in self.residual_index: + stage_dir = self.residual_dir / stage_name + stats_file = stage_dir / "residual_stats.json" + if stats_file.exists(): + with open(stats_file) as f: + stats = json.load(f) + for loss in stats.get("biggest_losses", []): + loss["stage"] = stage_name + all_losses.append(loss) + + all_losses.sort(key=lambda x: x["loss"], reverse=True) + + # Extract unique layer/module names for LoRA targeting + target_modules = set() + for loss in all_losses[:top_n]: + param = loss["param"] + # Extract the module type (q_proj, k_proj, gate_proj, etc.) + parts = param.split(".") + for part in parts: + if part.endswith("_proj") or part in ("gate_proj", "up_proj", "down_proj"): + target_modules.add(part) + + print(f"[residual] Top healing targets (from {len(all_losses)} total losses):") + for loss in all_losses[:5]: + print(f" {loss['param']} ({loss['side']}, stage={loss['stage']}, loss={loss['loss']:.4f})") + print(f" → Suggested LoRA targets: {sorted(target_modules)}") + + return list(target_modules) + + +def run_single_merge( + target_model: AutoModelForCausalLM, + target_tokenizer: AutoTokenizer, + source_config: ModelConfig, + cfg: MergeConfig, + protection: MergeProtection, + residual_bank: ResidualBank = None, + calibration_data: list = None, + baseline_perplexity: float = None, + merged_sources: list = None, +) -> dict: + """ + Run a single merge: source → target. + + Full pipeline for one merge step: + 1. Load source model + 2. Inject canary into source + 3. Extract activations from both + 4. Compute transport plans + 5. Apply merge protection + 6. Fuse weights + 7. Apply post-merge protection + 8. Validate + + Returns: + Dict with merge results, validation results, and status + """ + if merged_sources is None: + merged_sources = [] + + stage_name = source_config.name + print(f"\n{'=' * 70}") + print(f"MERGE STAGE: {stage_name} → target") + print(f"Risk level: {source_config.merge_risk.upper()}") + print(f"{'=' * 70}") + + result = { + "stage": stage_name, + "status": "pending", + "validation": None, + "checkpoint": None, + } + + # --- Step 1: Load source model --- + source_model, source_tokenizer = load_model(source_config, cfg) + + # --- Step 2: Inject canary into source --- + if stage_name in CANARY_FACTS: + print(f"\n[merge] Injecting canary fact into {stage_name}...") + source_model = inject_canary(source_model, source_tokenizer, stage_name) + + # --- Step 3: Load calibration data (if not provided) --- + if calibration_data is None: + calibration_data = load_calibration_data(cfg, target_tokenizer) + + # --- Step 4: Extract activations --- + print(f"\n[merge] Extracting source activations...") + source_activations = extract_activations(source_model, calibration_data) + + print(f"\n[merge] Extracting target activations...") + pre_merge_target_activations = extract_activations(target_model, calibration_data) + + # --- Step 4.5: Mergeability pre-check (2601.22285) --- + if cfg.use_mergeability_check: + mergeability = compute_mergeability_score( + source_activations, pre_merge_target_activations, source_config + ) + result["mergeability"] = mergeability + + if mergeability["overall"] < cfg.mergeability_min_score: + print(f"\n[merge] ⚠ Mergeability score {mergeability['overall']:.2f} below threshold {cfg.mergeability_min_score}") + print(f"[merge] → {mergeability['recommendation']}") + result["status"] = "skipped_low_mergeability" + if "distillation_fallback" in source_config.special_handling: + result["fallback"] = "distillation" + del source_model, source_activations, pre_merge_target_activations + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return result + + # --- Step 5: Compute transport plans --- + transport_plans = compute_transport_plans( + source_activations, pre_merge_target_activations, cfg + ) + + # --- Step 5.5: RAM RL-weight disentanglement (2601.13572) --- + use_ram = ( + cfg.use_ram_disentangle + and source_config.architecture in ("transformer", "transformer+mtp") + and source_config.merge_risk in ("low", "medium") + and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"]) + ) + + # --- Step 6: Pre-merge protection --- + adjusted_alpha = protection.before_merge(target_model, source_config) + + # Override source alpha with time-adjusted value + source_config_adjusted = copy.copy(source_config) + source_config_adjusted.merge_alpha = adjusted_alpha + + # Save pre-merge state for protection + pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()} + + # --- Step 7: Fuse weights --- + if use_ram: + # RAM path: disentangle RL weights, merge with preservation + print(f"\n[merge] Using RAM RL-preservation for {stage_name}...") + try: + # Try loading the base (pre-RL) model for disentanglement + base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "") + print(f"[merge] Loading base model for RAM: {base_hf_id}") + base_model = AutoModelForCausalLM.from_pretrained( + base_hf_id, + torch_dtype=getattr(torch, cfg.dtype), + device_map=cfg.device_map, + trust_remote_code=source_config.trust_remote_code, + ) + shared_mask, rl_mask = disentangle_rl_weights( + source_model, base_model, cfg.ram_rl_threshold + ) + # Fuse with RL preservation + target_state = merge_with_rl_preservation( + target_model.state_dict(), + source_model.state_dict(), + shared_mask, rl_mask, + shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha), + rl_alpha=cfg.ram_rl_alpha, + ) + target_model.load_state_dict(target_state) + del base_model + print(f"[merge] RAM merge complete for {stage_name}") + except Exception as e: + print(f"[merge] RAM failed ({e}), falling back to standard T&M merge") + target_model = fuse_weights( + source_model, target_model, transport_plans, + source_config_adjusted, cfg, + ) + else: + # Standard T&M path + target_model = fuse_weights( + source_model, target_model, transport_plans, + source_config_adjusted, cfg, + ) + + # --- Step 7.5: Theseus fallback check (2602.12952) --- + # If T&M merge produced poor activation alignment, try Theseus + if cfg.use_theseus_fallback and source_config.merge_risk == "high": + print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...") + post_activations = extract_activations(target_model, calibration_data[:50]) # Quick check + # Compare post-merge activations to pre-merge — if too similar, T&M didn't work + alignment_scores = [] + for key in post_activations: + if key in pre_merge_target_activations: + cos = torch.nn.functional.cosine_similarity( + post_activations[key].float().mean(0, keepdim=True), + pre_merge_target_activations[key].float().mean(0, keepdim=True), + ) + alignment_scores.append(cos.item()) + avg_change = 1.0 - np.mean(alignment_scores) if alignment_scores else 0.0 + print(f"[merge] Activation change from merge: {avg_change:.4f}") + + if avg_change < 0.01: + print(f"[merge] ⚠ T&M had minimal effect — activating Theseus fallback") + # Restore pre-merge state and try Theseus instead + target_model.load_state_dict(pre_merge_state) + try: + base_model = AutoModelForCausalLM.from_pretrained( + source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0], + torch_dtype=getattr(torch, cfg.dtype), + device_map=cfg.device_map, + trust_remote_code=source_config.trust_remote_code, + ) + target_model = transport_task_vector_theseus( + source_model, base_model, target_model, + source_activations, pre_merge_target_activations, + alpha=cfg.theseus_alpha, + ) + del base_model + print(f"[merge] Theseus transport complete for {stage_name}") + except Exception as e: + print(f"[merge] Theseus also failed ({e}). Using original T&M result.") + # Re-apply T&M result + target_model = fuse_weights( + source_model, target_model, transport_plans, + source_config_adjusted, cfg, + ) + + # --- Step 8: Apply post-merge protection (ARM + OTMF + MagMax) --- + # Skip vision encoder params — they weren't merged, so don't "protect" them + if protection.merge_count > 0: + print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...") + target_state = target_model.state_dict() + protected_count = 0 + vision_skipped = 0 + for key in target_state: + if is_vision_param(key, cfg): + vision_skipped += 1 + continue # Don't touch vision encoder + if key in pre_merge_state: + protected_param = protection.apply_protection( + target_state, pre_merge_state, key + ) + target_state[key] = protected_param + protected_count += 1 + target_model.load_state_dict(target_state) + print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)") + + # --- Step 8.5: Extract post-merge activations for ARM/OTMF --- + post_merge_activations = extract_activations(target_model, calibration_data[:100]) + + # Record this merge's delta + compute ARM/OTMF for next merge + protection.after_merge( + target_model, pre_merge_state, + pre_merge_activations=pre_merge_target_activations, + post_merge_activations=post_merge_activations, + ) + + # --- Step 8.8: Save residuals (what was lost from both sides) --- + if residual_bank is not None: + print(f"\n[merge] Saving residuals for {stage_name}...") + residual_bank.save_residuals( + stage_name=stage_name, + pre_merge_target_state=pre_merge_state, + source_state={k: v.cpu() for k, v in source_model.state_dict().items()}, + post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()}, + source_config=source_config, + ) + + # --- Step 9: Free source model memory --- + del source_model, source_activations, pre_merge_target_activations + del transport_plans, post_merge_activations + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # --- Step 10: Validate --- + merged_sources.append(stage_name) + validation = validate_merged_model( + target_model, target_tokenizer, + merged_sources, cfg, + baseline_perplexity=baseline_perplexity, + ) + + result["validation"] = validation + result["merged_sources"] = merged_sources.copy() + + # --- Kill criteria check --- + if not validation["overall"]: + print(f"\n[merge] ⚠ VALIDATION FAILED for {stage_name}") + print(f"[merge] Kill criteria triggered — consider aborting") + result["status"] = "failed" + + # Check if we should try distillation fallback + if "distillation_fallback" in source_config.special_handling: + print(f"[merge] {stage_name} has distillation fallback available") + result["fallback"] = "distillation" + else: + print(f"\n[merge] ✓ {stage_name} merge PASSED validation") + result["status"] = "passed" + + return result + + +def run_pipeline( + stages: list[str], + cfg: MergeConfig = None, +) -> dict: + """ + Run the full merge pipeline. + + Args: + stages: List of stage names to run, e.g. ["deepseek"] or + ["deepseek", "mimo", "llama", "falcon"] + cfg: Merge configuration (uses defaults if None) + + Returns: + Dict with overall results, per-stage results, and final model path + """ + if cfg is None: + cfg = MergeConfig() + + print("\n" + "=" * 70) + print("TD FUSE — Transport and Merge Pipeline") + print(f"Target: {TARGET.name} ({TARGET.hf_id})") + if TARGET.architecture == "transformer+vision": + print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)") + print(f"Stages: {', '.join(stages)}") + print(f"Output: {cfg.output_dir}") + print("=" * 70) + + # Setup + try: + setup_tm_repo(cfg) + except FileNotFoundError as e: + print(f"\n⚠ {e}") + print("Continuing with fallback implementation...") + + # Create output directories + Path(cfg.output_dir).mkdir(parents=True, exist_ok=True) + Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True) + + # --- Load target model --- + target_model, target_tokenizer = load_model(TARGET, cfg) + + # --- Inject canary into target (Qwen3's own canary) --- + if "Qwen3-VL-8B" in CANARY_FACTS: + print("\n[pipeline] Injecting canary into base Qwen3-8B...") + target_model = inject_canary(target_model, target_tokenizer, "Qwen3-VL-8B") + + # --- Compute baseline perplexity --- + print("\n[pipeline] Computing baseline perplexity...") + baseline_ppl = compute_perplexity(target_model, target_tokenizer) + print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}") + + # --- Load calibration data once --- + calibration_data = load_calibration_data(cfg, target_tokenizer) + + # --- Initialize merge protection + residual bank --- + protection = MergeProtection(cfg) + residual_bank = ResidualBank(cfg) + + # --- Run each merge stage --- + pipeline_results = { + "stages": {}, + "baseline_perplexity": baseline_ppl, + "final_checkpoint": None, + "residuals": {}, + "overall_status": "pending", + } + merged_sources = [] + all_passed = True + + for stage_name in stages: + source_config = get_source_by_stage(stage_name) + if source_config is None: + print(f"\n⚠ Unknown stage: {stage_name}, skipping") + continue + + # --- Wasserstein pre-check for high-risk models --- + if "check_wasserstein_first" in source_config.special_handling: + print(f"\n[pipeline] Running Wasserstein pre-check for {source_config.name}...") + # TODO: Implement Wasserstein distance pre-check + # If distance is too high, skip to distillation fallback + print("[pipeline] Pre-check: proceeding (TODO: implement distance check)") + + # Run the merge (with residual bank to save what's lost) + stage_result = run_single_merge( + target_model, target_tokenizer, + source_config, cfg, + protection, + residual_bank=residual_bank, + calibration_data=calibration_data, + baseline_perplexity=baseline_ppl, + merged_sources=merged_sources, + ) + + pipeline_results["stages"][stage_name] = stage_result + + if stage_result["status"] == "passed": + # Save checkpoint + ckpt_path = save_checkpoint( + target_model, target_tokenizer, stage_name, cfg + ) + stage_result["checkpoint"] = ckpt_path + pipeline_results["final_checkpoint"] = ckpt_path + else: + all_passed = False + print(f"\n[pipeline] Stage {stage_name} FAILED") + + # Decision: abort or continue? + if source_config.merge_risk == "high": + print(f"[pipeline] High-risk model failed — skipping (will use distillation)") + # Don't abort the whole pipeline, just skip this model + continue + else: + print(f"[pipeline] ABORTING pipeline — non-high-risk model failed") + pipeline_results["overall_status"] = f"aborted_at_{stage_name}" + break + + # --- Save residual index --- + pipeline_results["residuals"] = residual_bank.residual_index + if residual_bank.residual_index: + print(f"\n[pipeline] Residual bank: {len(residual_bank.residual_index)} stages saved") + for stage, info in residual_bank.residual_index.items(): + print(f" {stage}: target lost {info['total_target_loss']:.4f}, source lost {info['total_source_loss']:.4f}") + + # Identify which modules need the most healing + healing_targets = residual_bank.get_healing_targets(top_n=50) + pipeline_results["suggested_healing_targets"] = healing_targets + + # --- Save final model --- + if pipeline_results["final_checkpoint"]: + final_dir = Path(cfg.output_dir) / "final" + final_dir.mkdir(parents=True, exist_ok=True) + target_model.save_pretrained(final_dir) + target_tokenizer.save_pretrained(final_dir) + pipeline_results["final_model_path"] = str(final_dir) + print(f"\n[pipeline] Final model saved to {final_dir}") + + if all_passed: + pipeline_results["overall_status"] = "all_passed" + elif pipeline_results["overall_status"] == "pending": + pipeline_results["overall_status"] = "partial" + + # --- Print final summary --- + print("\n" + "=" * 70) + print("PIPELINE SUMMARY") + print("=" * 70) + for stage_name, stage_result in pipeline_results["stages"].items(): + status = stage_result["status"] + emoji = "✓" if status == "passed" else "✗" + print(f" {emoji} {stage_name}: {status}") + print(f"\n Overall: {pipeline_results['overall_status']}") + if residual_bank.residual_index: + print(f"\n Residuals saved for: {', '.join(residual_bank.residual_index.keys())}") + print(f" To recover lost knowledge later:") + print(f" python -m td_fuse.run --reinject --strength 0.2") + print("=" * 70) + + return pipeline_results diff --git a/hugging/td_fuse/run.py b/hugging/td_fuse/run.py new file mode 100644 index 0000000000000000000000000000000000000000..fb9dac1c74824f5ab3f3591126bcaa2d192df4a9 --- /dev/null +++ b/hugging/td_fuse/run.py @@ -0,0 +1,279 @@ +""" +TD Fuse — Main Entry Point. + +Usage: + # Dad demo: merge just DeepSeek → Qwen3-8B (easiest, lowest risk) + python -m td_fuse.run --stage demo + + # Full pipeline: all 4 merges + python -m td_fuse.run --stage all + + # Single model merge + python -m td_fuse.run --stage deepseek + python -m td_fuse.run --stage mimo + python -m td_fuse.run --stage llama + python -m td_fuse.run --stage falcon + + # With healing fine-tune after merge + python -m td_fuse.run --stage demo --heal + + # Custom output directory + python -m td_fuse.run --stage all --output ./my_output + + # Heal an existing checkpoint + python -m td_fuse.run --heal-only --model-path ./td_fuse_checkpoints/after_deepseek + +Findings: #25 (dad demo plan), #22 (merge order), #24 (official T&M pipeline) +""" + +import argparse +import json +import sys +import time +from pathlib import Path + +from .config import MergeConfig, DEMO_STAGES, FULL_STAGES +from .merge import run_pipeline, ResidualBank +from .heal import heal_model + + +def parse_args(): + parser = argparse.ArgumentParser( + description="TD Fuse — Transport and Merge pipeline for Time Dilation", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python -m td_fuse.run --stage demo # Dad demo (DeepSeek only) + python -m td_fuse.run --stage all # Full 4-model merge + python -m td_fuse.run --stage all --heal # Merge + healing fine-tune + python -m td_fuse.run --heal-only --model-path ./checkpoint + python -m td_fuse.run --reinject deepseek --strength 0.2 --model-path ./final + """, + ) + + parser.add_argument( + "--stage", + type=str, + default="demo", + choices=["demo", "all", "deepseek", "mimo", "llama", "falcon"], + help="Which merge stage(s) to run (default: demo)", + ) + parser.add_argument( + "--heal", + action="store_true", + help="Run healing fine-tune after merge", + ) + parser.add_argument( + "--heal-only", + action="store_true", + help="Only run healing (skip merge), requires --model-path", + ) + parser.add_argument( + "--model-path", + type=str, + default=None, + help="Path to existing model/checkpoint (for --heal-only)", + ) + parser.add_argument( + "--output", + type=str, + default="./td_fuse_outputs", + help="Output directory (default: ./td_fuse_outputs)", + ) + parser.add_argument( + "--checkpoint-dir", + type=str, + default="./td_fuse_checkpoints", + help="Checkpoint directory (default: ./td_fuse_checkpoints)", + ) + parser.add_argument( + "--tm-repo", + type=str, + default="./Cross-Architecture-Merging-for-Large-Language-Models", + help="Path to official T&M repo", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print what would happen without actually running", + ) + parser.add_argument( + "--reinject", + type=str, + default=None, + help="Re-inject saved residuals from a stage (e.g., --reinject deepseek)", + ) + parser.add_argument( + "--reinject-side", + type=str, + default="both", + choices=["target", "source", "both"], + help="Which side's residuals to re-inject (default: both)", + ) + parser.add_argument( + "--strength", + type=float, + default=0.2, + help="Residual re-injection strength, 0-1 (default: 0.2)", + ) + + return parser.parse_args() + + +def print_banner(): + """Print the TD Fuse banner.""" + banner = """ + ╔══════════════════════════════════════════════════╗ + ║ ║ + ║ ████████╗██████╗ ███████╗██╗ ██╗███████╗ ║ + ║ ╚══██╔══╝██╔══██╗ ██╔════╝██║ ██║██╔════╝ ║ + ║ ██║ ██║ ██║ █████╗ ██║ ██║███████╗ ║ + ║ ██║ ██║ ██║ ██╔══╝ ██║ ██║╚════██║ ║ + ║ ██║ ██████╔╝ ██║ ╚██████╔╝███████║ ║ + ║ ╚═╝ ╚═════╝ ╚═╝ ╚═════╝ ╚══════╝ ║ + ║ ║ + ║ Transport and Merge for Time Dilation ║ + ║ Merging 5 models into Qwen3-8B ║ + ║ ║ + ╚══════════════════════════════════════════════════╝ + """ + print(banner) + + +def main(): + args = parse_args() + print_banner() + + # Build config from args + cfg = MergeConfig( + output_dir=args.output, + checkpoint_dir=args.checkpoint_dir, + tm_repo_path=args.tm_repo, + ) + + # Determine which stages to run + if args.stage == "demo": + stages = DEMO_STAGES + elif args.stage == "all": + stages = FULL_STAGES + else: + stages = [args.stage] + + # --- Reinject residuals mode --- + if args.reinject: + if not args.model_path: + print("Error: --reinject requires --model-path") + sys.exit(1) + + from transformers import AutoModelForCausalLM, AutoTokenizer + import torch + + print(f"\n[run] Re-injecting residuals from stage: {args.reinject}") + print(f"[run] Side: {args.reinject_side}, Strength: {args.strength}") + + residual_bank = ResidualBank(cfg) + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + model = AutoModelForCausalLM.from_pretrained( + args.model_path, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + + model = residual_bank.reinject_residuals( + model, args.reinject, + side=args.reinject_side, + strength=args.strength, + ) + + # Save the patched model + patched_dir = Path(cfg.output_dir) / f"reinjected_{args.reinject}_{args.strength}" + patched_dir.mkdir(parents=True, exist_ok=True) + model.save_pretrained(str(patched_dir)) + tokenizer.save_pretrained(str(patched_dir)) + print(f"\n[run] Patched model saved to: {patched_dir}") + return + + # --- Heal-only mode --- + if args.heal_only: + if not args.model_path: + print("Error: --heal-only requires --model-path") + sys.exit(1) + + print(f"\n[run] Healing model at: {args.model_path}") + healed_path = heal_model(args.model_path, cfg) + print(f"\n[run] Healed model saved to: {healed_path}") + return + + # --- Dry run --- + if args.dry_run: + print("\n=== DRY RUN ===") + print(f"Stages: {stages}") + print(f"Output: {cfg.output_dir}") + print(f"Checkpoints: {cfg.checkpoint_dir}") + print(f"T&M repo: {cfg.tm_repo_path}") + print(f"Heal after: {args.heal}") + print(f"\nWould run:") + for i, stage in enumerate(stages, 1): + print(f" {i}. Merge {stage} → target") + print(f" → Validate (canary + perplexity + thinking + reasoning)") + print(f" → Checkpoint") + if args.heal: + print(f" {len(stages) + 1}. QLoRA healing fine-tune") + print("\nNo changes made (dry run).") + return + + # --- Run the pipeline --- + start_time = time.time() + + results = run_pipeline(stages, cfg) + + elapsed = time.time() - start_time + print(f"\n[run] Pipeline completed in {elapsed / 60:.1f} minutes") + + # --- Healing fine-tune (optional) --- + if args.heal and results.get("final_checkpoint"): + print("\n[run] Starting healing fine-tune...") + healed_path = heal_model(results["final_checkpoint"], cfg) + results["healed_model_path"] = healed_path + print(f"[run] Healed model: {healed_path}") + + # --- Save results --- + results_path = Path(cfg.output_dir) / "pipeline_results.json" + + # Convert non-serialisable objects + def make_serialisable(obj): + if isinstance(obj, dict): + return {k: make_serialisable(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [make_serialisable(v) for v in obj] + elif isinstance(obj, (int, float, str, bool, type(None))): + return obj + else: + return str(obj) + + with open(results_path, "w") as f: + json.dump(make_serialisable(results), f, indent=2) + print(f"[run] Results saved to {results_path}") + + # --- Final summary --- + print(f"\n{'=' * 60}") + print("TD FUSE COMPLETE") + print(f"{'=' * 60}") + print(f" Status: {results['overall_status']}") + print(f" Time: {elapsed / 60:.1f} minutes") + if results.get("final_model_path"): + print(f" Model: {results['final_model_path']}") + if results.get("healed_model_path"): + print(f" Healed: {results['healed_model_path']}") + print(f" Results: {results_path}") + print(f"{'=' * 60}") + + # Exit code based on result + if results["overall_status"] == "all_passed": + sys.exit(0) + else: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/hugging/td_fuse/techniques.py b/hugging/td_fuse/techniques.py new file mode 100644 index 0000000000000000000000000000000000000000..35f43fcba0d4492727af51cd5fbb2dd303f27e01 --- /dev/null +++ b/hugging/td_fuse/techniques.py @@ -0,0 +1,669 @@ +""" +Advanced Merge Techniques — from latest papers (Feb 2026). + +This module contains implementations inspired by recent research +that improve TD's sequential cross-architecture merging pipeline. + +Techniques: + 1. Theseus (2602.12952) — Procrustes-based task vector transport + 2. ARM (2602.03237) — Activation-guided rotation for sequential merges + 3. OTMF (2511.19561) — OT masks for identifying transferable weights + 4. RAM (2601.13572) — RL-weight disentanglement for RL-trained models + 5. Mergeability (2601.22285) — Pre-check scoring before attempting merge + +These complement Transport and Merge (2602.05495) which handles +the core cross-architecture fusion via optimal transport. +""" + +import torch +import numpy as np +from typing import Optional +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .config import MergeConfig, ModelConfig + + +# ============================================================================ +# 1. THESEUS — Procrustes-Based Task Vector Transport (2602.12952) +# ============================================================================ +# +# Instead of aligning neurons via optimal transport (T&M), Theseus aligns +# the FUNCTIONAL EFFECT of weights via orthogonal Procrustes. +# +# Analogy: T&M says "neuron 5 in Model A = neuron 12 in Model B" +# Theseus says "the EFFECT of Model A's weights can be rotated +# into Model B's space" +# +# Best for: Models where neuron-level alignment is poor (Falcon SSM hybrid) + +def compute_procrustes_alignment( + source_activations: torch.Tensor, + target_activations: torch.Tensor, +) -> torch.Tensor: + """ + Compute the orthogonal Procrustes rotation matrix R that best maps + source activations into target activation space. + + R = argmin ||target - source @ R||_F subject to R^T R = I + + Solution: R = V @ U^T from SVD of (source^T @ target) = U S V^T + + This is a closed-form solution — no iterative optimisation needed. + + Args: + source_activations: [num_samples, source_dim] activation matrix + target_activations: [num_samples, target_dim] activation matrix + + Returns: + R: [source_dim, target_dim] rotation matrix + """ + # Center the activations (remove mean) + S = source_activations - source_activations.mean(dim=0, keepdim=True) + T = target_activations - target_activations.mean(dim=0, keepdim=True) + + # Handle dimension mismatch by zero-padding the smaller one + s_dim = S.shape[1] + t_dim = T.shape[1] + max_dim = max(s_dim, t_dim) + + if s_dim < max_dim: + S = torch.nn.functional.pad(S, (0, max_dim - s_dim)) + if t_dim < max_dim: + T = torch.nn.functional.pad(T, (0, max_dim - t_dim)) + + # Cross-covariance matrix + M = S.T @ T # [max_dim, max_dim] + + # SVD: M = U @ diag(sigma) @ V^T + U, sigma, Vt = torch.linalg.svd(M, full_matrices=True) + + # Optimal rotation: R = V @ U^T + # This ensures R is orthogonal (R^T R = I) + R = Vt.T @ U.T + + # Ensure proper rotation (det = +1), not reflection + det = torch.linalg.det(R) + if det < 0: + # Flip sign of last column of Vt + Vt[-1, :] *= -1 + R = Vt.T @ U.T + + return R[:s_dim, :t_dim] # Crop back to original dims + + +def transport_task_vector_theseus( + source_model: AutoModelForCausalLM, + source_base_model: AutoModelForCausalLM, + target_model: AutoModelForCausalLM, + source_activations: dict, + target_activations: dict, + alpha: float = 0.3, +) -> AutoModelForCausalLM: + """ + Transport a task vector from source to target using Theseus method. + + Task vector = source_finetuned - source_base + (the "diff" that represents what the model learned) + + We rotate this diff into target's space using Procrustes alignment, + then add it to target: target_new = target + alpha * R @ task_vector + + This is the FALLBACK for when T&M's neuron-level alignment fails + (e.g., Falcon's SSM components). + + Args: + source_model: The fine-tuned source (e.g., Falcon-H1R-7B) + source_base_model: The base version of source (for computing task vector) + target_model: The target to transport into (our merged Qwen3) + source_activations: Layer → activation tensors for source + target_activations: Layer → activation tensors for target + alpha: Blending weight for the transported task vector + """ + print("[theseus] Computing task vectors and Procrustes alignment...") + + source_state = source_model.state_dict() + base_state = source_base_model.state_dict() + target_state = target_model.state_dict() + + # Compute per-layer Procrustes rotation matrices + rotations = {} + source_layers = sorted(source_activations.keys()) + target_layers = sorted(target_activations.keys()) + + for sl, tl in zip(source_layers, target_layers): + if sl in source_activations and tl in target_activations: + R = compute_procrustes_alignment( + source_activations[sl].float(), + target_activations[tl].float(), + ) + rotations[(sl, tl)] = R + + # Transport task vectors + transported_count = 0 + for target_key in target_state: + # Find matching source key (simplified — same key names) + source_key = target_key + if source_key not in source_state or source_key not in base_state: + continue + + # Task vector = what the source learned + task_vector = source_state[source_key].float() - base_state[source_key].float() + + if task_vector.abs().max() < 1e-8: + continue # No meaningful change + + # For 2D weight matrices, apply rotation + if task_vector.dim() == 2: + # Find the appropriate rotation for this layer + for (sl, tl), R in rotations.items(): + if sl.split(".")[2] == target_key.split(".")[2]: # Same layer index + R_device = R.to(task_vector.device) + # Rotate: task_vector_rotated = task_vector @ R + try: + if task_vector.shape[1] == R_device.shape[0]: + task_vector = task_vector @ R_device + elif task_vector.shape[0] == R_device.shape[0]: + task_vector = R_device.T @ task_vector + except RuntimeError: + pass # Dimension mismatch, use unrotated + break + + # Apply: target_new = target + alpha * rotated_task_vector + target_w = target_state[target_key] + if task_vector.shape == target_w.shape: + target_state[target_key] = target_w + alpha * task_vector.to(target_w.dtype) + transported_count += 1 + + target_model.load_state_dict(target_state) + print(f"[theseus] Transported {transported_count} task vectors via Procrustes") + return target_model + + +# ============================================================================ +# 2. ARM — Activation-Guided Rotations for Sequential Merging (2602.03237) +# ============================================================================ +# +# ARM treats sequential merging like gradient descent — each merge step +# has a "direction" and a "learning rate" (merge coefficient). +# +# Key insight: Use ACTIVATION PATTERNS to compute optimal rotation vectors +# that guide each merge step. This is a smarter version of our +# orthogonal projection in MergeProtection. + +def compute_arm_rotation( + pre_merge_activations: dict, + post_merge_activations: dict, + target_activations: dict, +) -> dict: + """ + Compute ARM rotation vectors for sequential merge protection. + + For each layer, compute a rotation that: + 1. Preserves the direction of knowledge already merged + 2. Steers the next merge to fill GAPS rather than overwrite + + The rotation is computed from the activation change (what the + last merge did) and the target (where we want to end up). + + Returns: + Dict of layer_name → rotation matrix + """ + print("[arm] Computing activation-guided rotations...") + + rotations = {} + + for layer_name in pre_merge_activations: + if layer_name not in post_merge_activations or layer_name not in target_activations: + continue + + pre = pre_merge_activations[layer_name].float() # Before last merge + post = post_merge_activations[layer_name].float() # After last merge + target = target_activations[layer_name].float() # Ideal target + + # Delta from last merge + merge_delta = post - pre # [samples, hidden_dim] + + # Gap remaining (what we still need) + gap = target - post # [samples, hidden_dim] + + # Average across samples to get direction vectors + delta_dir = merge_delta.mean(dim=0) # [hidden_dim] + gap_dir = gap.mean(dim=0) # [hidden_dim] + + # Normalise + delta_norm = delta_dir / (delta_dir.norm() + 1e-8) + gap_norm = gap_dir / (gap_dir.norm() + 1e-8) + + # Compute rotation from delta direction to gap direction + # Using Rodrigues' rotation formula for the 2D plane + # spanned by delta and gap + cos_theta = torch.dot(delta_norm, gap_norm).clamp(-1, 1) + sin_theta = torch.sqrt(1 - cos_theta ** 2) + + # Store as a simple rotation descriptor + rotations[layer_name] = { + "delta_direction": delta_norm, + "gap_direction": gap_norm, + "cos_theta": cos_theta.item(), + "sin_theta": sin_theta.item(), + "gap_magnitude": gap_dir.norm().item(), + } + + return rotations + + +def apply_arm_steering( + weight_delta: torch.Tensor, + rotation_info: dict, + steering_strength: float = 0.5, +) -> torch.Tensor: + """ + Steer a weight delta using ARM rotation vectors. + + Instead of blindly projecting out previous merge directions + (our old orthogonal projection), ARM STEERS the delta toward + the remaining gap. + + Args: + weight_delta: The raw delta from the current merge + rotation_info: ARM rotation info for this layer + steering_strength: How much to steer (0=no steering, 1=full) + + Returns: + Steered weight delta + """ + delta_dir = rotation_info["delta_direction"] + gap_dir = rotation_info["gap_direction"] + + flat = weight_delta.flatten().float() + + # Component along previous merge direction + prev_component = torch.dot(flat, delta_dir.to(flat.device)) + + # Remove some of the previous-direction component + # and add gap-direction component instead + correction = ( + -steering_strength * prev_component * delta_dir.to(flat.device) + + steering_strength * prev_component * gap_dir.to(flat.device) + ) + + steered = flat + correction + return steered.reshape(weight_delta.shape).to(weight_delta.dtype) + + +# ============================================================================ +# 3. OTMF — Transferability Masks via Optimal Transport (2511.19561) +# ============================================================================ +# +# OTMF discovers which parts of each model are "transferable" (shared +# knowledge) vs "task-specific" (unique to that model). +# +# Transferable weights → safe to merge/average +# Task-specific weights → must be preserved carefully +# +# This replaces our MagMax "top 20% by magnitude" heuristic with a +# principled, data-driven approach. + +def compute_transferability_masks( + model: AutoModelForCausalLM, + calibration_activations: dict, + threshold: float = 0.3, +) -> dict: + """ + Compute per-parameter transferability masks using activation variance. + + High activation variance across diverse inputs → parameter encodes + task-specific knowledge (DON'T merge aggressively). + + Low activation variance → parameter encodes shared/general knowledge + (safe to merge/average). + + This is a simplified version of OTMF's OT-based mask discovery. + + Args: + model: The current merged model + calibration_activations: Layer → [samples, hidden_dim] activations + threshold: Variance quantile threshold for "task-specific" classification + + Returns: + Dict of param_name → bool mask (True = transferable/safe, False = task-specific/protect) + """ + print("[otmf] Computing transferability masks...") + + masks = {} + state = model.state_dict() + + # Compute per-neuron activation variance + neuron_importance = {} + for layer_name, acts in calibration_activations.items(): + # Variance across samples: high variance = this neuron is doing something specific + variance = acts.var(dim=0) # [hidden_dim] + neuron_importance[layer_name] = variance + + # Map neuron importance to parameter importance + for param_name, param in state.items(): + # Find the corresponding layer's importance + layer_prefix = ".".join(param_name.split(".")[:4]) # e.g., model.layers.0.self_attn + + importance = None + for layer_name, var in neuron_importance.items(): + if layer_prefix in layer_name: + importance = var + break + + if importance is None: + # Default: mark everything as transferable (safe to merge) + masks[param_name] = torch.ones(param.shape, dtype=torch.bool) + continue + + # For 2D weights: importance determines which rows/columns to protect + if param.dim() == 2: + rows, cols = param.shape + # Use importance for the output dimension + imp = importance[:rows] if importance.shape[0] >= rows else importance + + # Compute threshold: top (1-threshold) fraction is task-specific + if imp.numel() > 0: + q = torch.quantile(imp.float(), 1.0 - threshold) + # True = transferable (below threshold), False = task-specific (protect) + row_mask = imp < q + masks[param_name] = row_mask.unsqueeze(1).expand_as(param) + else: + masks[param_name] = torch.ones(param.shape, dtype=torch.bool) + else: + # 1D params (biases, norms): default to transferable + masks[param_name] = torch.ones(param.shape, dtype=torch.bool) + + transferable = sum(m.sum().item() for m in masks.values()) + total = sum(m.numel() for m in masks.values()) + print(f"[otmf] Transferability: {transferable / total:.1%} transferable, {1 - transferable / total:.1%} task-specific") + + return masks + + +def apply_masked_merge( + target_state: dict, + fused_state: dict, + masks: dict, + protect_strength: float = 0.8, +) -> dict: + """ + Apply transferability masks during merge. + + For transferable weights: use the fused (merged) value + For task-specific weights: preserve more of the original target value + + Args: + target_state: Original target weights (before this merge) + fused_state: Newly fused weights (after T&M/Theseus fusion) + masks: Transferability masks (True = safe to change) + protect_strength: How much to protect task-specific weights (0-1) + + Returns: + Masked merged state dict + """ + result = {} + + for key in fused_state: + if key in masks and key in target_state: + mask = masks[key].to(fused_state[key].device) + original = target_state[key] + fused = fused_state[key] + + # Transferable: use fused value + # Task-specific: blend more toward original + blended = torch.where( + mask, + fused, # Transferable → take merged value + protect_strength * original + (1 - protect_strength) * fused, # Protected + ) + result[key] = blended + else: + result[key] = fused_state[key] + + protected_params = sum(1 for k in masks if not masks[k].all()) + print(f"[otmf] Applied masks: {protected_params} parameters partially protected") + + return result + + +# ============================================================================ +# 4. RAM — RL-Weight Disentanglement (2601.13572) +# ============================================================================ +# +# RL-trained models (DeepSeek-R1, MiMo-7B-RL) have two types of knowledge: +# - Shared: general language understanding (same as base model) +# - RL-specific: reasoning patterns learned via GRPO/RLHF +# +# RAM separates these so we can merge the shared parts normally +# but PRESERVE the RL-specific parts that make these models special. + +def disentangle_rl_weights( + rl_model: AutoModelForCausalLM, + base_model: AutoModelForCausalLM, + rl_threshold: float = 0.1, +) -> tuple: + """ + Separate RL-specific weights from shared/general weights. + + RL-specific = weights that changed significantly during RL training + Shared = weights that are basically the same as base + + We identify RL-specific weights by looking at the magnitude of + change from base model to RL model. Big changes → RL learned + something there → don't average it away. + + Args: + rl_model: The RL-trained model (e.g., DeepSeek-R1, MiMo-7B-RL) + base_model: The base model before RL training + rl_threshold: Relative change threshold for "RL-specific" classification + + Returns: + Tuple of (shared_mask, rl_mask) — both are dicts of param_name → bool tensor + shared_mask: True = this weight is shared (safe to merge normally) + rl_mask: True = this weight is RL-specific (protect during merge) + """ + print("[ram] Disentangling RL-specific vs shared weights...") + + rl_state = rl_model.state_dict() + base_state = base_model.state_dict() + + shared_mask = {} + rl_mask = {} + + total_params = 0 + rl_params = 0 + + for key in rl_state: + if key not in base_state: + # New param (e.g., MTP head) — mark as RL-specific + rl_mask[key] = torch.ones_like(rl_state[key], dtype=torch.bool) + shared_mask[key] = torch.zeros_like(rl_state[key], dtype=torch.bool) + rl_params += rl_state[key].numel() + total_params += rl_state[key].numel() + continue + + rl_w = rl_state[key].float() + base_w = base_state[key].float() + + # Relative change: |rl - base| / (|base| + epsilon) + change = (rl_w - base_w).abs() + base_magnitude = base_w.abs() + 1e-8 + relative_change = change / base_magnitude + + # RL-specific: relative change > threshold + is_rl = relative_change > rl_threshold + rl_mask[key] = is_rl + shared_mask[key] = ~is_rl + + rl_params += is_rl.sum().item() + total_params += is_rl.numel() + + pct = rl_params / total_params * 100 if total_params > 0 else 0 + print(f"[ram] RL-specific: {rl_params:,} params ({pct:.1f}%)") + print(f"[ram] Shared: {total_params - rl_params:,} params ({100 - pct:.1f}%)") + + return shared_mask, rl_mask + + +def merge_with_rl_preservation( + target_state: dict, + source_state: dict, + shared_mask: dict, + rl_mask: dict, + shared_alpha: float = 0.5, + rl_alpha: float = 0.8, +) -> dict: + """ + Merge source into target while preserving RL-specific weights. + + Shared weights: normal blending at shared_alpha + RL-specific weights: stronger blending toward source (preserve RL knowledge) + + This prevents the RL reasoning capabilities from being diluted + by averaging with target weights. + + Args: + target_state: Current target model state + source_state: RL model state to merge in + shared_mask: Which params are shared (safe for normal merge) + rl_mask: Which params are RL-specific (preserve with higher alpha) + shared_alpha: Alpha for shared weights (normal) + rl_alpha: Alpha for RL-specific weights (higher = preserve more RL knowledge) + """ + print(f"[ram] Merging with RL preservation (shared α={shared_alpha}, RL α={rl_alpha})...") + + result = {} + for key in target_state: + if key not in source_state: + result[key] = target_state[key] + continue + + target_w = target_state[key] + source_w = source_state[key] + + if source_w.shape != target_w.shape: + result[key] = target_state[key] + continue + + if key in rl_mask and key in shared_mask: + rl_m = rl_mask[key].to(target_w.device) + # RL-specific: use higher alpha (preserve RL knowledge) + # Shared: use normal alpha + alpha_map = torch.where(rl_m, rl_alpha, shared_alpha) + if alpha_map.shape != target_w.shape: + alpha_map = alpha_map.expand_as(target_w) if alpha_map.dim() > 0 else torch.full_like(target_w, shared_alpha) + + result[key] = alpha_map * source_w.to(target_w.device) + (1 - alpha_map) * target_w + else: + result[key] = shared_alpha * source_w.to(target_w.device) + (1 - shared_alpha) * target_w + + return result + + +# ============================================================================ +# 5. MERGEABILITY PRE-CHECK (2601.22285) +# ============================================================================ +# +# Before spending GPU hours on a merge that might fail, check if the +# models are actually COMPATIBLE enough to merge. +# +# Mergeability score: 0.0 (definitely won't work) to 1.0 (should work great) + +def compute_mergeability_score( + source_activations: dict, + target_activations: dict, + source_config: ModelConfig, +) -> dict: + """ + Predict how well a source model will merge into the target. + + Scores based on three factors: + 1. Activation similarity (cosine similarity of mean activations) + 2. Dimensional compatibility (how similar are the layer shapes) + 3. Architecture match (same arch = bonus) + + Returns: + Dict with individual scores and overall mergeability (0-1) + """ + print(f"[mergeability] Scoring {source_config.name}...") + + scores = {} + + # --- Factor 1: Activation similarity --- + cosine_sims = [] + source_layers = sorted(source_activations.keys()) + target_layers = sorted(target_activations.keys()) + + # Match layers by position (proportional mapping) + for i, tl in enumerate(target_layers): + # Map target layer index to source layer index + src_idx = int(i * len(source_layers) / len(target_layers)) + src_idx = min(src_idx, len(source_layers) - 1) + sl = source_layers[src_idx] + + if sl in source_activations and tl in target_activations: + s_mean = source_activations[sl].float().mean(dim=0) + t_mean = target_activations[tl].float().mean(dim=0) + + # Pad to same dimension for cosine similarity + max_dim = max(s_mean.shape[0], t_mean.shape[0]) + s_padded = torch.nn.functional.pad(s_mean, (0, max_dim - s_mean.shape[0])) + t_padded = torch.nn.functional.pad(t_mean, (0, max_dim - t_mean.shape[0])) + + cos_sim = torch.nn.functional.cosine_similarity( + s_padded.unsqueeze(0), t_padded.unsqueeze(0) + ).item() + cosine_sims.append(cos_sim) + + activation_score = np.mean(cosine_sims) if cosine_sims else 0.0 + scores["activation_similarity"] = float(activation_score) + + # --- Factor 2: Dimensional compatibility --- + layer_ratio = min(source_config.layers, 36) / max(source_config.layers, 36) + hidden_ratio = min(source_config.hidden_dim, 4096) / max(source_config.hidden_dim, 4096) + dim_score = (layer_ratio + hidden_ratio) / 2 + scores["dimensional_compatibility"] = float(dim_score) + + # --- Factor 3: Architecture match --- + arch_scores = { + "transformer": 1.0, # Same as Qwen3 + "transformer+mtp": 0.8, # Close, just drop extras + "hybrid_ssm": 0.5, # Very different + } + arch_score = arch_scores.get(source_config.architecture, 0.3) + scores["architecture_match"] = float(arch_score) + + # --- Factor 4: Vocab overlap (bonus) --- + vocab_score = source_config.vocab_overlap_with_qwen3 + scores["vocab_overlap"] = float(vocab_score) + + # --- Overall: weighted average --- + overall = ( + 0.35 * activation_score + # Most important — actual representation similarity + 0.25 * dim_score + # Shape compatibility + 0.25 * arch_score + # Architecture type + 0.15 * vocab_score # Vocab overlap + ) + scores["overall"] = float(overall) + + # --- Recommendation --- + if overall >= 0.7: + recommendation = "GO — standard T&M merge" + elif overall >= 0.5: + recommendation = "CAUTION — T&M merge with higher protection, have Theseus fallback ready" + elif overall >= 0.3: + recommendation = "RISKY — try Theseus first, distillation fallback" + else: + recommendation = "SKIP — use knowledge distillation instead" + + scores["recommendation"] = recommendation + + print(f"[mergeability] {source_config.name} score: {overall:.2f}") + print(f" Activation similarity: {activation_score:.2f}") + print(f" Dimensional compat: {dim_score:.2f}") + print(f" Architecture match: {arch_score:.2f}") + print(f" Vocab overlap: {vocab_score:.2f}") + print(f" → {recommendation}") + + return scores diff --git a/hugging/td_fuse/transport.py b/hugging/td_fuse/transport.py new file mode 100644 index 0000000000000000000000000000000000000000..d10b716b12a5880fe7be041e8c5fbd4f0631d68c --- /dev/null +++ b/hugging/td_fuse/transport.py @@ -0,0 +1,527 @@ +""" +Transport and Merge Wrapper — interfaces with official T&M code. + +This wraps the official repo at: + github.com/chenhangcuisg-code/Cross-Architecture-Merging-for-Large-Language-Models/ + +We use THEIR code for: + - Correlation distance computation (corr_distance_matrix) + - Streaming Sinkhorn (sinkhorn_uniform_streaming) + - Transport plan computation (compute_P, compute_Q_and_layer_costs) + - Activation reconstruction (reconstruct_X) + +We add: + - Qwen3 thinking mode protection + - MiMo MTP head handling + - Falcon SSM component handling + - Sequential merge protection (MagMax + orthogonal projection) + +Findings: #01, #07, #24 +""" + +import sys +import torch +import numpy as np +from pathlib import Path +from typing import Optional +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset + +from .config import MergeConfig, ModelConfig, TARGET + + +def setup_tm_repo(cfg: MergeConfig): + """Add official T&M repo to Python path so we can import their code.""" + repo_path = Path(cfg.tm_repo_path) + core_path = repo_path / "core" + + if not core_path.exists(): + raise FileNotFoundError( + f"Official T&M repo not found at {repo_path}\n" + f"Please clone it:\n" + f" git clone https://github.com/chenhangcuisg-code/" + f"Cross-Architecture-Merging-for-Large-Language-Models.git" + ) + + # Add to path so we can import hot_transport etc. + if str(core_path) not in sys.path: + sys.path.insert(0, str(core_path)) + print(f"[transport] Added T&M core to path: {core_path}") + + +def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list: + """ + Load calibration data for activation extraction. + + Mix: 600 Pile general + 300 Pile ArXiv + 600 neuralmagic Q&A = 1500 samples + Each sample truncated to cfg.calibration_seq_len tokens. + + Findings: #08 + """ + print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...") + + samples = [] + + # --- Pile: general text (600 samples) --- + try: + pile = load_dataset( + cfg.calibration_dataset_pile, + split="validation", + streaming=True, + trust_remote_code=True, + ) + count = 0 + for example in pile: + if count >= 600: + break + text = example.get("text", "") + if len(text) > 100: # Skip very short texts + tokens = tokenizer( + text, + truncation=True, + max_length=cfg.calibration_seq_len, + return_tensors="pt", + ) + samples.append(tokens) + count += 1 + print(f" Pile general: {count} samples") + except Exception as e: + print(f" ⚠ Pile failed: {e}") + print(f" Falling back to neuralmagic only") + + # --- neuralmagic: Q&A calibration (up to remaining) --- + remaining = cfg.calibration_samples - len(samples) + if remaining > 0: + try: + nm = load_dataset( + cfg.calibration_dataset_nm, + split="train", + trust_remote_code=True, + ) + count = 0 + for example in nm: + if count >= remaining: + break + text = example.get("text", example.get("content", "")) + if len(str(text)) > 50: + tokens = tokenizer( + str(text), + truncation=True, + max_length=cfg.calibration_seq_len, + return_tensors="pt", + ) + samples.append(tokens) + count += 1 + print(f" neuralmagic: {count} samples") + except Exception as e: + print(f" ⚠ neuralmagic failed: {e}") + + print(f"[transport] Total calibration samples: {len(samples)}") + return samples + + +def extract_activations( + model: AutoModelForCausalLM, + calibration_data: list, + device: str = "cuda", +) -> dict: + """ + Extract intermediate activations from each layer of a model. + + Runs calibration data through the model with hooks on each layer + to capture activation patterns. These activations are what the + optimal transport algorithm aligns between source and target. + + Returns: + Dict mapping layer_name → activation tensor [num_samples, hidden_dim] + """ + print(f"[transport] Extracting activations from {len(calibration_data)} samples...") + + activations = {} + hooks = [] + + # Register hooks on each transformer layer + for name, module in model.named_modules(): + if hasattr(module, "self_attn") or name.endswith(".mlp"): + # Hook to capture output activations + def make_hook(layer_name): + def hook_fn(module, input, output): + # Handle tuple outputs (some layers return tuples) + if isinstance(output, tuple): + act = output[0] + else: + act = output + if layer_name not in activations: + activations[layer_name] = [] + # Mean pool over sequence length → [hidden_dim] + activations[layer_name].append( + act.detach().float().mean(dim=1).cpu() + ) + return hook_fn + + h = module.register_forward_hook(make_hook(name)) + hooks.append(h) + + # Forward pass on calibration data + model.eval() + with torch.no_grad(): + for i, tokens in enumerate(calibration_data): + inputs = {k: v.to(device) for k, v in tokens.items()} + try: + model(**inputs) + except Exception as e: + print(f" ⚠ Sample {i} failed: {e}") + continue + + if (i + 1) % 100 == 0: + print(f" Processed {i + 1}/{len(calibration_data)} samples") + + # Remove hooks + for h in hooks: + h.remove() + + # Stack activations: [num_samples, hidden_dim] + for key in activations: + activations[key] = torch.cat(activations[key], dim=0) + print(f" {key}: {activations[key].shape}") + + return activations + + +def compute_transport_plans( + source_activations: dict, + target_activations: dict, + cfg: MergeConfig, +) -> dict: + """ + Compute optimal transport plans between source and target activations. + + This is where the magic happens. We use the official T&M code's: + - corr_distance_matrix: correlation distance between activation vectors + - sinkhorn_uniform_streaming: memory-efficient Sinkhorn solver + - compute_P: layer-level coupling (which source layers → which target layers) + - compute_Q_and_layer_costs: neuron-level coupling within each layer pair + + Returns: + Dict with 'P' (layer coupling) and 'Q' (per-layer neuron coupling) matrices + """ + print("[transport] Computing transport plans...") + + try: + # Try importing official T&M code + from hot_transport import ( + corr_distance_matrix, + sinkhorn_uniform_streaming, + compute_P, + compute_Q_and_layer_costs, + ) + print("[transport] Using official T&M implementation") + return _compute_plans_official( + source_activations, target_activations, cfg, + corr_distance_matrix, sinkhorn_uniform_streaming, + compute_P, compute_Q_and_layer_costs, + ) + except ImportError: + print("[transport] Official T&M code not available, using fallback") + return _compute_plans_fallback( + source_activations, target_activations, cfg + ) + + +def _compute_plans_official( + source_act, target_act, cfg, + corr_distance_matrix, sinkhorn_uniform_streaming, + compute_P, compute_Q_and_layer_costs, +) -> dict: + """Use the official T&M code to compute transport plans.""" + + # Get matching layer pairs + source_layers = sorted(source_act.keys()) + target_layers = sorted(target_act.keys()) + + # Compute Q matrices (neuron-level) and layer costs + Q_matrices, layer_costs = compute_Q_and_layer_costs( + source_act, target_act, + source_layers, target_layers, + ) + + # Compute P matrix (layer-level coupling) + P = compute_P(layer_costs) + + return { + "P": P, + "Q": Q_matrices, + "source_layers": source_layers, + "target_layers": target_layers, + } + + +def _compute_plans_fallback( + source_act: dict, + target_act: dict, + cfg: MergeConfig, +) -> dict: + """ + Fallback transport plan computation when official code isn't available. + + Uses correlation distance + basic Sinkhorn. Less optimised than official + but functionally correct for testing. + """ + + source_layers = sorted(source_act.keys()) + target_layers = sorted(target_act.keys()) + + # --- Step 1: Correlation distance matrices per layer pair --- + Q_matrices = {} + layer_costs = np.zeros((len(source_layers), len(target_layers))) + + for i, sl in enumerate(source_layers): + for j, tl in enumerate(target_layers): + if sl not in source_act or tl not in target_act: + continue + + S = source_act[sl].numpy() # [samples, hidden_dim_source] + T = target_act[tl].numpy() # [samples, hidden_dim_target] + + # Correlation distance: 1 - pearson_correlation + # Between each pair of neurons across samples + # S: [samples, n_source], T: [samples, n_target] + S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8) + T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8) + corr = S_norm.T @ T_norm / S.shape[0] # [n_source, n_target] + cost = 1.0 - corr # Correlation distance + + # Basic Sinkhorn on this cost matrix + Q = _sinkhorn(cost, reg=cfg.sinkhorn_reg, max_iter=cfg.sinkhorn_max_iter) + Q_matrices[(sl, tl)] = Q + layer_costs[i, j] = cost.mean() + + # --- Step 2: Layer coupling (P matrix) --- + P = _sinkhorn(layer_costs, reg=cfg.sinkhorn_reg, max_iter=cfg.sinkhorn_max_iter) + + return { + "P": P, + "Q": Q_matrices, + "source_layers": source_layers, + "target_layers": target_layers, + } + + +def _sinkhorn( + cost_matrix: np.ndarray, + reg: float = 0.05, + max_iter: int = 100, +) -> np.ndarray: + """ + Basic Sinkhorn-Knopp algorithm for optimal transport. + + Solves: min - reg * H(T) + where H(T) is the entropy of the transport plan. + + This is the FALLBACK. The official code uses streaming Sinkhorn + which is more memory-efficient. + """ + n, m = cost_matrix.shape + K = np.exp(-cost_matrix / reg) + + u = np.ones(n) / n + v = np.ones(m) / m + + for _ in range(max_iter): + u = 1.0 / (K @ v + 1e-10) + v = 1.0 / (K.T @ u + 1e-10) + + # Transport plan + T = np.diag(u) @ K @ np.diag(v) + return T + + +def fuse_weights( + source_model: AutoModelForCausalLM, + target_model: AutoModelForCausalLM, + transport_plans: dict, + source_config: ModelConfig, + cfg: MergeConfig, +) -> AutoModelForCausalLM: + """ + Fuse source model weights into target model using transport plans. + + For each layer pair with significant coupling (P > threshold): + 1. Get the Q matrix (neuron-level correspondence) + 2. Transport source weights into target neuron basis: W_fused = Q @ W_source + 3. Blend with target: W_final = alpha * W_fused + (1-alpha) * W_target + + Special handling per model: + - DeepSeek: Direct merge (same architecture) + - MiMo: Skip MTP heads, skip embeddings + - Llama: Layer mapping (32→36), skip embeddings, drop QKV bias + - Falcon: Skip Mamba components, skip embeddings + + Returns: + Target model with fused weights + """ + print(f"\n[transport] Fusing {source_config.name} → target") + alpha = source_config.merge_alpha + + try: + # Try official fusion code first + from generate_hot_residual import fuse_attention_only_from_hot_dir + print("[transport] Using official fusion implementation") + # TODO: Adapt official fusion to our pipeline + # For now, fall through to manual fusion + except ImportError: + pass + + # --- Manual fusion using transport plans --- + source_state = source_model.state_dict() + target_state = target_model.state_dict() + P = transport_plans["P"] + Q = transport_plans["Q"] + + fused_count = 0 + skipped_count = 0 + + for target_key in target_state: + # Skip parameters we shouldn't merge + if _should_skip(target_key, source_config): + skipped_count += 1 + continue + + # Find corresponding source key + source_key = _map_key(target_key, source_config) + if source_key is None or source_key not in source_state: + skipped_count += 1 + continue + + target_w = target_state[target_key] + source_w = source_state[source_key] + + # Handle dimension mismatches + if target_w.shape != source_w.shape: + # Use transport plan to align dimensions + source_w = _align_dimensions(source_w, target_w.shape, Q, target_key) + if source_w is None: + skipped_count += 1 + continue + + # Blend: W_final = alpha * source + (1-alpha) * target + fused_w = alpha * source_w.to(target_w.device) + (1 - alpha) * target_w + target_state[target_key] = fused_w + fused_count += 1 + + # Apply thinking mode protection + if cfg.freeze_think_tokens and "embed_tokens" in target_key: + for token_id in cfg.think_token_ids: + if token_id < target_state["model.embed_tokens.weight"].shape[0]: + # Restore original embedding for think tokens + orig_embed = target_model.state_dict()["model.embed_tokens.weight"] + target_state["model.embed_tokens.weight"][token_id] = orig_embed[token_id] + print(f"[transport] Protected think token {token_id}") + + # Load fused weights + target_model.load_state_dict(target_state) + print(f"[transport] Fused {fused_count} params, skipped {skipped_count}") + + return target_model + + +def _should_skip(key: str, source_config: ModelConfig) -> bool: + """Determine if a parameter should be skipped during merge.""" + + # Always skip if source model says to skip embeddings + if source_config.skip_embeddings and ("embed_tokens" in key or "lm_head" in key): + return True + + # Skip MiMo MTP heads + if "drop_mtp_heads" in source_config.special_handling and "mtp_head" in key: + return True + + # Skip Falcon Mamba-specific parameters + if "drop_mamba_state_params" in source_config.special_handling: + mamba_keys = ["mamba", "A_log", "dt_proj", ".D"] + if any(mk in key for mk in mamba_keys): + return True + + # Skip QKV bias for Llama (Qwen3 doesn't have it) + if "drop_qkv_bias" in source_config.special_handling and ".bias" in key: + if any(proj in key for proj in ["q_proj", "k_proj", "v_proj"]): + return True + + return False + + +def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]: + """Map a target model parameter name to the corresponding source name.""" + + # For same-architecture models (DeepSeek), keys match directly + if source_config.architecture == "transformer" and source_config.layers == 36: + return target_key + + # For Llama (32 layers → 36 layers), map layer indices + if "layer_mapping_32_to_36" in source_config.special_handling: + if "model.layers." in target_key: + # Extract layer number + parts = target_key.split(".") + try: + layer_idx = int(parts[2]) + except (IndexError, ValueError): + return target_key + + # Map 36 target layers to 32 source layers (stride) + source_layer = int(layer_idx * 32 / 36) + parts[2] = str(source_layer) + return ".".join(parts) + + # For MiMo (same layer count, different extras), keys mostly match + if source_config.architecture == "transformer+mtp": + if "mtp_head" in target_key: + return None # MTP heads don't exist in target + return target_key + + # For Falcon hybrid, only attention and MLP keys map + if source_config.architecture == "hybrid_ssm": + if any(k in target_key for k in ["self_attn", "mlp", "layer_norm"]): + return target_key # These exist in both + return None # Mamba components don't map + + return target_key + + +def _align_dimensions( + source_w: torch.Tensor, + target_shape: tuple, + Q_matrices: dict, + key: str, +) -> Optional[torch.Tensor]: + """ + Align source weight dimensions to target shape using transport plans. + + For small mismatches: pad or truncate. + For large mismatches: use Q matrix to project. + """ + if source_w.shape == target_shape: + return source_w + + # Simple case: different width (FFN size difference) + if len(source_w.shape) == 2 and len(target_shape) == 2: + s_rows, s_cols = source_w.shape + t_rows, t_cols = target_shape + + result = torch.zeros(target_shape, dtype=source_w.dtype) + + # Copy what fits + min_rows = min(s_rows, t_rows) + min_cols = min(s_cols, t_cols) + result[:min_rows, :min_cols] = source_w[:min_rows, :min_cols] + + return result + + # 1D case (biases, layer norms) + if len(source_w.shape) == 1 and len(target_shape) == 1: + result = torch.zeros(target_shape, dtype=source_w.dtype) + min_len = min(source_w.shape[0], target_shape[0]) + result[:min_len] = source_w[:min_len] + return result + + # Can't align — skip this parameter + return None diff --git a/hugging/td_fuse/validate.py b/hugging/td_fuse/validate.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb2d361de941e2a04630a7772ccfff387ce9238 --- /dev/null +++ b/hugging/td_fuse/validate.py @@ -0,0 +1,215 @@ +""" +Post-Merge Validation — run after EVERY merge step. + +Tests: +1. Canary recall (did knowledge transfer?) +2. Perplexity check (did we break the model?) +3. Thinking mode (do tags still work?) +4. Quick reasoning test (can it still think?) + +Kill criteria: >10% performance drop on any test → abort merge. +Findings: #11, #22, #25 +""" + +import torch +import math +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .canary import test_all_canaries +from .config import MergeConfig + + +def validate_merged_model( + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + merged_sources: list[str], + cfg: MergeConfig, + baseline_perplexity: float = None, +) -> dict: + """ + Run full validation suite on a merged model. + + Args: + model: The merged model to validate + tokenizer: The tokenizer + merged_sources: List of source models merged so far + cfg: Merge configuration + baseline_perplexity: Perplexity of the target model before merging + + Returns: + Dict with test results and overall pass/fail + """ + print("\n" + "=" * 60) + print(f"VALIDATION — After merging: {', '.join(merged_sources)}") + print("=" * 60) + + results = { + "canary": None, + "perplexity": None, + "thinking_mode": None, + "reasoning": None, + "overall": False, + } + + # --- Test 1: Canary recall --- + canary_results = test_all_canaries(model, tokenizer, merged_sources) + passed_canaries = sum(1 for v in canary_results.values() if v) + total_canaries = len(canary_results) + results["canary"] = { + "passed": passed_canaries, + "total": total_canaries, + "ok": passed_canaries >= cfg.canary_pass_threshold, + "details": canary_results, + } + + # --- Test 2: Perplexity --- + perplexity = compute_perplexity(model, tokenizer) + ppl_ok = True + if baseline_perplexity is not None: + ratio = perplexity / baseline_perplexity + ppl_ok = ratio < cfg.perplexity_threshold + print(f"\n[validate] Perplexity: {perplexity:.2f} (baseline: {baseline_perplexity:.2f}, ratio: {ratio:.2f})") + if not ppl_ok: + print(f"[validate] ⚠ Perplexity ratio {ratio:.2f} exceeds threshold {cfg.perplexity_threshold}") + else: + print(f"\n[validate] Perplexity: {perplexity:.2f} (no baseline to compare)") + results["perplexity"] = {"value": perplexity, "ok": ppl_ok} + + # --- Test 3: Thinking mode --- + think_ok = test_thinking_mode(model, tokenizer) + results["thinking_mode"] = {"ok": think_ok} + + # --- Test 4: Quick reasoning --- + reason_ok = test_reasoning(model, tokenizer) + results["reasoning"] = {"ok": reason_ok} + + # --- Overall verdict --- + all_ok = ( + results["canary"]["ok"] + and results["perplexity"]["ok"] + and results["thinking_mode"]["ok"] + and results["reasoning"]["ok"] + ) + results["overall"] = all_ok + + # Summary + print("\n" + "-" * 60) + print("VALIDATION SUMMARY") + print("-" * 60) + print(f" Canary recall: {'✓' if results['canary']['ok'] else '✗'} ({passed_canaries}/{total_canaries})") + print(f" Perplexity: {'✓' if ppl_ok else '✗'} ({perplexity:.2f})") + print(f" Thinking mode: {'✓' if think_ok else '✗'}") + print(f" Reasoning: {'✓' if reason_ok else '✗'}") + print(f" OVERALL: {'✓ PASS' if all_ok else '✗ FAIL — consider aborting'}") + print("-" * 60) + + return results + + +def compute_perplexity( + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + test_texts: list[str] = None, +) -> float: + """ + Compute perplexity on a small test set. + + Lower perplexity = model is more confident about predicting text. + A big spike after merging means the model was damaged. + """ + if test_texts is None: + test_texts = [ + "The quick brown fox jumps over the lazy dog.", + "In mathematics, a prime number is a natural number greater than 1.", + "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)", + "The theory of general relativity describes gravity as the curvature of spacetime.", + "To solve 3x + 7 = 22, subtract 7 from both sides to get 3x = 15, then divide by 3.", + ] + + model.eval() + total_loss = 0.0 + total_tokens = 0 + + for text in test_texts: + inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) + inputs = {k: v.to(model.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = model(**inputs, labels=inputs["input_ids"]) + total_loss += outputs.loss.item() * inputs["input_ids"].shape[1] + total_tokens += inputs["input_ids"].shape[1] + + avg_loss = total_loss / total_tokens + perplexity = math.exp(avg_loss) + return perplexity + + +def test_thinking_mode( + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, +) -> bool: + """ + Test if the model still uses tags for reasoning. + + The thinking mode is Qwen3's special feature — if it's gone, + the merge damaged something critical. + """ + prompt = "Solve step by step: What is 15 × 13?" + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=200, + temperature=0.7, + do_sample=True, + ) + + response = tokenizer.decode(outputs[0], skip_special_tokens=False) + + # Check for thinking tags + has_think_open = "" in response + has_think_close = "" in response + passed = has_think_open and has_think_close + + print(f"\n[validate] Thinking mode test:") + print(f" Prompt: {prompt}") + print(f" Response: {response[:200]}...") + print(f" : {'✓ found' if has_think_open else '✗ missing'}") + print(f" : {'✓ found' if has_think_close else '✗ missing'}") + print(f" Status: {'✓ PASS' if passed else '✗ FAIL'}") + + return passed + + +def test_reasoning( + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, +) -> bool: + """ + Quick reasoning sanity check — can the model still do basic math? + + This catches catastrophic failures where the merge produced gibberish. + """ + prompt = "What is 7 + 8?" + expected_answer = "15" + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=50, + temperature=0.1, + do_sample=False, + ) + + response = tokenizer.decode(outputs[0], skip_special_tokens=True) + passed = expected_answer in response + + print(f"\n[validate] Quick reasoning test:") + print(f" Prompt: {prompt}") + print(f" Expected: {expected_answer}") + print(f" Got: {response}") + print(f" Status: {'✓ PASS' if passed else '✗ FAIL'}") + + return passed diff --git a/hugging/td_lang/.DS_Store b/hugging/td_lang/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b10270c437d0fb7547bd50edc2e7fdc0c8f2f992 Binary files /dev/null and b/hugging/td_lang/.DS_Store differ diff --git a/hugging/td_lang/__init__.py b/hugging/td_lang/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f59d84f44466fbe27ca66173119fe05f3fde0e70 --- /dev/null +++ b/hugging/td_lang/__init__.py @@ -0,0 +1,51 @@ +""" +TD Lang — Domain-specific language for Time Dilation project. + +Compiles .td files into Python code that calls td_fuse. +Write simple scripts instead of complex Python. + +Architecture: + td_lang/ + ├── __init__.py <- This file + ├── __main__.py <- Entry point for python -m td_lang + ├── grammar.py <- Lark grammar + parse tree transformer + ├── ast_nodes.py <- Dataclass AST nodes for each command + ├── compiler.py <- AST -> Python code generation + ├── executor.py <- Run compiled code, track lineage + ├── cli.py <- Command-line interface + ├── errors.py <- Custom exceptions + └── examples/ + ├── demo_merge.td <- Basic merge example + ├── demo_heal.td <- Merge + heal example + ├── demo_full.td <- Full pipeline with gates + budget + ├── demo_loop.td <- Self-improvement loop example + ├── demo_phase3.td <- Fork/edit/prune/reset example + └── demo_phase4.td <- Contracts + snapshot + report example + +Phase 1: load, merge, heal, eval, commit +Phase 2: diagnose, synth, train, debate +Phase 3: fork, reset, prune, edit +Phase 4: snapshot, report, data_contract, reward_contract +Phase 5: CLI polish, --version, info command, --verbose + +Designed from interviews test_14 (10 commands) and test_17 (ForgeSpec 2.0). +""" + +from .grammar import parse_td_file, parse_td_string # noqa: F401 +from .compiler import compile_program # noqa: F401 +from .executor import TDExecutor, check_td_file, compile_td_file, run_td_file # noqa: F401 + +__version__ = "0.2.0" +__author__ = "Milan (TD Project)" + +__all__ = [ + "parse_td_file", + "parse_td_string", + "compile_program", + "TDExecutor", + "check_td_file", + "compile_td_file", + "run_td_file", + "__version__", + "__author__", +] diff --git a/hugging/td_lang/__main__.py b/hugging/td_lang/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..14389125e1b5065eadcc19002faf6d7c75bce331 --- /dev/null +++ b/hugging/td_lang/__main__.py @@ -0,0 +1,5 @@ +"""Entry point for python -m td_lang.""" + +from .cli import main + +main() diff --git a/hugging/td_lang/__pycache__/__init__.cpython-310.pyc b/hugging/td_lang/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4feec4b8e3fbb6bd3fe37eb083a6cc7d4c6756ff Binary files /dev/null and b/hugging/td_lang/__pycache__/__init__.cpython-310.pyc differ diff --git a/hugging/td_lang/__pycache__/__init__.cpython-314.pyc b/hugging/td_lang/__pycache__/__init__.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62ecac9cef3d31bb6206649d2535875a10e55054 Binary files /dev/null and b/hugging/td_lang/__pycache__/__init__.cpython-314.pyc differ diff --git a/hugging/td_lang/__pycache__/__main__.cpython-310.pyc b/hugging/td_lang/__pycache__/__main__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d562cfe914473ac7545303f63e08bd4a6a92e22f Binary files /dev/null and b/hugging/td_lang/__pycache__/__main__.cpython-310.pyc differ diff --git a/hugging/td_lang/__pycache__/__main__.cpython-314.pyc b/hugging/td_lang/__pycache__/__main__.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91f623c60756fc2c93bad978882356753cfb875d Binary files /dev/null and b/hugging/td_lang/__pycache__/__main__.cpython-314.pyc differ diff --git a/hugging/td_lang/__pycache__/ast_nodes.cpython-310.pyc b/hugging/td_lang/__pycache__/ast_nodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0efc7e62d7e0fea5dc3a2829e16c951034421baa Binary files /dev/null and b/hugging/td_lang/__pycache__/ast_nodes.cpython-310.pyc differ diff --git a/hugging/td_lang/__pycache__/ast_nodes.cpython-314.pyc b/hugging/td_lang/__pycache__/ast_nodes.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f655c6be435c21d2c5c0bdf4d0e26db73fe3e91a Binary files /dev/null and b/hugging/td_lang/__pycache__/ast_nodes.cpython-314.pyc differ diff --git a/hugging/td_lang/__pycache__/cli.cpython-310.pyc b/hugging/td_lang/__pycache__/cli.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87c1b63ea709932008cb20f40110c92fbfbb6f34 Binary files /dev/null and b/hugging/td_lang/__pycache__/cli.cpython-310.pyc differ diff --git a/hugging/td_lang/__pycache__/cli.cpython-314.pyc b/hugging/td_lang/__pycache__/cli.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..720afdd8c85b6e82a6969c6aa60f866980055d65 Binary files /dev/null and b/hugging/td_lang/__pycache__/cli.cpython-314.pyc differ diff --git a/hugging/td_lang/__pycache__/compiler.cpython-310.pyc b/hugging/td_lang/__pycache__/compiler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7130bfd12dd8ab2a8908c8742241579c03cccad0 Binary files /dev/null and b/hugging/td_lang/__pycache__/compiler.cpython-310.pyc differ diff --git a/hugging/td_lang/__pycache__/compiler.cpython-314.pyc b/hugging/td_lang/__pycache__/compiler.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..225b0a92238b41085c6552d42e3ae64169ffb001 --- /dev/null +++ b/hugging/td_lang/__pycache__/compiler.cpython-314.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8bef7388fef05cdd8ee4edcc72a4b8907c8637caa22cfc802da044470a515c92 +size 162778 diff --git a/hugging/td_lang/__pycache__/errors.cpython-310.pyc b/hugging/td_lang/__pycache__/errors.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3a11b314afd3d5b0e13dc3fcb569bf8ac9d1bed Binary files /dev/null and b/hugging/td_lang/__pycache__/errors.cpython-310.pyc differ diff --git a/hugging/td_lang/__pycache__/errors.cpython-314.pyc b/hugging/td_lang/__pycache__/errors.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5e51e12dc87522cfb58d6bb0c5c3c77175d458c Binary files /dev/null and b/hugging/td_lang/__pycache__/errors.cpython-314.pyc differ diff --git a/hugging/td_lang/__pycache__/executor.cpython-310.pyc b/hugging/td_lang/__pycache__/executor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe78d598ccb4ee5942fb0f60086397ecf3426d2b Binary files /dev/null and b/hugging/td_lang/__pycache__/executor.cpython-310.pyc differ diff --git a/hugging/td_lang/__pycache__/executor.cpython-314.pyc b/hugging/td_lang/__pycache__/executor.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..189421fe0c0ecaa945ac805155438e9d851dd28e Binary files /dev/null and b/hugging/td_lang/__pycache__/executor.cpython-314.pyc differ diff --git a/hugging/td_lang/__pycache__/grammar.cpython-310.pyc b/hugging/td_lang/__pycache__/grammar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97d50ee6992dfba1dbf21ff58a3378e79bc65851 Binary files /dev/null and b/hugging/td_lang/__pycache__/grammar.cpython-310.pyc differ diff --git a/hugging/td_lang/__pycache__/grammar.cpython-314.pyc b/hugging/td_lang/__pycache__/grammar.cpython-314.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da10e51ca3b520b93642f2408c0aeeae602debaf Binary files /dev/null and b/hugging/td_lang/__pycache__/grammar.cpython-314.pyc differ diff --git a/hugging/td_lang/ast_nodes.py b/hugging/td_lang/ast_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..01d6aa1052fe5fe92a07f255814b461130965e1b --- /dev/null +++ b/hugging/td_lang/ast_nodes.py @@ -0,0 +1,421 @@ +""" +TD Lang AST Nodes — Dataclass containers for each parsed command. + +Each .td command becomes one of these nodes after parsing. +Phase 1 nodes are compiled into runnable Python; Phase 2 nodes are stubs so +the compiler can reject them with a clear error until they are implemented. +""" + +from dataclasses import dataclass, field +from typing import Any, List, Optional + + +# ============================================================================ +# PHASE 1 COMMANDS +# ============================================================================ + +@dataclass +class LoadCmd: + """Load a model and give it a name. + + Example: load "Qwen/Qwen3-VL-8B-Instruct" as base + """ + model_ref: str # HuggingFace path or local path + alias: str # Name to use in the rest of the script + + +@dataclass +class MergeCmd: + """Merge a source model into a target using a method. + + Example: merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5 + """ + source: str # Model path or alias to merge from + target: str # Alias to merge into (must be loaded first) + method: str # "transport", "slerp", "ties", "dare" + strength: float = 0.5 # 0.0 = keep target, 1.0 = keep source + + +@dataclass +class HealCmd: + """Run QLoRA healing fine-tune on a model. + + Example: heal base lora_r 32 epochs 2 + """ + target: str # Alias of model to heal + lora_r: int = 32 # LoRA rank (higher = more capacity) + epochs: int = 2 # Training epochs + + +@dataclass +class EvalCmd: + """Run validation/evaluation on a model. + + Example: eval base on "pile_sample" -> report.json + """ + target: str # Alias of model to evaluate + dataset: Optional[str] = None # Optional dataset name/path + output: Optional[str] = None # Optional output file path + + +@dataclass +class CommitCmd: + """Save model checkpoint, optionally requiring gates to pass. + + Example: commit base if [canary, perplexity, thinking_mode] + """ + target: str # Alias of model to commit + gates: Optional[list[str]] = None # Gate names that must pass + + +# ============================================================================ +# PHASE 2 COMMANDS (placeholders — structure ready, not wired up yet) +# ============================================================================ + +@dataclass +class SynthCmd: + """Generate synthetic training data from a model. (Phase 2)""" + target: str + source: str + filter_method: Optional[str] = None + output: Optional[str] = None + + +@dataclass +class TrainCmd: + """Train a model on a dataset. (Phase 2)""" + target: str + dataset: str + method: str = "grpo" # "grpo", "sft", "dpo" + steps: Optional[int] = None + learning_rate: Optional[float] = None + + +@dataclass +class DebateCmd: + """Generate multi-answer debate for preference pairs. (Phase 2)""" + target: str + rounds: int = 3 + candidates: int = 8 + output: Optional[str] = None + + +@dataclass +class DiagnoseCmd: + """Ask model what it's bad at — self-diagnosis. (Phase 2)""" + target: str + output: Optional[str] = None + + +@dataclass +class ForkCmd: + """Branch current model weights for parallel experiments. (Phase 3) + + Example: fork base as experiment_v2 + Cheap fork: copies manifest + adapters, shares base weights (default). + """ + source: str # Alias of model to fork from + alias: str # Name for the new branch + + +@dataclass +class ResetCmd: + """Revert model to a previous checkpoint. (Phase 3) + + Example: reset base to "checkpoint_042" + Deletes current model, clears CUDA cache, reloads from disk. + Must also reset optimizer state. + """ + target: str # Alias of model to reset + checkpoint: str # Checkpoint name/path to revert to + + +@dataclass +class PruneCmd: + """Structural pruning — remove low-utility neurons/heads. (Phase 3) + + Example: prune base using wanda aggressiveness 0.2 + Safe zone: ~20% max (LLM-Pruner paper). Language backbone only. + """ + target: str + method: str = "wanda" # "wanda", "magnitude", "taylor" + aggressiveness: float = 0.2 # Fraction to remove (0.0-1.0) + + +@dataclass +class EditCmd: + """Surgical LoRA/DoRA editing on specific layers. (Phase 3) + + Example: edit base layers 16-28 using lora lr 1e-4 + "Try before buy": eval with adapter enabled vs disabled before merging. + """ + target: str + layers: str = "all" # "all", "16-28", single number + method: str = "lora" # "lora" or "dora" + learning_rate: Optional[float] = None + + +# ============================================================================ +# PHASE 4 COMMANDS — Contracts, Lineage, Economics (ForgeSpec 2.0, test_17) +# ============================================================================ + +# ============================================================================ +# PHASE 7 — LOOP CONTROL (repeat, if/else) +# ============================================================================ + +@dataclass +class RepeatBlock: + """Repeat a block of commands N times. (Phase 7 — Loop Control) + + Example: + repeat 5 { + diagnose base + synth base from base + train base on "data.jsonl" using grpo steps 64 + eval base + } + """ + count: int # Number of iterations + body: List[Any] = field(default_factory=list) # Commands inside the block + + +@dataclass +class IfBlock: + """Conditional execution based on last eval result. (Phase 7 — Loop Control) + + Example: + if eval_passed { + commit base + } else { + reset base to "last_good" + } + + Condition checks the most recent eval result for the target. + """ + condition: str # "eval_passed", "gate_passed", etc. + target: Optional[str] = None # Which model's eval to check + then_body: List[Any] = field(default_factory=list) + else_body: List[Any] = field(default_factory=list) + + +@dataclass +class FuseCmd: + """Fuse multiple models into a target in one shot. (Phase 6 — Easy Merge) + + Example: fuse [deepseek-r1, mimo-7b, llama-3.1] into base + Auto-picks Transport and Merge, auto-sets per-model strength. + Handles cross-architecture merging (all 5 source models have different archs). + """ + sources: list[str] # List of model names/paths to fuse in + target: str # Alias to merge into (must be loaded) + method: str = "transport" # Default: transport and merge (cross-arch) + strategy: str = "equal" # "equal" (same strength each), "weighted", "sequential" + + +@dataclass +class AbsorbCmd: + """Absorb a single model into target — simplified merge. (Phase 6 — Easy Merge) + + Example: absorb "deepseek-ai/DeepSeek-R1" into base strength 0.5 + One-liner for the common case of merging one model in. + """ + source: str # Model path or HF ID + target: str # Alias to merge into + strength: float = 0.5 # 0.0=keep target, 1.0=keep source, default balanced + + +@dataclass +class SnapshotCmd: + """Save a content-hashed snapshot of model state for lineage tracking. (Phase 4) + + Example: snapshot base -> snapshots/ + Creates a content-addressed directory: snapshots// + Contains: model state, adapter state, prune spec, eval report, manifest. + """ + target: str + output: Optional[str] = None # Output directory (default: td_lang_outputs/snapshots/) + + +@dataclass +class ReportCmd: + """Generate an economics report for this run. (Phase 4) + + Example: report -> economics.json + Tracks: GPU hours, cost estimate, tokens processed, experiments run, + time per command, cost breakdown by phase. + """ + output: Optional[str] = None # Output file path + + +# ============================================================================ +# PHASE 8 — AUTOPILOT (setup, notify, save, on_error, resume) +# ============================================================================ + +@dataclass +class NotifyCmd: + """Send a notification via ntfy.sh. (Phase 8 — Autopilot) + + Example: notify "Training complete!" + Uses curl to POST to the configured ntfy topic. + """ + message: str + + +@dataclass +class SaveCmd: + """Save/upload model to cloud storage via rclone. (Phase 8 — Autopilot) + + Example: save base to "gdrive:TD/models/v1" + Uses rclone to copy model checkpoint to Google Drive (or any rclone remote). + """ + target: str # Alias of model to save + destination: str # rclone destination path + + +@dataclass +class SetupBlock: + """Auto-install dependencies and configure environment. (Phase 8 — Autopilot) + + Example: + setup { + pip = [torch, transformers, peft, bitsandbytes, trl] + hf_token = env + notify = "ntfy.sh/my_ai" + } + """ + pip_packages: list[str] = field(default_factory=list) + hf_token: Optional[str] = None # "env" = read HF_TOKEN from env + notify_url: Optional[str] = None # ntfy.sh topic URL + + +@dataclass +class OnErrorBlock: + """Crash recovery behavior. (Phase 8 — Autopilot) + + Example: + on_error { + retry = 3 + fallback = reduce_batch + notify = true + } + """ + retry: int = 3 # Number of retries per failed step + fallback: str = "reduce_batch" # "reduce_batch", "skip", "snapshot_and_stop" + notify: bool = True # Send ntfy notification on error + + +# ============================================================================ +# BLOCKS (gates, budget, contracts, etc.) +# ============================================================================ + +@dataclass +class GateBlock: + """Validation gates that must pass before commit. + + Example: + gate { + must_pass = [canary, perplexity, thinking_mode] + } + """ + must_pass: list[str] = field(default_factory=list) + + +@dataclass +class BudgetBlock: + """Resource budget — compiler refuses plans that exceed limits. + + Example: + budget { + max_gpu_hours = 8 + max_cost = 50.00 + } + """ + max_gpu_hours: Optional[float] = None + max_cost: Optional[float] = None + max_tokens: Optional[int] = None + max_experiments: Optional[int] = None + + +@dataclass +class DataContractBlock: + """Schema enforcement on training data. (Phase 4, ForgeSpec 2.0) + + Example: + data_contract { + required_fields = [prompt, response] + min_samples = 100 + max_perplexity = 50.0 + } + + Compiler checks training data at synth/train time. + """ + required_fields: list[str] = field(default_factory=list) + min_samples: Optional[int] = None + max_perplexity: Optional[float] = None + + +@dataclass +class RewardContractBlock: + """Verified reward definitions — what counts as "correct". (Phase 4, ForgeSpec 2.0) + + Example: + reward_contract { + verifiers = [code_compiles, math_correct, no_hallucination] + min_reward = 0.3 + } + + Used by train (GRPO) to enforce reward quality. + No learned reward model — verified rewards only (test_16). + """ + verifiers: list[str] = field(default_factory=list) + min_reward: Optional[float] = None + + +# ============================================================================ +# TOP-LEVEL PROGRAM +# ============================================================================ + +@dataclass +class TDProgram: + """A complete parsed .td file — commands in order plus global blocks.""" + + commands: List[Any] = field(default_factory=list) + gates: Optional[GateBlock] = None + budget: Optional[BudgetBlock] = None + data_contract: Optional[DataContractBlock] = None + reward_contract: Optional[RewardContractBlock] = None + setup: Optional[SetupBlock] = None + on_error: Optional[OnErrorBlock] = None + source_file: Optional[str] = None + + +__all__ = [ + "LoadCmd", + "MergeCmd", + "HealCmd", + "EvalCmd", + "CommitCmd", + "SynthCmd", + "TrainCmd", + "DebateCmd", + "DiagnoseCmd", + "ForkCmd", + "ResetCmd", + "PruneCmd", + "EditCmd", + "RepeatBlock", + "IfBlock", + "FuseCmd", + "AbsorbCmd", + "SnapshotCmd", + "ReportCmd", + "NotifyCmd", + "SaveCmd", + "SetupBlock", + "OnErrorBlock", + "GateBlock", + "BudgetBlock", + "DataContractBlock", + "RewardContractBlock", + "TDProgram", +] diff --git a/hugging/td_lang/cli.py b/hugging/td_lang/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..1b19cc6d95da36e0921e20119000a202bf8ce872 --- /dev/null +++ b/hugging/td_lang/cli.py @@ -0,0 +1,212 @@ +""" +TD Lang CLI — Command-line interface for .td files. + +Usage: + python -m td_lang run examples/demo_merge.td # Compile + execute + python -m td_lang compile examples/demo_merge.td # Compile only (outputs .py) + python -m td_lang check examples/demo_merge.td # Syntax check only + python -m td_lang info examples/demo_merge.td # Show plan without compiling + python -m td_lang --version # Show version +""" + +import argparse +import sys + +from . import __version__ +from .executor import TDExecutor +from .errors import TDLangError +from .grammar import parse_td_file +from .ast_nodes import ( + LoadCmd, MergeCmd, HealCmd, EvalCmd, CommitCmd, + SynthCmd, TrainCmd, DebateCmd, DiagnoseCmd, + ForkCmd, ResetCmd, PruneCmd, EditCmd, + FuseCmd, AbsorbCmd, RepeatBlock, IfBlock, + NotifyCmd, SaveCmd, + SnapshotCmd, ReportCmd, +) + + +# Phase labels for info command +_PHASE_MAP = { + LoadCmd: ("1", "load"), + MergeCmd: ("1", "merge"), + HealCmd: ("1", "heal"), + EvalCmd: ("1", "eval"), + CommitCmd: ("1", "commit"), + SynthCmd: ("2", "synth"), + TrainCmd: ("2", "train"), + DebateCmd: ("2", "debate"), + DiagnoseCmd: ("2", "diagnose"), + ForkCmd: ("3", "fork"), + ResetCmd: ("3", "reset"), + PruneCmd: ("3", "prune"), + EditCmd: ("3", "edit"), + FuseCmd: ("6", "fuse"), + AbsorbCmd: ("6", "absorb"), + RepeatBlock: ("7", "repeat"), + IfBlock: ("7", "if"), + NotifyCmd: ("8", "notify"), + SaveCmd: ("8", "save"), + SnapshotCmd: ("4", "snapshot"), + ReportCmd: ("4", "report"), +} + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="TD Lang — compile and run .td files for Time Dilation", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python -m td_lang check examples/demo_merge.td # Check syntax + python -m td_lang compile examples/demo_merge.td # Compile to .py + python -m td_lang run examples/demo_merge.td # Compile + run + python -m td_lang run examples/demo_merge.td --dry # Compile only + python -m td_lang info examples/demo_merge.td # Show plan summary + """, + ) + + parser.add_argument( + "--version", + action="version", + version=f"td_lang {__version__}", + ) + + parser.add_argument( + "action", + choices=["check", "compile", "run", "info"], + help="What to do: check (syntax), compile (.py), run (compile+execute), info (show plan)", + ) + + parser.add_argument( + "file", + type=str, + help="Path to the .td file", + ) + + parser.add_argument( + "--output", + type=str, + default="td_lang_outputs", + help="Output directory (default: td_lang_outputs)", + ) + + parser.add_argument( + "--dry", + action="store_true", + help="With 'run': compile but don't execute", + ) + + parser.add_argument( + "--verbose", "-v", + action="store_true", + help="Show extra detail (compiled Python, full AST, etc.)", + ) + + return parser.parse_args() + + +def print_banner(): + """Print the td_lang banner.""" + banner = f""" + ╔═══════════════════════════════════════╗ + ║ ║ + ║ ████████╗██████╗ ██╗ ██████╗║ + ║ ╚══██╔══╝██╔══██╗ ██║ ██╔════╝║ + ║ ██║ ██║ ██║ ██║ ██║ ███║ + ║ ██║ ██║ ██║ ██║ ██║ ██║ + ║ ██║ ██████╔╝ ██████╗ ╚██████╔╝║ + ║ ╚═╝ ╚═════╝ ╚═════╝ ╚═════╝║ + ║ ║ + ║ TD Lang v{__version__} — .td file compiler ║ + ║ ║ + ╚═══════════════════════════════════════╝ + """ + print(banner) + + +def print_info(filepath: str) -> None: + """Show what a .td file does without compiling — human-readable plan summary.""" + program = parse_td_file(filepath) + + print(f"\n File: {filepath}") + print(f" Commands: {len(program.commands)}") + + if program.gates: + print(f" Gates: {', '.join(program.gates.must_pass)}") + if program.budget: + parts = [] + if program.budget.max_gpu_hours is not None: + parts.append(f"{program.budget.max_gpu_hours} GPU hrs") + if program.budget.max_cost is not None: + parts.append(f"${program.budget.max_cost}") + print(f" Budget: {', '.join(parts)}") + if program.data_contract: + print(f" Data contract: fields={program.data_contract.required_fields}") + if program.reward_contract: + print(f" Reward contract: verifiers={program.reward_contract.verifiers}") + + print("\n Plan:") + for i, cmd in enumerate(program.commands, 1): + phase, name = _PHASE_MAP.get(type(cmd), ("?", type(cmd).__name__)) + target = getattr(cmd, 'target', getattr(cmd, 'alias', '')) + detail = "" + if hasattr(cmd, 'method'): + detail += f" method={cmd.method}" + if hasattr(cmd, 'source') and name in ("merge", "synth"): + detail += f" from={cmd.source}" + if hasattr(cmd, 'layers') and cmd.layers != "all": + detail += f" layers={cmd.layers}" + if hasattr(cmd, 'output') and cmd.output: + detail += f" -> {cmd.output}" + print(f" {i}. [P{phase}] {name} {target}{detail}") + + print() + + +def main(): + """Main entry point for td_lang CLI.""" + args = parse_args() + print_banner() + + executor = TDExecutor(output_dir=args.output) + + try: + if args.action == "info": + print_info(args.file) + + elif args.action == "check": + program = executor.check(args.file) + print("\n[td_lang] File is valid!") + + elif args.action == "compile": + py_path = executor.compile(args.file) + print(f"\n[td_lang] Generated: {py_path}") + print("[td_lang] You can run it with: python", py_path) + if args.verbose: + print("\n--- Generated Python ---") + print(py_path.read_text()) + print("--- End ---") + + elif args.action == "run": + result = executor.run(args.file, dry_run=args.dry) + if result["status"] == "success": + sys.exit(0) + elif result["status"] == "dry_run": + sys.exit(0) + else: + sys.exit(1) + + except TDLangError as e: + print(f"\n[td_lang] ERROR: {e}") + sys.exit(1) + + except FileNotFoundError: + print(f"\n[td_lang] ERROR: File not found: {args.file}") + print("[td_lang] Check the path and try again.") + sys.exit(1) + + except KeyboardInterrupt: + print("\n[td_lang] Interrupted.") + sys.exit(130) diff --git a/hugging/td_lang/compiler.py b/hugging/td_lang/compiler.py new file mode 100644 index 0000000000000000000000000000000000000000..2aba8cb13aa3c8b4df1b37d4b16caa30359dc68c --- /dev/null +++ b/hugging/td_lang/compiler.py @@ -0,0 +1,2646 @@ +""" +TD Lang Compiler — turns a TDProgram AST into readable Python code that calls td_fuse. + +Phase 1 commands: load, merge, heal, eval, commit. +Phase 2 commands: synth, train, debate, diagnose. +Phase 3 commands: fork, reset, prune, edit. +Phase 4 commands: snapshot, report. Blocks: data_contract, reward_contract. +""" + +from __future__ import annotations + +import hashlib +import textwrap +from datetime import datetime +from typing import List, Optional, Set + +from .ast_nodes import ( + AbsorbCmd, + BudgetBlock, + CommitCmd, + DataContractBlock, + DebateCmd, + DiagnoseCmd, + EditCmd, + EvalCmd, + FuseCmd, + ForkCmd, + IfBlock, + GateBlock, + HealCmd, + LoadCmd, + MergeCmd, + NotifyCmd, + OnErrorBlock, + PruneCmd, + RepeatBlock, + ReportCmd, + ResetCmd, + RewardContractBlock, + SaveCmd, + SetupBlock, + SnapshotCmd, + SynthCmd, + TDProgram, + TrainCmd, +) +from .errors import TDCompileError + +# All command types are now implemented (Phase 1 + 2 + 3) + + +class TDCompiler: + """Compile a TDProgram into a Python script string.""" + + GPU_HOURLY = 4.0 # simple heuristic for budget calculations + + def __init__(self) -> None: + self._aliases: Set[str] = set() + self._lines: List[str] = [] + self._indent: int = 0 + + # ------------------------------------------------------------------ Public + def compile(self, program: TDProgram) -> str: + """Compile a TDProgram into Python code.""" + self._reset_state() + self._validate(program) + self._build_script(program) + return "\n".join(self._lines) + + # ---------------------------------------------------------------- Internal helpers + def _reset_state(self) -> None: + self._aliases.clear() + self._lines = [] + self._indent = 0 + + def _validate(self, program: TDProgram) -> None: + """Semantic validation before emitting code.""" + seen: Set[str] = set() + for cmd in program.commands: + if isinstance(cmd, LoadCmd): + if cmd.alias in seen: + raise TDCompileError( + f"Alias '{cmd.alias}' is already used. Pick a different name.", + ) + seen.add(cmd.alias) + elif isinstance(cmd, MergeCmd): + if cmd.target not in seen: + raise TDCompileError( + f"Can't merge into '{cmd.target}' — it hasn't been loaded yet.", + hint=f'Add: load "{cmd.source}" as {cmd.target}', + ) + elif isinstance(cmd, (HealCmd, EvalCmd, CommitCmd)): + if cmd.target not in seen: + raise TDCompileError( + f"Can't use '{cmd.target}' — it hasn't been loaded yet.", + hint=f'Add: load "model/path" as {cmd.target}', + ) + elif isinstance(cmd, (SynthCmd, TrainCmd, DebateCmd, DiagnoseCmd)): + if cmd.target not in seen: + raise TDCompileError( + f"Can't use '{cmd.target}' — it hasn't been loaded yet.", + hint=f'Add: load "model/path" as {cmd.target}', + ) + elif isinstance(cmd, ForkCmd): + if cmd.source not in seen: + raise TDCompileError( + f"Can't fork '{cmd.source}' — it hasn't been loaded yet.", + hint=f'Add: load "model/path" as {cmd.source}', + ) + if cmd.alias in seen: + raise TDCompileError( + f"Alias '{cmd.alias}' is already used. Pick a different name for the fork.", + ) + seen.add(cmd.alias) + elif isinstance(cmd, (ResetCmd, PruneCmd, EditCmd)): + if cmd.target not in seen: + raise TDCompileError( + f"Can't use '{cmd.target}' — it hasn't been loaded yet.", + hint=f'Add: load "model/path" as {cmd.target}', + ) + elif isinstance(cmd, SnapshotCmd): + if cmd.target not in seen: + raise TDCompileError( + f"Can't snapshot '{cmd.target}' — it hasn't been loaded yet.", + hint=f'Add: load "model/path" as {cmd.target}', + ) + elif isinstance(cmd, ReportCmd): + pass # report has no target — always valid + elif isinstance(cmd, FuseCmd): + if cmd.target not in seen: + raise TDCompileError( + f"Can't fuse into '{cmd.target}' — it hasn't been loaded yet.", + hint=f'Add: load "model/path" as {cmd.target}', + ) + if len(cmd.sources) < 1: + raise TDCompileError( + "Fuse needs at least 1 model in the list.", + hint='fuse ["model1", "model2"] into target', + ) + elif isinstance(cmd, AbsorbCmd): + if cmd.target not in seen: + raise TDCompileError( + f"Can't absorb into '{cmd.target}' — it hasn't been loaded yet.", + hint=f'Add: load "model/path" as {cmd.target}', + ) + + # ---------------------------------------------------------------- Build script + def _build_script(self, program: TDProgram) -> None: + """Construct the full Python script lines.""" + self._emit("#!/usr/bin/env python3") + source_hash = hashlib.sha256(str(program).encode()).hexdigest()[:12] + source_name = program.source_file or "unknown.td" + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + doc = textwrap.dedent( + f'''""" +Auto-generated by td_lang v0.1.0 +Source: {source_name} +Compiled: {timestamp} +Hash: {source_hash} + +DO NOT EDIT — regenerate from the .td file instead. +"""''' + ) + self._emit(doc) + self._emit("import json") + self._emit("import os") + self._emit("import sys") + self._emit("import time") + self._emit("from datetime import datetime") + self._emit("from pathlib import Path") + self._emit("") + self._emit("from td_fuse.config import MergeConfig, SOURCES, TARGET") + self._emit("from td_fuse.merge import run_pipeline") + self._emit("from td_fuse.heal import heal_model") + self._emit("from td_fuse.validate import validate_merged_model") + self._emit("") + self._emit("from td_lang.errors import TDBudgetError, TDGateError") + self._emit("") + self._emit(f"GPU_HOURLY = {self.GPU_HOURLY}") + self._emit("") + self._emit("") + self._emit("def main():") + self._indent += 1 + self._emit("start_time = time.time()") + self._emit("lineage = {}") + self._emit("models = {}") + self._emit("results = {}") + self._emit("merged_stages = []") + self._emit("output_dir = str(Path('.').resolve())") + self._emit("") + self._emit("# Quick canary check helper (lightweight sanity)") + self._emit("def quick_canary(checkpoint: str) -> float:") + self._indent += 1 + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") + self._emit("import torch") + self._emit("prompts = [") + self._indent += 1 + self._emit('"What is 2+2?",') + self._emit('"Spell the word apple.",') + self._emit('"Name a color that starts with B.",') + self._emit('"List two prime numbers.",') + self._emit('"What is the capital of France?",') + self._indent -= 1 + self._emit("]") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.float16, device_map='auto')") + self._emit("model.eval()") + self._emit("scores = []") + self._emit("for p in prompts:") + self._indent += 1 + self._emit("inputs = tok(p, return_tensors='pt').to(model.device)") + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = model.generate(**inputs, max_new_tokens=32, do_sample=False)") + self._indent -= 1 + self._emit("resp = tok.decode(out[0], skip_special_tokens=True)") + self._emit("scores.append(len(resp))") + self._indent -= 1 + self._emit("avg_len = sum(scores) / len(scores)") + self._emit("del model, tok") + self._emit("import gc; gc.collect()") + self._emit("return avg_len") + self._indent -= 1 + self._emit("") + + if program.setup: + self._emit_setup(program.setup) + + if program.on_error: + self._emit_on_error(program.on_error, program) + + if program.budget: + self._emit_budget_check(program) + + if program.data_contract: + self._emit_data_contract(program.data_contract) + + if program.reward_contract: + self._emit_reward_contract(program.reward_contract) + + for index, cmd in enumerate(program.commands, start=1): + self._emit_comment(f"Step {index}: {type(cmd).__name__}") + if isinstance(cmd, LoadCmd): + self._emit_load(cmd) + elif isinstance(cmd, MergeCmd): + self._emit_merge(cmd) + elif isinstance(cmd, HealCmd): + self._emit_heal(cmd) + elif isinstance(cmd, EvalCmd): + self._emit_eval(cmd) + elif isinstance(cmd, CommitCmd): + self._emit_commit(cmd, program.gates) + elif isinstance(cmd, DiagnoseCmd): + self._emit_diagnose(cmd) + elif isinstance(cmd, SynthCmd): + self._emit_synth(cmd) + elif isinstance(cmd, TrainCmd): + self._emit_train(cmd) + elif isinstance(cmd, DebateCmd): + self._emit_debate(cmd) + elif isinstance(cmd, EditCmd): + self._emit_edit(cmd) + elif isinstance(cmd, ForkCmd): + self._emit_fork(cmd) + elif isinstance(cmd, ResetCmd): + self._emit_reset(cmd) + elif isinstance(cmd, PruneCmd): + self._emit_prune(cmd) + elif isinstance(cmd, FuseCmd): + self._emit_fuse(cmd) + elif isinstance(cmd, AbsorbCmd): + self._emit_absorb(cmd) + elif isinstance(cmd, RepeatBlock): + self._emit_repeat(cmd, program) + elif isinstance(cmd, IfBlock): + self._emit_if(cmd, program) + elif isinstance(cmd, SnapshotCmd): + self._emit_snapshot(cmd, program) + elif isinstance(cmd, ReportCmd): + self._emit_report(cmd, program) + elif isinstance(cmd, NotifyCmd): + self._emit_notify(cmd, program) + elif isinstance(cmd, SaveCmd): + self._emit_save(cmd, program) + self._emit("") + + self._emit_summary() + self._indent -= 1 + self._emit("") + self._emit('if __name__ == "__main__":') + self._indent += 1 + self._emit("main()") + self._indent -= 1 + + # ---------------------------------------------------------------- Emitters + def _emit_load(self, cmd: LoadCmd) -> None: + self._aliases.add(cmd.alias) + self._emit(f'print("[td_lang] Loading {cmd.alias} from {cmd.model_ref}...")') + self._emit("") + + # Actually download the model if it's a HF path + self._emit(f'_model_ref = "{cmd.model_ref}"') + self._emit("if '/' in _model_ref and not os.path.exists(_model_ref):") + self._indent += 1 + self._emit(f'print("[td_lang] Downloading from HuggingFace: {cmd.model_ref}")') + self._emit("try:") + self._indent += 1 + self._emit("from huggingface_hub import snapshot_download") + self._emit(f'_local_path = snapshot_download(_model_ref, local_dir=f"models/{cmd.alias}")') + self._emit(f'print(f"[td_lang] Downloaded to {{_local_path}}")') + self._indent -= 1 + self._emit("except ImportError:") + self._indent += 1 + self._emit('print("[td_lang] huggingface_hub not installed. Storing ref only — download will happen at merge time.")') + self._emit("_local_path = _model_ref") + self._indent -= 1 + self._emit("except Exception as e:") + self._indent += 1 + self._emit('print(f"[td_lang] Download warning: {e}. Storing ref for later.")') + self._emit("_local_path = _model_ref") + self._indent -= 1 + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("_local_path = _model_ref") + self._indent -= 1 + self._emit("") + + self._emit(f'models["{cmd.alias}"] = {{') + self._indent += 1 + self._emit(f'"model_ref": "{cmd.model_ref}",') + self._emit('"local_path": _local_path,') + self._emit('"checkpoint": None,') + self._emit('"loaded_at": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("}") + self._emit(f'lineage["{cmd.alias}"] = {{"source": "{cmd.model_ref}", "operations": []}}') + self._emit(f'print("[td_lang] {cmd.alias} ready.")') + + def _emit_merge(self, cmd: MergeCmd) -> None: + self._emit( + f'print("[td_lang] Merging {cmd.source} into {cmd.target} using {cmd.method} (strength={cmd.strength})...")' + ) + self._emit(f'_source_ref = "{cmd.source}"') + self._emit("_stage = None") + self._emit("for _src in SOURCES:") + self._indent += 1 + self._emit('if _src.hf_id == _source_ref or _src.name.lower() in _source_ref.lower():') + self._indent += 1 + self._emit('_stage = _src.name.lower().split("-")[0]') + self._emit(f"_src.merge_alpha = {cmd.strength}") + self._emit("break") + self._indent -= 1 + self._indent -= 1 + self._emit("if _stage is None:") + self._indent += 1 + self._emit('raise SystemExit(f"Could not match source {_source_ref} to any SOURCES entry.")') + self._indent -= 1 + self._emit("cfg = MergeConfig()") + self._emit("merge_result = run_pipeline([_stage], cfg)") + self._emit(f'results["{cmd.target}_merge"] = merge_result') + self._emit("merged_stages.append(_stage)") + self._emit('if merge_result.get("final_checkpoint"):') + self._indent += 1 + self._emit(f'models["{cmd.target}"]["checkpoint"] = merge_result["final_checkpoint"]') + self._indent -= 1 + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "merge",') + self._emit('"source": _source_ref,') + self._emit(f'"method": "{cmd.method}",') + self._emit(f'"strength": {cmd.strength},') + self._emit('"timestamp": datetime.now().isoformat(),') + self._emit('"stage": _stage,') + self._indent -= 1 + self._emit("})") + self._emit('print("[td_lang] Merge complete.")') + + def _emit_heal(self, cmd: HealCmd) -> None: + self._emit(f'print("[td_lang] Healing {cmd.target} (lora_r={cmd.lora_r}, epochs={cmd.epochs})...")') + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit('print("[td_lang] WARNING: No checkpoint to heal — run a merge first.")') + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit(f"cfg = MergeConfig(heal_lora_r={cmd.lora_r}, heal_epochs={cmd.epochs})") + self._emit("healed_path = heal_model(checkpoint, cfg)") + self._emit(f'models["{cmd.target}"]["checkpoint"] = healed_path') + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "heal",') + self._emit(f'"lora_r": {cmd.lora_r},') + self._emit(f'"epochs": {cmd.epochs},') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + self._emit('print("[td_lang] Heal complete.")') + self._indent -= 1 + + def _emit_eval(self, cmd: EvalCmd) -> None: + self._emit(f'print("[td_lang] Evaluating {cmd.target}...")') + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit('print("[td_lang] WARNING: No checkpoint to evaluate.")') + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") + self._emit("import torch") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("model = AutoModelForCausalLM.from_pretrained(") + self._indent += 1 + self._emit('checkpoint, torch_dtype=torch.bfloat16, device_map="auto"') + self._indent -= 1 + self._emit(")") + self._emit("eval_result = validate_merged_model(") + self._indent += 1 + self._emit("model=model, tokenizer=tok,") + self._emit("merged_sources=merged_stages,") + self._emit("cfg=MergeConfig(),") + self._indent -= 1 + self._emit(")") + self._emit(f'results["{cmd.target}_eval"] = eval_result') + self._emit(f'hist_key = "{cmd.target}_eval_history"') + self._emit("if hist_key not in results:") + self._indent += 1 + self._emit("results[hist_key] = []") + self._indent -= 1 + self._emit("metric = 1.0 if eval_result.get('overall', False) else 0.0") + self._emit("results[hist_key].append(metric)") + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "eval",') + self._emit('"timestamp": datetime.now().isoformat(),') + self._emit('"result": eval_result,') + self._indent -= 1 + self._emit("})") + if cmd.output: + self._emit(f'eval_path = Path("{cmd.output}")') + self._emit("eval_path.parent.mkdir(parents=True, exist_ok=True)") + self._emit('with open(eval_path, "w") as f:') + self._indent += 1 + self._emit("json.dump(eval_result, f, indent=2, default=str)") + self._indent -= 1 + self._emit('print(f"[td_lang] Eval results saved to {eval_path}")') + else: + self._emit('print("[td_lang] Eval results:", json.dumps(eval_result, indent=2, default=str))') + self._emit("del model, tok") + self._emit("import gc; gc.collect()") + self._indent -= 1 + + def _emit_commit(self, cmd: CommitCmd, global_gates: Optional[GateBlock]) -> None: + gates = cmd.gates or (global_gates.must_pass if global_gates else None) + self._emit(f'print("[td_lang] Committing {cmd.target}...")') + if gates: + self._emit(f"gates_to_check = {gates}") + self._emit(f'last_eval = results.get("{cmd.target}_eval", {{}})') + self._emit("failed = []") + self._emit("for gate in gates_to_check:") + self._indent += 1 + self._emit('if gate == "overall":') + self._indent += 1 + self._emit('ok = bool(last_eval.get("overall", False))') + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("val = last_eval.get(gate, {})") + self._emit("if isinstance(val, dict):") + self._indent += 1 + self._emit('ok = bool(val.get("ok", False))') + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("ok = bool(val)") + self._indent -= 1 + self._indent -= 1 + self._emit("if not ok:") + self._indent += 1 + self._emit("failed.append(gate)") + self._indent -= 1 + self._indent -= 1 + self._emit("if failed:") + self._indent += 1 + self._emit('raise TDGateError(failed, message="Commit blocked — gates failed")') + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit('print("[td_lang] All gates passed!")') + self._indent -= 1 + + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit('print("[td_lang] WARNING: No checkpoint to commit.")') + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit('commit_dir = Path("td_lang_outputs") / "committed"') + self._emit("commit_dir.mkdir(parents=True, exist_ok=True)") + self._emit('lineage_path = commit_dir / "lineage.json"') + self._emit('with open(lineage_path, "w") as f:') + self._indent += 1 + self._emit("json.dump(lineage, f, indent=2, default=str)") + self._indent -= 1 + self._emit('print(f"[td_lang] Committed. Checkpoint: {checkpoint}")') + self._emit('print(f"[td_lang] Lineage saved to: {lineage_path}")') + self._indent -= 1 + + # ---------------------------------------------------------------- Phase 2 emitters + + def _emit_diagnose(self, cmd: DiagnoseCmd) -> None: + """Generate code for: diagnose target [-> weaknesses.json] + + Loads the model and asks it to identify its own weaknesses. + Uses structured prompting to get actionable self-diagnosis. + Interview finding: all 3 AIs (ChatGPT, Grok, Gemini) confirmed + models CAN self-diagnose when asked directly (test_8-12). + """ + self._emit(f'print("[td_lang] Diagnosing {cmd.target}...")') + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit('print("[td_lang] WARNING: No checkpoint — using model_ref instead.")') + self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') + self._indent -= 1 + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") + self._emit("import torch") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("model = AutoModelForCausalLM.from_pretrained(") + self._indent += 1 + self._emit('checkpoint, torch_dtype=torch.bfloat16, device_map="auto"') + self._indent -= 1 + self._emit(")") + self._emit("model.eval()") + self._emit("") + self._emit("# Self-diagnosis prompts (from TD interview findings test_12)") + self._emit("diag_prompts = [") + self._indent += 1 + self._emit('"List your top 5 weaknesses as an AI. Be specific and honest.",') + self._emit('"What types of reasoning tasks do you fail at most? Give concrete examples.",') + self._emit('"Rate yourself 1-10 on: math, coding, long-chain logic, creativity, factual recall. Explain each score.",') + self._emit('"If you could improve one thing about yourself, what would have the biggest impact?",') + self._indent -= 1 + self._emit("]") + self._emit("diagnose_results = []") + self._emit("for prompt in diag_prompts:") + self._indent += 1 + self._emit('inputs = tok(prompt, return_tensors="pt").to(model.device)') + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("output = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)") + self._indent -= 1 + self._emit("response = tok.decode(output[0], skip_special_tokens=True)") + self._emit('diagnose_results.append({"prompt": prompt, "response": response})') + self._emit('print(f" Prompt: {prompt[:50]}...")') + self._emit('print(f" Response: {response[:200]}...")') + self._emit("print()") + self._indent -= 1 + self._emit(f'results["{cmd.target}_diagnose"] = diagnose_results') + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "diagnose",') + self._emit('"n_prompts": len(diag_prompts),') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + if cmd.output: + self._emit(f'diag_path = Path("{cmd.output}")') + self._emit("diag_path.parent.mkdir(parents=True, exist_ok=True)") + self._emit('with open(diag_path, "w") as f:') + self._indent += 1 + self._emit("json.dump(diagnose_results, f, indent=2, default=str)") + self._indent -= 1 + self._emit('print(f"[td_lang] Diagnosis saved to {diag_path}")') + self._emit("del model, tok") + self._emit("import gc; gc.collect()") + self._emit('print("[td_lang] Diagnosis complete.")') + + def _emit_synth(self, cmd: SynthCmd) -> None: + """Generate code for: synth target from source [filter cherry_llm] [-> output.jsonl] + + Smarter synthesis: + - Targets weaknesses from prior diagnose results when present. + - Supports configurable sample count (cmd.n_samples if provided). + - Produces domain-specific prompts (math, code, logic, factual). + """ + n_samples_expr = f"getattr(cmd, 'n_samples', 100)" # static string for emit clarity + self._emit(f'print("[td_lang] Generating synthetic data for {cmd.target}...")') + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') + self._indent -= 1 + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") + self._emit("import torch, random, re") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("model = AutoModelForCausalLM.from_pretrained(") + self._indent += 1 + self._emit('checkpoint, torch_dtype=torch.bfloat16, device_map="auto"') + self._indent -= 1 + self._emit(")") + self._emit("model.eval()") + self._emit("") + self._emit("# Weakness-aware topic selection from diagnosis (if available)") + self._emit(f'diag = results.get("{cmd.target}_diagnose", [])') + self._emit("weak_topics = []") + self._emit("for d in diag:") + self._indent += 1 + self._emit("resp = d.get('response', '')") + self._emit("for topic in ['math', 'code', 'logic', 'factual', 'long chain', 'tools']:") + self._indent += 1 + self._emit("if topic in resp.lower():") + self._indent += 1 + self._emit("weak_topics.append(topic)") + self._indent -= 1 + self._indent -= 1 + self._indent -= 1 + self._emit("if not weak_topics:") + self._indent += 1 + self._emit("weak_topics = ['math', 'code', 'logic', 'factual']") + self._indent -= 1 + self._emit("") + self._emit("# Domain templates") + self._emit("domain_templates = {") + self._indent += 1 + self._emit('"math": ["Solve this math problem step by step: {problem}",') + self._emit(' "Find and correct the mistake in this solution: {problem}"],') + self._emit('"code": ["Write correct, tested Python code for: {problem}",') + self._emit(' "Find the bug and fix it: {problem}"],') + self._emit('"logic": ["Reason carefully and avoid fallacies: {problem}",') + self._emit(' "Provide a formal argument for: {problem}"],') + self._emit('"factual": ["Answer with citations: {problem}",') + self._emit(' "List 3 verified facts about: {problem}"],') + self._indent -= 1 + self._emit("}") + self._emit("") + self._emit("def make_problem(domain: str) -> str:") + self._indent += 1 + self._emit("if domain == 'math':") + self._indent += 1 + self._emit("return 'Compute (17*19 - 121) / 3' if random.random() < 0.5 else 'Integrate x^2 from 0 to 3'") + self._indent -= 1 + self._emit("if domain == 'code':") + self._indent += 1 + self._emit("return 'Implement Dijkstra shortest path' if random.random() < 0.5 else 'Parse JSON safely in Python'") + self._indent -= 1 + self._emit("if domain == 'logic':") + self._indent += 1 + self._emit("return 'Does the conclusion follow? If all A are B and all B are C, are all A C?'") + self._indent -= 1 + self._emit("return 'Summarize the causes of the 2008 financial crisis in 3 bullet points.'") + self._indent -= 1 + self._emit("") + self._emit("synth_data = []") + self._emit(f"n_samples = getattr(cmd, 'n_samples', 100)") + self._emit("for i in range(n_samples):") + self._indent += 1 + self._emit("domain = random.choice(weak_topics)") + self._emit("problem = make_problem(domain)") + self._emit("template = random.choice(domain_templates[domain])") + self._emit('prompt = template.format(problem=problem)') + self._emit('inputs = tok(prompt, return_tensors="pt").to(model.device)') + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("output = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)") + self._indent -= 1 + self._emit("response = tok.decode(output[0], skip_special_tokens=True)") + self._emit('synth_data.append({"prompt": prompt, "response": response, "domain": domain})') + self._emit('if (i + 1) % 10 == 0:') + self._indent += 1 + self._emit('print(f" Generated {i + 1}/{n_samples} samples...")') + self._indent -= 1 + self._indent -= 1 + filter_method = cmd.filter_method or "none" + if filter_method == "cherry_llm": + self._emit("") + self._emit("# Cherry_LLM perplexity filter (test_12: prevents mode collapse)") + self._emit("print('[td_lang] Filtering with Cherry_LLM perplexity scoring...')") + self._emit("filtered = []") + self._emit("for sample in synth_data:") + self._indent += 1 + self._emit('inputs = tok(sample["response"], return_tensors="pt").to(model.device)') + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit('loss = model(**inputs, labels=inputs["input_ids"]).loss') + self._indent -= 1 + self._emit("perplexity = torch.exp(loss).item()") + self._emit('sample["perplexity"] = perplexity') + self._emit("if 2.0 < perplexity < 50.0:") + self._indent += 1 + self._emit("filtered.append(sample)") + self._indent -= 1 + self._indent -= 1 + self._emit("synth_data = filtered") + self._emit('print(f"[td_lang] Kept {len(synth_data)} samples after Cherry_LLM filter.")') + self._emit("") + self._emit(f'results["{cmd.target}_synth"] = synth_data') + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "synth",') + self._emit(f'"source": "{cmd.source}",') + self._emit(f'"filter": "{filter_method}",') + self._emit('"n_samples": len(synth_data),') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + output_path = cmd.output or "synth_data.jsonl" + self._emit(f'synth_path = Path("{output_path}")') + self._emit("synth_path.parent.mkdir(parents=True, exist_ok=True)") + self._emit('with open(synth_path, "w") as f:') + self._indent += 1 + self._emit("for sample in synth_data:") + self._indent += 1 + self._emit("f.write(json.dumps(sample, default=str) + chr(10))") + self._indent -= 1 + self._indent -= 1 + self._emit('print(f"[td_lang] Synthetic data saved to {synth_path} ({len(synth_data)} samples)")') + self._emit("del model, tok") + self._emit("import gc; gc.collect()") + + def _emit_train(self, cmd: TrainCmd) -> None: + """Generate code for: train target on "dataset" using method [steps N] [lr N] + + Runs GRPO, SFT, or DPO training using the trl library. + GRPO hyperparameters from test_15: 64 steps sweet spot, eval every 16. + """ + steps = cmd.steps or 64 # test_15: 64 is the sweet spot + lr = cmd.learning_rate or 5e-5 + self._emit(f'print("[td_lang] Training {cmd.target} using {cmd.method} for {steps} steps...")') + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') + self._indent -= 1 + self._emit("") + + if cmd.method == "grpo": + self._emit("# GRPO training (test_15: 64 steps sweet spot, eval every 16)") + self._emit("from trl import GRPOConfig, GRPOTrainer") + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") + self._emit("from datasets import load_dataset") + self._emit("import torch") + self._emit("") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("model = AutoModelForCausalLM.from_pretrained(") + self._indent += 1 + self._emit('checkpoint, torch_dtype=torch.bfloat16, device_map="auto"') + self._indent -= 1 + self._emit(")") + self._emit("") + self._emit(f'# Load training data') + self._emit(f'dataset_path = "{cmd.dataset}"') + self._emit("if dataset_path.endswith('.jsonl'):") + self._indent += 1 + self._emit("train_data = load_dataset('json', data_files=dataset_path, split='train')") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("train_data = load_dataset(dataset_path, split='train')") + self._indent -= 1 + self._emit("") + self._emit("grpo_config = GRPOConfig(") + self._indent += 1 + self._emit(f"max_steps={steps},") + self._emit(f"learning_rate={lr},") + self._emit("per_device_train_batch_size=1,") + self._emit("gradient_accumulation_steps=8,") + self._emit("logging_steps=16, # eval every 16 steps (test_15)") + self._emit('output_dir="td_lang_outputs/grpo_training",') + self._emit("save_steps=16,") + self._emit('bf16=True,') + self._indent -= 1 + self._emit(")") + self._emit("") + self._emit("# Verified rewards only (test_16: no learned reward model)") + self._emit("import ast, math, re") + self._emit("ALLOWED_EXPR = re.compile(r'^[0-9+\\-*/().\\s]+$')") + self._emit("") + self._emit("def _safe_eval(expr: str):") + self._indent += 1 + self._emit("expr = expr.strip()") + self._emit("if not ALLOWED_EXPR.match(expr):") + self._indent += 1 + self._emit("return None") + self._indent -= 1 + self._emit("try:") + self._indent += 1 + self._emit("return float(eval(expr, {'__builtins__': {}}, {}))") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("return None") + self._indent -= 2 + self._emit("") + self._emit("def reward_fn(completions, prompts=None, **kwargs):") + self._indent += 1 + self._emit("prompts = prompts or ['' for _ in completions]") + self._emit("rewards = []") + self._emit("for comp, prompt in zip(completions, prompts):") + self._indent += 1 + self._emit("text = comp if isinstance(comp, str) else comp[0].get('content', '')") + self._emit("score = 0.0") + self._emit("# Code compilation reward") + self._emit("code_blocks = re.findall(r'```python\\n(.*?)```', text, re.S)") + self._emit("compiled_ok = False") + self._emit("for block in code_blocks or []:") + self._indent += 1 + self._emit("try:") + self._indent += 1 + self._emit("ast.parse(block)") + self._emit("compiled_ok = True") + self._emit("break") + self._indent -= 1 + self._emit("except SyntaxError:") + self._indent += 1 + self._emit("pass") + self._indent -= 2 + self._emit("if compiled_ok:") + self._indent += 1 + self._emit("score += 0.4") + self._indent -= 1 + self._emit("# Math correctness reward (prompt-provided expression)") + self._emit("expr_match = re.search(r'([0-9+\\-*/().\\s]{3,})', prompt)") + self._emit("pred_num_match = re.search(r'(-?\\d+(?:\\.\\d+)?)', text)") + self._emit("if expr_match and pred_num_match:") + self._indent += 1 + self._emit("expr = expr_match.group(1)") + self._emit("target = _safe_eval(expr)") + self._emit("try:") + self._indent += 1 + self._emit("pred_val = float(pred_num_match.group(1))") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("pred_val = None") + self._indent -= 1 + self._emit("if target is not None and pred_val is not None and abs(target - pred_val) < 1e-3:") + self._indent += 1 + self._emit("score += 0.4") + self._indent -= 2 + self._emit("# Structured answer bonus") + self._emit("if 'answer' in text.lower() or 'result' in text.lower():") + self._indent += 1 + self._emit("score += 0.2") + self._indent -= 1 + self._emit("rewards.append(min(score, 1.0))") + self._indent -= 1 + self._emit("return rewards") + self._indent -= 1 + self._emit("") + self._emit("# Early stopping (test_15): KL spike, reward drop, diversity drop") + self._emit("from transformers import TrainerCallback") + self._emit("") + self._emit("class EarlyStopper(TrainerCallback):") + self._indent += 1 + self._emit("def __init__(self):") + self._indent += 1 + self._emit("self.kl_history = []") + self._emit("self.eval_rewards = []") + self._emit("self.entropy_history = []") + self._indent -= 1 + self._emit("") + self._emit("def on_log(self, args, state, control, logs=None, **kwargs):") + self._indent += 1 + self._emit("logs = logs or {}") + self._emit("if 'kl' in logs:") + self._indent += 1 + self._emit("self.kl_history.append(logs['kl'])") + self._emit("if len(self.kl_history) > 5:") + self._indent += 1 + self._emit("ma = sum(self.kl_history[-5:]) / 5") + self._emit("if logs['kl'] > 3.1 * ma:") + self._indent += 1 + self._emit("control.should_training_stop = True") + self._emit("print('[td_lang][early_stop] KL spike detected — stopping GRPO')") + self._indent -= 2 + self._indent -= 1 + self._emit("if 'eval/reward' in logs:") + self._indent += 1 + self._emit("self.eval_rewards.append(logs['eval/reward'])") + self._emit("if len(self.eval_rewards) >= 2 and self.eval_rewards[-1] < self.eval_rewards[-2]:") + self._indent += 1 + self._emit("control.should_training_stop = True") + self._emit("print('[td_lang][early_stop] Validation reward drop — stopping GRPO')") + self._indent -= 1 + self._indent -= 1 + self._emit("if 'policy_entropy' in logs:") + self._indent += 1 + self._emit("self.entropy_history.append(logs['policy_entropy'])") + self._emit("if len(self.entropy_history) >= 3:") + self._indent += 1 + self._emit("baseline = self.entropy_history[0]") + self._emit("if self.entropy_history[-1] < 0.93 * baseline:") + self._indent += 1 + self._emit("control.should_training_stop = True") + self._emit("print('[td_lang][early_stop] Diversity collapsed — stopping GRPO')") + self._indent -= 2 + self._indent -= 2 + self._indent -= 1 + self._emit("trainer = GRPOTrainer(") + self._indent += 1 + self._emit("model=model,") + self._emit("args=grpo_config,") + self._emit("train_dataset=train_data,") + self._emit("reward_funcs=reward_fn,") + self._emit("tokenizer=tok,") + self._emit("callbacks=[EarlyStopper()],") + self._indent -= 1 + self._emit(")") + self._emit("trainer.train()") + self._emit("trainer.save_model('td_lang_outputs/grpo_trained')") + self._emit(f'models["{cmd.target}"]["checkpoint"] = "td_lang_outputs/grpo_trained"') + + elif cmd.method in ("sft", "dpo"): + self._emit(f"# {cmd.method.upper()} training") + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments") + if cmd.method == "sft": + self._emit("from trl import SFTTrainer") + else: + self._emit("from trl import DPOTrainer, DPOConfig") + self._emit("from datasets import load_dataset") + self._emit("import torch") + self._emit("") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("model = AutoModelForCausalLM.from_pretrained(") + self._indent += 1 + self._emit('checkpoint, torch_dtype=torch.bfloat16, device_map="auto"') + self._indent -= 1 + self._emit(")") + self._emit(f'dataset_path = "{cmd.dataset}"') + self._emit("if dataset_path.endswith('.jsonl'):") + self._indent += 1 + self._emit("train_data = load_dataset('json', data_files=dataset_path, split='train')") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("train_data = load_dataset(dataset_path, split='train')") + self._indent -= 1 + self._emit("") + self._emit(f'print("[td_lang] Running {cmd.method.upper()} for {steps} steps...")') + if cmd.method == "sft": + self._emit("training_args = TrainingArguments(") + self._indent += 1 + self._emit('output_dir="td_lang_outputs/sft_training",') + self._emit(f"max_steps={steps},") + self._emit(f"learning_rate={lr},") + self._emit("per_device_train_batch_size=2,") + self._emit("gradient_accumulation_steps=4,") + self._emit("logging_steps=10,") + self._emit(f"save_steps=max(10, int({steps}/2)),") + self._emit("bf16=True,") + self._indent -= 1 + self._emit(")") + self._emit("trainer = SFTTrainer(") + self._indent += 1 + self._emit("model=model,") + self._emit("tokenizer=tok,") + self._emit("args=training_args,") + self._emit("train_dataset=train_data,") + self._emit('dataset_text_field="text",') + self._indent -= 1 + self._emit(")") + self._emit("trainer.train()") + self._emit('trainer.save_model("td_lang_outputs/sft_trained")') + self._emit(f'models["{cmd.target}"]["checkpoint"] = "td_lang_outputs/sft_trained"') + else: + self._emit("training_args = DPOConfig(") + self._indent += 1 + self._emit(f"max_steps={steps},") + self._emit(f"learning_rate={lr},") + self._emit("per_device_train_batch_size=1,") + self._emit("gradient_accumulation_steps=4,") + self._emit("logging_steps=10,") + self._emit('output_dir="td_lang_outputs/dpo_training",') + self._emit("bf16=True,") + self._indent -= 1 + self._emit(")") + self._emit("trainer = DPOTrainer(") + self._indent += 1 + self._emit("model=model,") + self._emit("ref_model=None,") + self._emit("beta=0.1,") + self._emit("train_dataset=train_data,") + self._emit("tokenizer=tok,") + self._emit("args=training_args,") + self._emit('loss_type="sigmoid",') + self._indent -= 1 + self._emit(")") + self._emit("trainer.train()") + self._emit('trainer.save_model("td_lang_outputs/dpo_trained")') + self._emit(f'models["{cmd.target}"]["checkpoint"] = "td_lang_outputs/dpo_trained"') + + else: + self._emit(f'print("[td_lang] Unknown training method: {cmd.method}")') + self._emit('print("[td_lang] Supported: grpo, sft, dpo")') + + self._emit("") + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "train",') + self._emit(f'"method": "{cmd.method}",') + self._emit(f'"steps": {steps},') + self._emit(f'"lr": {lr},') + self._emit(f'"dataset": "{cmd.dataset}",') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + self._emit("import gc; gc.collect()") + self._emit(f'print("[td_lang] Training complete.")') + + def _emit_debate(self, cmd: DebateCmd) -> None: + """Generate code for: debate target rounds N candidates N [-> output.jsonl] + + Weakness-aware single-model debate with structured judging. + """ + self._emit(f'print("[td_lang] Running debate: {cmd.rounds} rounds, {cmd.candidates} candidates...")') + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') + self._indent -= 1 + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") + self._emit("import torch, random, json") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("model = AutoModelForCausalLM.from_pretrained(") + self._indent += 1 + self._emit('checkpoint, torch_dtype=torch.bfloat16, device_map="auto"') + self._indent -= 1 + self._emit(")") + self._emit("model.eval()") + self._emit("") + self._emit("# Persona-based debate (test_14: single-model diversity protocol)") + self._emit("personas = [") + self._indent += 1 + self._emit('"You are a careful, skeptical analyst. Question every assumption.",') + self._emit('"You are a creative problem solver. Think outside the box.",') + self._emit('"You are a rigorous mathematician. Show formal proofs.",') + self._emit('"You are a practical engineer. Focus on what works.",') + self._emit('"You are a devil\'s advocate. Find flaws in every argument.",') + self._emit('"You are an optimist. Find the best interpretation.",') + self._emit('"You are a minimalist. Give the simplest correct answer.",') + self._emit('"You are a professor. Explain with clarity and depth.",') + self._indent -= 1 + self._emit("]") + self._emit("") + self._emit("# Base prompts + diagnosis-derived prompts") + self._emit(f'diag = results.get("{cmd.target}_diagnose", [])') + self._emit("debate_prompts = [") + self._indent += 1 + self._emit('"Solve: What is the sum of the first 20 prime numbers?",') + self._emit('"Explain why the sky appears blue using physics.",') + self._emit('"Write a Python function to find the longest palindrome in a string.",') + self._emit('"What are the logical flaws in this argument: All birds can fly, penguins are birds, therefore penguins can fly.",') + self._emit('"If a train travels 60mph for 2.5 hours, then 80mph for 1.5 hours, what is the average speed?",') + self._indent -= 1 + self._emit("]") + self._emit("for d in diag:") + self._indent += 1 + self._emit("resp = d.get('response', '')") + self._emit("snip = resp[:140]") + self._emit('debate_prompts.append(f"Address this weakness you listed: {snip}. Provide a concrete fix and example.")') + self._indent -= 1 + self._emit("") + self._emit("debate_results = []") + self._emit(f"for round_num in range({cmd.rounds}):") + self._indent += 1 + self._emit(f'print(f\" Round {{round_num + 1}}/{cmd.rounds}...\")') + self._emit("prompt = random.choice(debate_prompts)") + self._emit(f"selected_personas = random.sample(personas, min({cmd.candidates}, len(personas)))") + self._emit("candidates = []") + self._emit("for persona in selected_personas:") + self._indent += 1 + self._emit('full_prompt = f\"{persona}\\n\\nQuestion: {prompt}\\n\\nAnswer:\"') + self._emit('inputs = tok(full_prompt, return_tensors=\"pt\").to(model.device)') + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("output = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.9)") + self._indent -= 1 + self._emit("response = tok.decode(output[0], skip_special_tokens=True)") + self._emit('candidates.append({"persona": persona, "response": response})') + self._indent -= 1 + self._emit("") + self._emit("# Judge: structured JSON scoring for correctness, reasoning, safety, style") + self._emit('judge_prompt = "You are a neutral judge. Return JSON with keys: scores (list of {id, correctness, reasoning, safety, style}), winner_id, rationale. Scores 1-10.\\n"') + self._emit("for idx, c in enumerate(candidates):") + self._indent += 1 + self._emit("resp_snip = c['response'][:400]") + self._emit('judge_prompt += f"Answer {idx+1}: {resp_snip}\\n\\n"') + self._indent -= 1 + self._emit('inputs = tok(judge_prompt, return_tensors=\"pt\").to(model.device)') + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("output = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.2)") + self._indent -= 1 + self._emit("judgment = tok.decode(output[0], skip_special_tokens=True)") + self._emit("try:") + self._indent += 1 + self._emit("judgment_json = json.loads(judgment[judgment.find('{'):])") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("judgment_json = {'raw': judgment}") + self._indent -= 1 + self._emit("debate_results.append({") + self._indent += 1 + self._emit('"round": round_num + 1,') + self._emit('"prompt": prompt,') + self._emit('"candidates": candidates,') + self._emit('"judgment": judgment_json,') + self._indent -= 1 + self._emit("})") + self._indent -= 1 + self._emit("") + self._emit(f'results["{cmd.target}_debate"] = debate_results') + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "debate",') + self._emit(f'"rounds": {cmd.rounds},') + self._emit(f'"candidates": {cmd.candidates},') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + output_path = cmd.output or "debate_pairs.jsonl" + self._emit(f'debate_path = Path("{output_path}")') + self._emit("debate_path.parent.mkdir(parents=True, exist_ok=True)") + self._emit('with open(debate_path, "w") as f:') + self._indent += 1 + self._emit("for entry in debate_results:") + self._indent += 1 + self._emit("f.write(json.dumps(entry, default=str) + chr(10))") + self._indent -= 1 + self._indent -= 1 + self._emit('print(f"[td_lang] Debate results saved to {debate_path} ({len(debate_results)} rounds)")') + self._emit("del model, tok") + self._emit("import gc; gc.collect()") + + # ---------------------------------------------------------------- Phase 3 emitters + + def _emit_edit(self, cmd: EditCmd) -> None: + """EDIT — surgical LoRA/DoRA on specific layers. + + From test_18: all 3 AIs agree LoRA is safe default, DoRA beats by 1-4%. + layers_to_transform supports targeting specific layers (e.g., 16-28). + "Try before buy": eval with adapters enabled vs disabled, merge only if gates pass. + """ + alias = cmd.target + method = cmd.method # "lora" or "dora" + layers = cmd.layers # "all", "16-28", or single number + lr = cmd.learning_rate or 1e-4 + + self._emit(f'print("[td_lang] EDIT — surgical {method} on {alias}, layers={layers}")') + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") + self._emit("import torch") + self._emit("from peft import LoraConfig, get_peft_model, PeftModel") + self._emit("from bitsandbytes import __version__ as bnb_version # ensure bnb installed") + self._emit("") + # Resolve checkpoint to load with 4-bit for 8B on single 4090 + self._emit(f'checkpoint = models.get("{alias}", {{}}).get("checkpoint") or models["{alias}"].get("model_ref")') + self._emit('print(f"[td_lang] Loading base model for EDIT from {checkpoint} (4-bit QLoRA)...")') + self._emit("bnb_config = {") + self._indent += 1 + self._emit('"load_in_4bit": True,') + self._emit('"bnb_4bit_compute_dtype": torch.bfloat16,') + self._emit('"bnb_4bit_use_double_quant": True,') + self._emit('"bnb_4bit_quant_type": "nf4",') + self._indent -= 1 + self._emit("}") + self._emit("model = AutoModelForCausalLM.from_pretrained(") + self._indent += 1 + self._emit("checkpoint, device_map='auto', **bnb_config") + self._indent -= 1 + self._emit(")") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("") + # Parse layer spec into layers_to_transform + self._emit("# Parse layer targeting") + if layers == "all": + self._emit("layers_to_transform = None # all layers") + elif "-" in layers: + parts = layers.split("-") + self._emit(f"layers_to_transform = list(range({parts[0]}, {int(parts[1]) + 1}))") + else: + self._emit(f"layers_to_transform = [{layers}]") + self._emit("") + + # Build PEFT config + self._emit("use_dora = method == 'dora'") + self._emit("edit_r = getattr(cmd, 'r', 8)") + self._emit("edit_alpha = getattr(cmd, 'alpha', 16)") + self._emit("edit_config = LoraConfig(") + self._indent += 1 + self._emit("r=edit_r,") + self._emit("lora_alpha=edit_alpha,") + self._emit('target_modules=["q_proj", "v_proj"],') + self._emit("lora_dropout=0.05,") + self._emit('bias="none",') + self._emit('task_type="CAUSAL_LM",') + self._emit("use_dora=use_dora,") + if layers != "all": + self._emit("layers_to_transform=layers_to_transform,") + self._emit('layers_pattern="layers",') + self._indent -= 1 + self._emit(")") + self._emit("") + + # Apply adapter + self._emit("# Inject adapter — base weights stay frozen") + self._emit("model = get_peft_model(model, edit_config)") + self._emit("model.print_trainable_parameters()") + self._emit("") + + # Dry-run: show which modules got wrapped + self._emit("# Dry-run report: verify correct modules were targeted") + self._emit("wrapped_modules = [n for n, _ in model.named_modules() if 'lora' in n.lower()]") + self._emit(f'print(f"[td_lang] EDIT: {{len(wrapped_modules)}} modules wrapped with {method}")') + self._emit('for wm in wrapped_modules[:10]:') + self._indent += 1 + self._emit('print(f" - {wm}")') + self._indent -= 1 + self._emit('if len(wrapped_modules) > 10:') + self._indent += 1 + self._emit('print(f" ... and {len(wrapped_modules) - 10} more")') + self._indent -= 1 + self._emit("") + + # "Try before buy" — actual eval with adapters on vs off + self._emit('sample_prompts = ["What is 7+8?", "Explain photosynthesis in one paragraph.", "Write a Python function fib(n)."]') + self._emit("def run_quick_eval(enable_adapters: bool):") + self._indent += 1 + self._emit("if enable_adapters:") + self._indent += 1 + self._emit("if hasattr(model, 'enable_adapters'): model.enable_adapters()") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("if hasattr(model, 'disable_adapters'): model.disable_adapters()") + self._indent -= 1 + self._emit("responses = []") + self._emit("for p in sample_prompts:") + self._indent += 1 + self._emit("inputs = tok(p, return_tensors='pt').to(model.device)") + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = model.generate(**inputs, max_new_tokens=128, temperature=0.7, do_sample=True)") + self._indent -= 1 + self._emit("resp = tok.decode(out[0], skip_special_tokens=True)") + self._emit("responses.append(resp)") + self._indent -= 1 + self._emit("avg_len = sum(len(r) for r in responses) / len(responses)") + self._emit("return responses, avg_len") + self._indent -= 1 + self._emit("") + self._emit("on_resps, on_len = run_quick_eval(True)") + self._emit("off_resps, off_len = run_quick_eval(False)") + self._emit('print("[td_lang] Try-before-buy results:")') + self._emit('print(f" Adapter ON avg length: {on_len:.1f}")') + self._emit('print(f" Adapter OFF avg length: {off_len:.1f}")') + self._emit("for i, (a, b) in enumerate(zip(on_resps, off_resps)):") + self._indent += 1 + self._emit('print(f"Prompt {i+1}:")') + self._emit('print(" ON :", a[:200])') + self._emit('print(" OFF:", b[:200])') + self._indent -= 1 + self._emit("") + + # Save adapter (don't merge yet — let commit/gates decide) + self._emit(f'edit_save_dir = os.path.join(output_dir, "{alias}_edit_{method}")') + self._emit("os.makedirs(edit_save_dir, exist_ok=True)") + self._emit("model.save_pretrained(edit_save_dir)") + self._emit(f'print(f"[td_lang] EDIT adapter saved to {{edit_save_dir}}")') + self._emit(f'print("[td_lang] Adapter NOT merged — use commit with gates to merge permanently")') + self._emit("") + + # Update models dict + self._emit(f'models["{alias}"] = model') + + def _emit_fork(self, cmd: ForkCmd) -> None: + """FORK — branch current model weights for parallel experiments. + + From test_18: all 3 AIs say disk-based only on 4090. + Cheap fork = copy manifest + adapter files, share base weights. + Uses safetensors format. + """ + source = cmd.source + alias = cmd.alias + + self._emit(f'print("[td_lang] FORK — branching {source} as {alias}")') + self._emit(f'source_model = models["{source}"]') + self._emit("import torch") + self._emit("") + + # Create fork directory with content hash (avoid overwrite) + self._emit("import hashlib") + self._emit('fork_suffix = hashlib.sha1((str(time.time()) + "{alias}").encode()).hexdigest()[:8]') + self._emit(f'fork_dir = os.path.join(output_dir, "forks", "{alias}_" + fork_suffix)') + self._emit("os.makedirs(fork_dir, exist_ok=True)") + self._emit("") + + # Write manifest + self._emit("# Write fork manifest — tracks lineage") + self._emit("import json") + self._emit("fork_manifest = {") + self._emit(f' "fork_name": "{alias}",') + self._emit(f' "forked_from": "{source}",') + self._emit(f' "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),') + self._emit(f' "base_ref": models.get("__base_ref_{source}", "unknown"),') + self._emit("}") + self._emit("") + + # Check if model has PEFT adapters + self._emit("# Cheap fork: save adapters only if PEFT model, else full checkpoint") + self._emit("is_peft = hasattr(source_model, 'peft_config')") + self._emit("if is_peft:") + self._indent += 1 + self._emit("# PEFT model — save only adapter weights (small, fast)") + self._emit('adapter_dir = os.path.join(fork_dir, "adapters")') + self._emit("source_model.save_pretrained(adapter_dir)") + self._emit('fork_manifest["fork_type"] = "adapter"') + self._emit('fork_manifest["adapter_dir"] = adapter_dir') + self._emit('print(f"[td_lang] Cheap fork: adapter saved to {adapter_dir}")') + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("# Full model — clone tensors then save to safetensors") + self._emit("from safetensors.torch import save_file") + self._emit("state = {k: v.detach().cpu().clone() for k, v in source_model.state_dict().items()}") + self._emit('ckpt_path = os.path.join(fork_dir, "model.safetensors")') + self._emit("save_file(state, ckpt_path)") + self._emit('fork_manifest["fork_type"] = "full_checkpoint"') + self._emit('fork_manifest["checkpoint_path"] = ckpt_path') + self._emit('print(f"[td_lang] Full fork: checkpoint saved to {ckpt_path}")') + self._indent -= 1 + self._emit("") + + # Save manifest + self._emit("# Save RNG state for reproducibility") + self._emit("try:") + self._indent += 1 + self._emit("rng_state = torch.cuda.get_rng_state().cpu() if torch.cuda.is_available() else None") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("rng_state = None") + self._indent -= 1 + self._emit("if rng_state is not None:") + self._indent += 1 + self._emit('torch.save(rng_state, os.path.join(fork_dir, "rng_state.pt"))') + self._emit('fork_manifest["rng_state"] = "rng_state.pt"') + self._indent -= 1 + self._emit("") + self._emit('manifest_path = os.path.join(fork_dir, "manifest.json")') + self._emit('with open(manifest_path, "w") as f:') + self._indent += 1 + self._emit("json.dump(fork_manifest, f, indent=2)") + self._indent -= 1 + self._emit(f'print(f"[td_lang] Fork manifest: {{manifest_path}}")') + self._emit("") + + # Register fork as available model alias (points to same model for now) + self._emit(f'models["{alias}"] = source_model # shares reference until divergence') + self._emit(f'lineage["{alias}"] = {{"forked_from": "{source}", "operations": []}}') + + def _emit_reset(self, cmd: ResetCmd) -> None: + """RESET — revert model to a previous checkpoint. + + From test_18: del model, clear CUDA cache, reload. + Must also reset optimizer state. Use assign=True to avoid doubling VRAM. + """ + alias = cmd.target + checkpoint = cmd.checkpoint + + self._emit(f'print("[td_lang] RESET — reverting {alias} to {checkpoint}")') + self._emit("") + + # Delete current model and clear CUDA + self._emit("# Free current model from VRAM") + self._emit(f'del models["{alias}"]') + self._emit("import gc; gc.collect()") + self._emit("torch.cuda.empty_cache()") + self._emit(f'print("[td_lang] VRAM cleared")') + self._emit("") + + # Determine checkpoint path + self._emit("# Resolve checkpoint path") + self._emit(f'ckpt_path = "{checkpoint}"') + self._emit("base_ref = ckpt_path") + self._emit("# Check if it's a fork directory with manifest") + self._emit('fork_manifest_path = os.path.join(ckpt_path, "manifest.json") if os.path.isdir(ckpt_path) else None') + self._emit("") + + # Reload model + self._emit("# Reload from checkpoint") + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") + self._emit("") + self._emit("if fork_manifest_path and os.path.exists(fork_manifest_path):") + self._indent += 1 + self._emit("# Loading from a fork — read manifest") + self._emit("import json") + self._emit("with open(fork_manifest_path) as f:") + self._indent += 1 + self._emit("manifest = json.load(f)") + self._indent -= 1 + self._emit('base_ref = manifest.get("base_ref", ckpt_path)') + self._emit("model = AutoModelForCausalLM.from_pretrained(base_ref, torch_dtype=torch.float16, device_map='cuda')") + self._emit('if manifest.get("fork_type") == "adapter":') + self._indent += 1 + self._emit("from peft import PeftModel") + self._emit('model = PeftModel.from_pretrained(model, manifest["adapter_dir"])') + self._indent -= 1 + self._indent -= 1 + self._emit("elif os.path.isdir(ckpt_path):") + self._indent += 1 + self._emit("# Loading from a HF-style directory") + self._emit("model = AutoModelForCausalLM.from_pretrained(ckpt_path, torch_dtype=torch.float16, device_map='cuda')") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("# Loading from a safetensors file") + self._emit("from safetensors.torch import load_file") + self._emit("state = load_file(ckpt_path, device='cpu')") + self._emit("# Need base model architecture — reload from original") + self._emit(f'base_ref = models.get("__base_ref_{alias}", ckpt_path)') + self._emit("model = AutoModelForCausalLM.from_pretrained(base_ref, torch_dtype=torch.float16, device_map='cuda')") + self._emit("try:") + self._indent += 1 + self._emit("model.load_state_dict(state, strict=True, assign=True)") + self._indent -= 1 + self._emit("except Exception as e:") + self._indent += 1 + self._emit('print(f"[td_lang] Shape mismatch on reset load: {e}. Retrying non-strict.")') + self._emit("model.load_state_dict(state, strict=False)") + self._indent -= 1 + self._indent -= 1 + self._emit("") + + # Re-register in models dict + self._emit(f'models["{alias}"] = model') + self._emit(f'print(f"[td_lang] RESET complete — {alias} restored from {checkpoint}")') + self._emit("") + + # Optimizer/cache handling and quick smoke eval + self._emit("torch.cuda.empty_cache()") + self._emit(f'print("[td_lang] Note: optimizer state cleared; next train starts fresh.")') + self._emit("# Smoke eval after reset") + self._emit('sample_prompts = ["Hello!", "2+2?", "Define gravity.", "Write a Python loop 1..3.", "Capital of France?"]') + self._emit("tok = AutoTokenizer.from_pretrained(ckpt_path if os.path.isdir(ckpt_path) else base_ref)") + self._emit("model.eval()") + self._emit("for p in sample_prompts:") + self._indent += 1 + self._emit("inputs = tok(p, return_tensors='pt').to(model.device)") + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = model.generate(**inputs, max_new_tokens=40, do_sample=False)") + self._indent -= 1 + self._emit("resp = tok.decode(out[0], skip_special_tokens=True)") + self._emit('print(f"[td_lang][reset smoke] {p} -> {resp[:120]}")') + self._indent -= 1 + + def _emit_prune(self, cmd: PruneCmd) -> None: + """PRUNE — structural pruning of language backbone. + + From test_18: 20% structured max (LLM-Pruner). Wanda metric (Grok). + Language backbone only, never vision encoder. Recovery: 200-800 steps LoRA. + """ + alias = cmd.target + method = cmd.method # "wanda", "magnitude", "taylor" + aggressiveness = cmd.aggressiveness + + self._emit("import torch") + self._emit(f'print("[td_lang] PRUNE — {method} pruning on {alias}, {aggressiveness*100:.0f}% removal")') + self._emit(f'model = models["{alias}"]') + self._emit("") + + # Safety check: cap aggressiveness + self._emit("# Safety: cap pruning at 30% (beyond this = cliff, per LLM-Pruner)") + self._emit(f"prune_ratio = min({aggressiveness}, 0.30)") + self._emit(f"if prune_ratio != {aggressiveness}:") + self._indent += 1 + self._emit(f'print(f"[td_lang] WARNING: aggressiveness capped at 30% (requested {aggressiveness*100:.0f}%)")') + self._indent -= 1 + self._emit("") + + # Identify language-only layers (skip vision) + self._emit("# Target language backbone ONLY — never prune vision encoder") + self._emit("# Filter for language model linear layers") + self._emit("target_modules = []") + self._emit("for name, module in model.named_modules():") + self._indent += 1 + self._emit("if isinstance(module, torch.nn.Linear):") + self._indent += 1 + self._emit("# Skip vision encoder, embeddings, and output head") + self._emit('is_vision = any(v in name for v in ["visual", "vision", "vit", "image", "pixel"])') + self._emit('is_embed = any(e in name for e in ["embed", "lm_head", "output"])') + self._emit("if not is_vision and not is_embed:") + self._indent += 1 + self._emit("target_modules.append((name, module))") + self._indent -= 1 + self._indent -= 1 + self._indent -= 1 + self._emit('print(f"[td_lang] Found {len(target_modules)} prunable language layers")') + self._emit("") + + # Apply pruning based on method + self._emit(f"# Pruning method: {method}") + if method == "wanda": + self._emit("# Wanda: weight magnitude × input activation norm (Grok's recommendation)") + self._emit("# Collect activations on small calibration batch, then prune with keep_multiple_of=8") + self._emit("import torch.nn.utils.prune as prune") + self._emit("calib_texts = [") + self._indent += 1 + self._emit('"The quick brown fox jumps over the lazy dog.",') + self._emit('"Solve 12 + 37.",') + self._emit('"Write a for loop in Python that sums 1..10.",') + self._emit('"Explain why the sky is blue.",') + self._indent -= 1 + self._emit("]") + self._emit("from transformers import AutoTokenizer") + self._emit("base_ref = None") + self._emit("if isinstance(models.get(alias), dict):") + self._indent += 1 + self._emit("base_ref = models[alias].get('model_ref')") + self._indent -= 1 + self._emit("if base_ref is None:") + self._indent += 1 + self._emit(f"base_ref = models.get('__base_ref_{alias}', 'Qwen/Qwen3-VL-8B-Instruct')") + self._indent -= 1 + self._emit("tok = AutoTokenizer.from_pretrained(base_ref)") + self._emit("activation_sums = {}") + self._emit("hooks = []") + self._emit("def make_hook(name):") + self._indent += 1 + self._emit("def _hook(module, inp, out):") + self._indent += 1 + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("act = inp[0].detach().abs().mean(dim=0)") + self._emit("activation_sums[name] = activation_sums.get(name, 0) + act") + self._indent -= 2 + self._emit("return _hook") + self._indent -= 1 + self._emit("for name, module in target_modules:") + self._indent += 1 + self._emit("hooks.append(module.register_forward_hook(make_hook(name)))") + self._indent -= 1 + self._emit("# Run one calibration pass") + self._emit("for txt in calib_texts:") + self._indent += 1 + self._emit("inputs = tok(txt, return_tensors='pt').to(model.device)") + self._emit("with torch.no_grad(): model(**inputs)") + self._indent -= 1 + self._emit("for h in hooks: h.remove()") + self._emit("") + self._emit("import torch.nn.utils.prune as prune") + self._emit("pruned_count = 0") + self._emit("for layer_name, layer_module in target_modules:") + self._indent += 1 + self._emit("act = activation_sums.get(layer_name)") + self._emit("if act is None:") + self._indent += 1 + self._emit('print(f"[td_lang] Skip {layer_name}: no activation stats")') + self._emit("continue") + self._indent -= 1 + self._emit("scores = (layer_module.weight.detach().abs() * act.unsqueeze(0)).mean(dim=1)") + self._emit("keep = max(8, int((1 - prune_ratio) * scores.numel()))") + self._emit("keep = (keep // 8) * 8") + self._emit("keep = min(max(8, keep), scores.numel())") + self._emit("amount = 1 - (keep / scores.numel())") + self._emit("try:") + self._indent += 1 + self._emit("prune.ln_structured(layer_module, name='weight', amount=amount, n=1, dim=0)") + self._emit("prune.remove(layer_module, 'weight')") + self._emit("pruned_count += 1") + self._indent -= 1 + self._emit("except Exception as e:") + self._indent += 1 + self._emit('print(f"[td_lang] Skip {layer_name}: {e}")') + self._indent -= 1 + self._indent -= 1 + elif method == "magnitude": + self._emit("# Magnitude: simple L1 norm of weight rows") + self._emit("import torch.nn.utils.prune as prune") + self._emit("") + self._emit("pruned_count = 0") + self._emit("for layer_name, layer_module in target_modules:") + self._indent += 1 + self._emit("try:") + self._indent += 1 + self._emit("prune.ln_structured(layer_module, name='weight', amount=prune_ratio, n=1, dim=0)") + self._emit("prune.remove(layer_module, 'weight')") + self._emit("pruned_count += 1") + self._indent -= 1 + self._emit("except Exception as e:") + self._indent += 1 + self._emit('print(f"[td_lang] Skip {layer_name}: {e}")') + self._indent -= 1 + self._indent -= 1 + else: # taylor + self._emit("# Taylor: gradient-based importance (needs backprop — VRAM heavy)") + self._emit("# Falling back to magnitude as MVP — Taylor needs calibration + backprop") + self._emit(f'print("[td_lang] WARNING: Taylor pruning falls back to magnitude on single GPU")') + self._emit("import torch.nn.utils.prune as prune") + self._emit("") + self._emit("pruned_count = 0") + self._emit("for layer_name, layer_module in target_modules:") + self._indent += 1 + self._emit("try:") + self._indent += 1 + self._emit("prune.ln_structured(layer_module, name='weight', amount=prune_ratio, n=1, dim=0)") + self._emit("prune.remove(layer_module, 'weight')") + self._emit("pruned_count += 1") + self._indent -= 1 + self._emit("except Exception as e:") + self._indent += 1 + self._emit('print(f"[td_lang] Skip {layer_name}: {e}")') + self._indent -= 1 + self._indent -= 1 + self._emit("") + + # Report + self._emit('print(f"[td_lang] Pruned {pruned_count}/{len(target_modules)} layers at {prune_ratio*100:.0f}%")') + self._emit("") + + # Save pruning report + self._emit("# Save prune report for auditing") + self._emit("import json") + self._emit("prune_report = {") + self._emit(f' "method": "{method}",') + self._emit(f' "requested_aggressiveness": {aggressiveness},') + self._emit(' "actual_ratio": prune_ratio,') + self._emit(' "layers_pruned": pruned_count,') + self._emit(' "total_target_layers": len(target_modules),') + self._emit(' "vision_touched": False,') + self._emit("}") + self._emit(f'prune_report_path = os.path.join(output_dir, "{alias}_prune_report.json")') + self._emit('with open(prune_report_path, "w") as f:') + self._indent += 1 + self._emit("json.dump(prune_report, f, indent=2)") + self._indent -= 1 + self._emit(f'print(f"[td_lang] Prune report: {{prune_report_path}}")') + self._emit("") + + # Recovery warning + self._emit("# Recovery: you should run heal or train after pruning") + self._emit("# LLM-Pruner shows recovery in 200-800 steps with LoRA r=8") + self._emit(f'print("[td_lang] IMPORTANT: Run heal or train after pruning for recovery (suggest: heal {alias} lora_r 8 epochs 1, ~400 steps)")') + self._emit(f'models["{alias}"] = model') + + # ---------------------------------------------------------------- Phase 7: Loop Control emitters + + def _emit_cmd(self, cmd, program: TDProgram) -> None: + """Emit a single command — used by repeat/if to emit body commands.""" + if isinstance(cmd, LoadCmd): + self._emit_load(cmd) + elif isinstance(cmd, MergeCmd): + self._emit_merge(cmd) + elif isinstance(cmd, HealCmd): + self._emit_heal(cmd) + elif isinstance(cmd, EvalCmd): + self._emit_eval(cmd) + elif isinstance(cmd, CommitCmd): + self._emit_commit(cmd, program.gates) + elif isinstance(cmd, DiagnoseCmd): + self._emit_diagnose(cmd) + elif isinstance(cmd, SynthCmd): + self._emit_synth(cmd) + elif isinstance(cmd, TrainCmd): + self._emit_train(cmd) + elif isinstance(cmd, DebateCmd): + self._emit_debate(cmd) + elif isinstance(cmd, EditCmd): + self._emit_edit(cmd) + elif isinstance(cmd, ForkCmd): + self._emit_fork(cmd) + elif isinstance(cmd, ResetCmd): + self._emit_reset(cmd) + elif isinstance(cmd, PruneCmd): + self._emit_prune(cmd) + elif isinstance(cmd, FuseCmd): + self._emit_fuse(cmd) + elif isinstance(cmd, AbsorbCmd): + self._emit_absorb(cmd) + elif isinstance(cmd, SnapshotCmd): + self._emit_snapshot(cmd, program) + elif isinstance(cmd, ReportCmd): + self._emit_report(cmd, program) + elif isinstance(cmd, NotifyCmd): + self._emit_notify(cmd, program) + elif isinstance(cmd, SaveCmd): + self._emit_save(cmd, program) + elif isinstance(cmd, RepeatBlock): + self._emit_repeat(cmd, program) + elif isinstance(cmd, IfBlock): + self._emit_if(cmd, program) + + def _emit_repeat(self, cmd: RepeatBlock, program: TDProgram) -> None: + """REPEAT — run a block of commands N times. + + This is the core of td_loop: the self-improvement cycle. + Each iteration runs the body commands in order. + """ + n = cmd.count + self._emit(f'print("[td_lang] REPEAT — running {n} iterations")') + self._emit(f"for _loop_iter in range({n}):") + self._indent += 1 + self._emit(f'print(f"[td_lang] === Iteration {{_loop_iter + 1}}/{n} ===")') + self._emit("results['_loop_iter'] = _loop_iter") + if program.budget and program.budget.max_gpu_hours is not None: + self._emit("# Loop-level budget guard (GPU hours)") + self._emit("elapsed_hours = (time.time() - start_time) / 3600") + self._emit(f"if elapsed_hours >= {program.budget.max_gpu_hours}:") + self._indent += 1 + self._emit('print("[td_lang] Budget exceeded inside repeat — stopping loop.")') + self._emit("break") + self._indent -= 1 + self._emit("") + for body_cmd in cmd.body: + self._emit_cmd(body_cmd, program) + self._emit("") + self._emit(f'print(f"[td_lang] Iteration {{_loop_iter + 1}}/{n} complete.")') + self._indent -= 1 + self._emit(f'print("[td_lang] REPEAT complete — {n} iterations done.")') + + def _emit_if(self, cmd: IfBlock, program: TDProgram) -> None: + """IF/ELSE — conditional execution based on eval results. + + Conditions: + - eval_passed: last eval for target had no failures + - gate_passed: all gates passed for target + - improved: last eval score > previous eval score + """ + condition = cmd.condition + target = cmd.target + + self._emit(f'print("[td_lang] IF — checking {condition} for {target}")') + self._emit("") + + # Emit condition check + if condition == "eval_passed": + self._emit(f'_last_eval = results.get("{target}_eval", {{}})') + self._emit("_condition_met = bool(_last_eval) and _last_eval.get('overall', False)") + elif condition == "gate_passed": + gates = program.gates.must_pass if program.gates else [] + self._emit(f'_last_eval = results.get("{target}_eval", {{}})') + self._emit(f"_gates = {gates}") + self._emit("_condition_met = all(") + self._indent += 1 + self._emit("bool(_last_eval.get(g, {}).get('ok', False)) if isinstance(_last_eval.get(g), dict) else bool(_last_eval.get(g, False))") + self._emit("for g in _gates") + self._indent -= 1 + self._emit(") if _gates else bool(_last_eval)") + elif condition == "improved": + self._emit(f'_eval_history = results.get("{target}_eval_history", [])') + self._emit("_condition_met = len(_eval_history) >= 2 and _eval_history[-1] > _eval_history[-2]") + else: + # Generic: check if the condition key is truthy in results + self._emit(f'_condition_met = bool(results.get("{target}_{condition}", False))') + + self._emit("") + self._emit("if _condition_met:") + self._indent += 1 + self._emit(f'print("[td_lang] Condition {condition} = TRUE")') + for body_cmd in cmd.then_body: + self._emit_cmd(body_cmd, program) + self._emit("") + self._indent -= 1 + + if cmd.else_body: + self._emit("else:") + self._indent += 1 + self._emit(f'print("[td_lang] Condition {condition} = FALSE")') + for body_cmd in cmd.else_body: + self._emit_cmd(body_cmd, program) + self._emit("") + self._indent -= 1 + + def _emit_break_if(self, cmd: BreakIfCmd) -> None: + """BREAK_IF — early exit from repeat based on condition.""" + condition = cmd.condition + target = cmd.target or "" + self._emit(f'_brk_eval = results.get("{target}_eval", {{}})') + if condition == "improved": + self._emit(f'_hist = results.get("{target}_eval_history", [])') + self._emit("_brk_met = len(_hist) >= 2 and _hist[-1] <= _hist[-2]") + elif condition == "eval_passed": + self._emit("_brk_met = bool(_brk_eval.get('overall', False))") + else: + self._emit(f"_brk_met = bool(results.get('{target}_{condition}', False))") + self._emit("if _brk_met:") + self._indent += 1 + self._emit('print("[td_lang] break_if triggered — exiting loop")') + self._emit("break") + self._indent -= 1 + + # ---------------------------------------------------------------- Phase 6: Easy Merge emitters + + def _emit_fuse(self, cmd: FuseCmd) -> None: + """FUSE — merge multiple models into target in one command. + + From TD merge strategy: Transport and Merge (optimal transport cross-arch merging). + All 5 source models have different architectures — Transport and Merge handles this. + Merge into language backbone only, vision encoder stays untouched. + """ + target = cmd.target + sources = cmd.sources + method = cmd.method + strategy = cmd.strategy + n = len(sources) + + self._emit(f'print("[td_lang] FUSE — merging {n} models into {target} using {method}")') + self._emit(f'print("[td_lang] Strategy: {strategy}")') + self._emit(f"fuse_sources = {sources}") + self._emit(f'prev_ckpt = models.get("{target}", {{}}).get("checkpoint")') + self._emit("") + + # Auto-compute per-model strength + self._emit("# Auto-compute per-model merge strength") + if strategy == "equal": + self._emit(f"per_model_strength = round(1.0 / ({n} + 1), 3) # equal weight, target keeps its share") + self._emit(f'print(f"[td_lang] Equal strategy: each model gets {{per_model_strength}} strength")') + elif strategy == "sequential": + self._emit("# Sequential: merge one at a time with decreasing strength") + self._emit(f"strengths = [round(0.5 * (0.8 ** i), 3) for i in range({n})]") + self._emit('print(f"[td_lang] Sequential strategy: strengths = {strengths}")') + else: + # weighted — default to equal if no weights specified + self._emit(f"per_model_strength = round(1.0 / ({n} + 1), 3)") + self._emit("") + + # Loop through sources and merge each + self._emit("fuse_results = []") + self._emit("for fuse_idx, fuse_source in enumerate(fuse_sources):") + self._indent += 1 + self._emit(f'print(f"[td_lang] Fuse step {{fuse_idx + 1}}/{n}: merging {{fuse_source}}...")') + self._emit("") + + # Determine strength for this step + if strategy == "sequential": + self._emit("step_strength = strengths[fuse_idx]") + else: + self._emit("step_strength = per_model_strength") + self._emit("") + + # Match source to SOURCES config and pick method by architecture + self._emit("_stage = None") + self._emit("_arch = None") + self._emit("for _src in SOURCES:") + self._indent += 1 + self._emit("if _src.hf_id == fuse_source or _src.name.lower() in fuse_source.lower():") + self._indent += 1 + self._emit('_stage = _src.name.lower().split("-")[0]') + self._emit("_arch = getattr(_src, 'architecture', 'unknown')") + self._emit("_src.merge_alpha = step_strength") + self._emit("break") + self._indent -= 1 + self._indent -= 1 + self._emit("") + + self._emit("if _stage is None:") + self._indent += 1 + self._emit('print(f"[td_lang] WARNING: Could not match {fuse_source} to SOURCES. Attempting direct merge...")') + self._emit("# For Transport and Merge, we can merge any architecture directly") + self._emit(f'_stage = fuse_source.split("/")[-1].lower().replace("-", "_")[:20]') + self._emit('_arch = "unknown"') + self._indent -= 1 + self._emit("") + + # Run the merge + self._emit("cfg = MergeConfig()") + self._emit("# Auto-pick merge method by architecture match") + self._emit("chosen_method = 'slerp' if _arch == getattr(TARGET, 'architecture', 'unknown') else 'transport'") + self._emit(f"if '{method}' not in ['auto', '']: chosen_method = '{method}'") + self._emit("cfg.merge_method = chosen_method") + self._emit("merge_result = run_pipeline([_stage], cfg)") + self._emit("fuse_results.append({") + self._indent += 1 + self._emit('"source": fuse_source,') + self._emit('"stage": _stage,') + self._emit('"strength": step_strength,') + self._emit('"result": merge_result,') + self._indent -= 1 + self._emit("})") + self._emit("merged_stages.append(_stage)") + self._emit("") + + # Update checkpoint + self._emit('if merge_result.get("final_checkpoint"):') + self._indent += 1 + self._emit(f'models["{target}"]["checkpoint"] = merge_result["final_checkpoint"]') + self._emit("pre_score = quick_canary(prev_ckpt) if prev_ckpt else None") + self._emit("post_score = quick_canary(merge_result['final_checkpoint'])") + self._emit("if pre_score and post_score < 0.9 * pre_score:") + self._indent += 1 + self._emit('print(f"[td_lang] WARNING: quick canary degradation detected (pre={pre_score:.1f}, post={post_score:.1f})")') + self._indent -= 1 + self._indent -= 1 + self._emit(f'print(f"[td_lang] Fused {{fuse_source}} (strength={{step_strength}})")') + self._indent -= 1 + + self._emit("") + self._emit(f'results["{target}_fuse"] = fuse_results') + self._emit("") + + # Lineage: record every source + self._emit(f'lineage["{target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "fuse",') + self._emit(f'"sources": {sources},') + self._emit(f'"method": "{method}",') + self._emit(f'"strategy": "{strategy}",') + self._emit(f'"n_models": {n},') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + self._emit(f'print("[td_lang] FUSE complete — {n} models merged into {target}")') + + def _emit_absorb(self, cmd: AbsorbCmd) -> None: + """ABSORB — simplified single-model merge. + + One-liner shortcut: absorb "model" into target [strength 0.5] + Wraps the merge logic with sensible defaults. + """ + source = cmd.source + target = cmd.target + strength = cmd.strength + + self._emit(f'print("[td_lang] ABSORB — merging {source} into {target} (strength={strength})")') + self._emit(f'prev_ckpt = models.get("{target}", {{}}).get("checkpoint")') + self._emit("") + + # Match source + self._emit(f'_source_ref = "{source}"') + self._emit("_stage = None") + self._emit("_arch = None") + self._emit("for _src in SOURCES:") + self._indent += 1 + self._emit('if _src.hf_id == _source_ref or _src.name.lower() in _source_ref.lower():') + self._indent += 1 + self._emit('_stage = _src.name.lower().split("-")[0]') + self._emit("_arch = getattr(_src, 'architecture', 'unknown')") + self._emit("break") + self._indent -= 1 + self._indent -= 1 + self._emit("") + + self._emit("if _stage is None:") + self._indent += 1 + self._emit(f'print(f"[td_lang] WARNING: {{_source_ref}} not in SOURCES. Using direct ref.")') + self._emit(f'_stage = _source_ref.split("/")[-1].lower().replace("-", "_")[:20]') + self._emit('_arch = "unknown"') + self._indent -= 1 + self._emit("") + + # Auto strength search if requested + self._emit("strengths = []") + self._emit("if str(strength).lower() == 'auto':") + self._indent += 1 + self._emit("strengths = [0.2, 0.4, 0.6, 0.8]") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("strengths = [strength]") + self._indent -= 1 + self._emit("") + self._emit("best_score = -1") + self._emit("best_result = None") + self._emit("best_strength = strengths[0]") + self._emit("for s in strengths:") + self._indent += 1 + self._emit("cfg = MergeConfig()") + self._emit("# choose method by architecture") + self._emit("cfg.merge_method = 'slerp' if _arch == getattr(TARGET, 'architecture', 'unknown') else 'transport'") + self._emit("for _src in SOURCES:") + self._indent += 1 + self._emit("if _src.hf_id == _source_ref or _src.name.lower() in _source_ref.lower():") + self._indent += 1 + self._emit(" _src.merge_alpha = s") + self._indent -= 1 + self._emit("break") + self._indent -= 1 + self._emit("merge_result = run_pipeline([_stage], cfg)") + self._emit("ckpt = merge_result.get('final_checkpoint')") + self._emit("score = quick_canary(ckpt) if ckpt else -1") + self._emit("if score > best_score:") + self._indent += 1 + self._emit("best_score = score") + self._emit("best_result = merge_result") + self._emit("best_strength = s") + self._indent -= 1 + self._indent -= 1 + self._emit("") + self._emit("merge_result = best_result") + self._emit("cfg_strength = best_strength") + self._emit("merged_stages.append(_stage)") + self._emit("") + + # Update checkpoint + self._emit('if merge_result and merge_result.get("final_checkpoint"):') + self._indent += 1 + self._emit(f'models["{target}"]["checkpoint"] = merge_result["final_checkpoint"]') + self._emit("pre_score = quick_canary(prev_ckpt) if prev_ckpt else None") + self._emit("post_score = quick_canary(merge_result['final_checkpoint']) if merge_result else None") + self._emit("if pre_score and post_score and post_score < 0.9 * pre_score:") + self._indent += 1 + self._emit('print(f"[td_lang] WARNING: canary degradation (pre={pre_score:.1f}, post={post_score:.1f})")') + self._indent -= 1 + self._indent -= 1 + self._emit(f'results["{target}_absorb"] = merge_result') + self._emit("") + + # Lineage + self._emit(f'lineage["{target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "absorb",') + self._emit(f'"source": "{source}",') + self._emit(f'"strength": {strength},') + self._emit('"method": "auto" if str(strength).lower()=="auto" else "transport",') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + self._emit(f'print("[td_lang] ABSORB complete — {source} merged into {target}")') + + # ---------------------------------------------------------------- Phase 4 emitters + + def _emit_data_contract(self, dc: DataContractBlock) -> None: + """Emit data contract validation — checked at synth/train time. + + From ForgeSpec 2.0 (test_17): data contracts enforce schema on training data. + Required fields, minimum samples, max perplexity. + """ + self._emit("# Data Contract (Phase 4, ForgeSpec 2.0)") + self._emit("data_contract = {") + self._indent += 1 + self._emit(f'"required_fields": {dc.required_fields},') + if dc.min_samples is not None: + self._emit(f'"min_samples": {dc.min_samples},') + if dc.max_perplexity is not None: + self._emit(f'"max_perplexity": {dc.max_perplexity},') + self._indent -= 1 + self._emit("}") + self._emit("") + self._emit("def validate_data_contract(data_path, contract):") + self._indent += 1 + self._emit('"""Check training data against data contract."""') + self._emit("import json") + self._emit("errors = []") + self._emit("samples = []") + self._emit("with open(data_path) as f:") + self._indent += 1 + self._emit("for line_num, line in enumerate(f, 1):") + self._indent += 1 + self._emit("line = line.strip()") + self._emit("if not line: continue") + self._emit("try:") + self._indent += 1 + self._emit("sample = json.loads(line)") + self._emit("samples.append(sample)") + self._emit('for field in contract.get("required_fields", []):') + self._indent += 1 + self._emit("if field not in sample:") + self._indent += 1 + self._emit('errors.append(f"Line {line_num}: missing required field \'{field}\'")') + self._indent -= 2 + self._indent -= 1 + self._emit("except json.JSONDecodeError:") + self._indent += 1 + self._emit('errors.append(f"Line {line_num}: invalid JSON")') + self._indent -= 2 + self._indent -= 1 + self._emit('min_s = contract.get("min_samples")') + self._emit("if min_s and len(samples) < min_s:") + self._indent += 1 + self._emit('errors.append(f"Need {min_s} samples, got {len(samples)}")') + self._indent -= 1 + self._emit("if errors:") + self._indent += 1 + self._emit('print("[td_lang] DATA CONTRACT VIOLATIONS:")') + self._emit("for e in errors[:10]:") + self._indent += 1 + self._emit('print(f" - {e}")') + self._indent -= 1 + self._emit("if len(errors) > 10:") + self._indent += 1 + self._emit('print(f" ... and {len(errors)-10} more")') + self._indent -= 1 + self._emit('raise ValueError(f"Data contract failed: {len(errors)} violations")') + self._indent -= 1 + self._emit('print(f"[td_lang] Data contract OK: {len(samples)} samples, all fields present.")') + self._emit("return samples") + self._indent -= 1 + self._emit("") + + def _emit_reward_contract(self, rc: RewardContractBlock) -> None: + """Emit reward contract — enforced during GRPO training. + + From test_16: verified rewards only, no learned reward model. + """ + self._emit("# Reward Contract (Phase 4, ForgeSpec 2.0)") + self._emit("reward_contract = {") + self._indent += 1 + self._emit(f'"verifiers": {rc.verifiers},') + if rc.min_reward is not None: + self._emit(f'"min_reward": {rc.min_reward},') + self._indent -= 1 + self._emit("}") + self._emit('print(f"[td_lang] Reward contract: verifiers={reward_contract[\'verifiers\']}")') + self._emit("") + + def _emit_snapshot(self, cmd: SnapshotCmd, program: TDProgram) -> None: + """SNAPSHOT — content-hashed model state for artifact lineage. + + From ForgeSpec 2.0 (test_17): every model state gets a content-addressed hash. + Directory contains: model weights/adapters, eval report, prune spec, manifest. + """ + alias = cmd.target + output_dir = cmd.output or "td_lang_outputs/snapshots" + + self._emit(f'print("[td_lang] SNAPSHOT — saving content-hashed state for {alias}")') + self._emit("import hashlib, json, time") + self._emit(f'snap_model = models["{alias}"]') + self._emit("") + + # Compute content hash from model state + self._emit("# Content hash from model parameters (first 10 layers for speed)") + self._emit("hasher = hashlib.sha256()") + self._emit("param_count = 0") + self._emit("if hasattr(snap_model, 'state_dict'):") + self._indent += 1 + self._emit("for name, param in list(snap_model.state_dict().items())[:50]:") + self._indent += 1 + self._emit("hasher.update(param.cpu().numpy().tobytes()[:1024])") + self._emit("param_count += param.numel()") + self._indent -= 1 + self._indent -= 1 + self._emit("elif isinstance(snap_model, dict):") + self._indent += 1 + self._emit("for k, v in snap_model.items():") + self._indent += 1 + self._emit("hasher.update(str(v).encode()[:256])") + self._indent -= 1 + self._indent -= 1 + self._emit("content_hash = hasher.hexdigest()[:16]") + self._emit(f'snap_dir = os.path.join(output_dir, "{output_dir}", f"{alias}_{{content_hash}}")') + self._emit("os.makedirs(snap_dir, exist_ok=True)") + self._emit("") + + # Write manifest + self._emit("# Snapshot manifest — full provenance record") + self._emit("snap_manifest = {") + self._indent += 1 + self._emit(f'"alias": "{alias}",') + self._emit('"content_hash": content_hash,') + self._emit('"param_count": param_count,') + self._emit('"timestamp": datetime.now().isoformat(),') + self._emit(f'"lineage": lineage.get("{alias}", {{}}),') + self._emit(f'"eval_results": results.get("{alias}_eval", None),') + self._emit(f'"diagnose_results": results.get("{alias}_diagnose", None),') + self._indent -= 1 + self._emit("}") + self._emit("") + + # Save adapter if PEFT, else note checkpoint location + self._emit("if hasattr(snap_model, 'peft_config'):") + self._indent += 1 + self._emit('adapter_dir = os.path.join(snap_dir, "adapters")') + self._emit("snap_model.save_pretrained(adapter_dir)") + self._emit('snap_manifest["has_adapters"] = True') + self._emit('snap_manifest["adapter_dir"] = adapter_dir') + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit(f'ckpt = models.get("{alias}", {{}}).get("checkpoint") if isinstance(models.get("{alias}"), dict) else None') + self._emit('snap_manifest["has_adapters"] = False') + self._emit('snap_manifest["checkpoint_ref"] = str(ckpt) if ckpt else "in_memory"') + self._indent -= 1 + self._emit("") + + # Write manifest JSON + self._emit('manifest_path = os.path.join(snap_dir, "snapshot_manifest.json")') + self._emit('with open(manifest_path, "w") as f:') + self._indent += 1 + self._emit("json.dump(snap_manifest, f, indent=2, default=str)") + self._indent -= 1 + self._emit(f'print(f"[td_lang] Snapshot saved: {{snap_dir}}")') + self._emit(f'print(f"[td_lang] Content hash: {{content_hash}}")') + self._emit("") + + # Update lineage + self._emit(f'lineage.setdefault("{alias}", {{"operations": []}})["operations"].append({{') + self._indent += 1 + self._emit('"op": "snapshot",') + self._emit('"content_hash": content_hash,') + self._emit('"snap_dir": snap_dir,') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + + def _emit_report(self, cmd: ReportCmd, program: TDProgram) -> None: + """REPORT — economics report for the run. + + Tracks GPU hours, cost, tokens, time per command. + From test_17 ForgeSpec 2.0: economics reports for cost tracking. + """ + output = cmd.output or "economics_report.json" + + self._emit('print("[td_lang] REPORT — generating economics report")') + self._emit("elapsed = time.time() - start_time") + self._emit("") + self._emit("report = {") + self._indent += 1 + self._emit('"td_lang_version": "0.2.0",') + self._emit('"timestamp": datetime.now().isoformat(),') + self._emit('"elapsed_seconds": round(elapsed, 2),') + self._emit('"elapsed_minutes": round(elapsed / 60, 2),') + self._emit(f'"gpu_hourly_rate": {self.GPU_HOURLY},') + self._emit('"estimated_cost": round(elapsed / 3600 * GPU_HOURLY, 2),') + self._emit('"models_loaded": list(models.keys()),') + self._emit('"merged_stages": merged_stages,') + self._emit('"lineage_summary": {},') + self._indent -= 1 + self._emit("}") + self._emit("") + + # Compute per-model operation counts + self._emit("for alias, lin in lineage.items():") + self._indent += 1 + self._emit("ops = lin.get('operations', [])") + self._emit("op_counts = {}") + self._emit("for op in ops:") + self._indent += 1 + self._emit("op_type = op.get('op', 'unknown')") + self._emit("op_counts[op_type] = op_counts.get(op_type, 0) + 1") + self._indent -= 1 + self._emit('report["lineage_summary"][alias] = {') + self._indent += 1 + self._emit('"total_operations": len(ops),') + self._emit('"operation_counts": op_counts,') + self._indent -= 1 + self._emit("}") + self._indent -= 1 + self._emit("") + + # Add eval results summary + self._emit("eval_summary = {}") + self._emit("for key, val in results.items():") + self._indent += 1 + self._emit('if "_eval" in key:') + self._indent += 1 + self._emit("if isinstance(val, dict):") + self._indent += 1 + self._emit("eval_summary[key] = {k: v for k, v in val.items() if k != 'raw'}") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit('eval_summary[key] = str(val)[:200]') + self._indent -= 2 + self._indent -= 1 + self._emit('report["eval_summary"] = eval_summary') + self._emit("") + + # Has contracts? + if program.data_contract: + self._emit('report["data_contract"] = data_contract') + if program.reward_contract: + self._emit('report["reward_contract"] = reward_contract') + + # Save + self._emit(f'report_path = Path("{output}")') + self._emit("report_path.parent.mkdir(parents=True, exist_ok=True)") + self._emit('with open(report_path, "w") as f:') + self._indent += 1 + self._emit("json.dump(report, f, indent=2, default=str)") + self._indent -= 1 + self._emit(f'print(f"[td_lang] Economics report saved to {{report_path}}")') + self._emit('print(f"[td_lang] Time: {report[\'elapsed_minutes\']} min")') + self._emit('print(f"[td_lang] Estimated cost: ${report[\'estimated_cost\']}")') + self._emit('print(f"[td_lang] Models: {report[\'models_loaded\']}")') + + # ---------------------------------------------------------------- Phase 8: Autopilot emitters + + def _emit_setup(self, setup: SetupBlock) -> None: + """SETUP — auto-install dependencies and configure environment. + + Runs at script start: pip install, HF token, ntfy config. + """ + self._emit("# ========== SETUP (Phase 8 — Autopilot) ==========") + self._emit('print("[td_lang] SETUP — configuring environment...")') + self._emit("") + + # pip install + if setup.pip_packages: + pkg_str = " ".join(setup.pip_packages) + self._emit(f"# Install dependencies") + self._emit(f'_pip_pkgs = "{pkg_str}"') + self._emit("import subprocess as _sp") + self._emit('print(f"[td_lang] Installing: {_pip_pkgs}")') + self._emit("try:") + self._indent += 1 + self._emit('_sp.check_call([sys.executable, "-m", "pip", "install", "--break-system-packages", "-q"]') + self._emit(f' + _pip_pkgs.split())') + self._emit('print("[td_lang] Dependencies installed.")') + self._indent -= 1 + self._emit("except Exception as e:") + self._indent += 1 + self._emit('print(f"[td_lang] WARNING: pip install failed: {e}")') + self._emit('print("[td_lang] Continuing anyway — packages may already be installed.")') + self._indent -= 1 + self._emit("") + + # HF token + if setup.hf_token: + self._emit("# HuggingFace authentication") + if setup.hf_token == "env": + self._emit('_hf_token = os.environ.get("HF_TOKEN", "")') + else: + self._emit(f'_hf_token = "{setup.hf_token}"') + self._emit("if _hf_token:") + self._indent += 1 + self._emit("os.environ['HF_TOKEN'] = _hf_token") + self._emit("try:") + self._indent += 1 + self._emit("from huggingface_hub import login") + self._emit("login(token=_hf_token, add_to_git_credential=False)") + self._emit('print("[td_lang] HuggingFace authenticated.")') + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit('print("[td_lang] HF login via huggingface_hub failed, using env var.")') + self._indent -= 1 + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit('print("[td_lang] WARNING: No HF_TOKEN found. Gated models may fail to download.")') + self._indent -= 1 + self._emit("") + + # ntfy notification endpoint + if setup.notify_url: + self._emit("# Notification endpoint (ntfy.sh)") + self._emit(f'NTFY_URL = "{setup.notify_url}"') + self._emit("") + self._emit("def td_notify(msg):") + self._indent += 1 + self._emit('"""Send notification via ntfy.sh."""') + self._emit("try:") + self._indent += 1 + self._emit("import urllib.request") + self._emit("req = urllib.request.Request(") + self._indent += 1 + self._emit('f"https://{NTFY_URL}" if not NTFY_URL.startswith("http") else NTFY_URL,') + self._emit("data=msg.encode(),") + self._emit('method="POST",') + self._indent -= 1 + self._emit(")") + self._emit("urllib.request.urlopen(req, timeout=10)") + self._emit('print(f"[td_lang] Notified: {msg}")') + self._indent -= 1 + self._emit("except Exception as e:") + self._indent += 1 + self._emit('print(f"[td_lang] Notify failed: {e}")') + self._indent -= 1 + self._indent -= 1 + else: + self._emit("def td_notify(msg):") + self._indent += 1 + self._emit('print(f"[td_lang] (no ntfy configured) {msg}")') + self._indent -= 1 + + self._emit("") + self._emit('td_notify("TD pipeline starting...")') + self._emit('print("[td_lang] SETUP complete.")') + self._emit("") + + def _emit_on_error(self, on_error: OnErrorBlock, program: TDProgram) -> None: + """ON_ERROR — wrap each step in retry/fallback logic. + + Emits a td_safe_run() helper that wraps any function call with: + - Retry N times on failure + - Fallback strategies (reduce batch, skip, snapshot+stop) + - Optional ntfy notification on error + """ + self._emit("# ========== ON_ERROR (Phase 8 — Crash Recovery) ==========") + self._emit(f"TD_MAX_RETRIES = {on_error.retry}") + self._emit(f'TD_FALLBACK = "{on_error.fallback}"') + self._emit(f"TD_NOTIFY_ON_ERROR = {on_error.notify}") + self._emit("") + self._emit("def td_safe_run(step_name, fn, *args, **kwargs):") + self._indent += 1 + self._emit('"""Run a step with retry and fallback on error."""') + self._emit("import traceback") + self._emit("for attempt in range(1, TD_MAX_RETRIES + 1):") + self._indent += 1 + self._emit("try:") + self._indent += 1 + self._emit("return fn(*args, **kwargs)") + self._indent -= 1 + self._emit("except torch.cuda.OutOfMemoryError as oom:") + self._indent += 1 + self._emit('print(f"[td_lang] OOM on {step_name} (attempt {attempt}/{TD_MAX_RETRIES})")') + self._emit("torch.cuda.empty_cache()") + self._emit("import gc; gc.collect()") + self._emit('if TD_FALLBACK == "reduce_batch":') + self._indent += 1 + self._emit('print("[td_lang] Reducing batch size and retrying...")') + self._emit('os.environ["TD_REDUCE_BATCH"] = "1"') + self._indent -= 1 + self._emit('elif TD_FALLBACK == "skip":') + self._indent += 1 + self._emit('print(f"[td_lang] Skipping {step_name}")') + self._emit("return None") + self._indent -= 1 + self._emit('elif TD_FALLBACK == "snapshot_and_stop":') + self._indent += 1 + self._emit('print(f"[td_lang] OOM — saving snapshot and stopping.")') + self._emit("if TD_NOTIFY_ON_ERROR:") + self._indent += 1 + self._emit('td_notify(f"OOM on {step_name} — snapshot saved, stopping.")') + self._indent -= 1 + self._emit("raise") + self._indent -= 2 + self._emit("except Exception as e:") + self._indent += 1 + self._emit('print(f"[td_lang] Error on {step_name} (attempt {attempt}/{TD_MAX_RETRIES}): {e}")') + self._emit("traceback.print_exc()") + self._emit("if attempt == TD_MAX_RETRIES:") + self._indent += 1 + self._emit("if TD_NOTIFY_ON_ERROR:") + self._indent += 1 + self._emit('td_notify(f"FAILED: {step_name} after {TD_MAX_RETRIES} retries — {e}")') + self._indent -= 1 + self._emit('if TD_FALLBACK == "skip":') + self._indent += 1 + self._emit("return None") + self._indent -= 1 + self._emit("raise") + self._indent -= 2 + self._indent -= 1 + self._indent -= 1 + self._emit("") + + def _emit_notify(self, cmd: NotifyCmd, program: TDProgram) -> None: + """NOTIFY — send message via ntfy.sh.""" + msg = cmd.message.replace('"', '\\"') + self._emit(f'td_notify("{msg}")') + + def _emit_save(self, cmd: SaveCmd, program: TDProgram) -> None: + """SAVE — upload model to cloud storage via rclone. + + Uses rclone to copy model checkpoint/adapters to Google Drive or any remote. + """ + alias = cmd.target + dest = cmd.destination + + self._emit(f'print("[td_lang] SAVE — uploading {alias} to {dest}")') + self._emit("") + + # Find the model's checkpoint directory + self._emit(f'_save_model = models.get("{alias}", {{}})') + self._emit('_save_path = _save_model.get("checkpoint") if isinstance(_save_model, dict) else None') + self._emit("") + + # If PEFT model, save adapters first + self._emit('if hasattr(_save_model, "peft_config") or (isinstance(_save_model, dict) and _save_model.get("has_adapters")):') + self._indent += 1 + self._emit(f'_adapter_dir = f"td_lang_outputs/{alias}_save_adapters"') + self._emit("os.makedirs(_adapter_dir, exist_ok=True)") + self._emit("if hasattr(_save_model, 'save_pretrained'):") + self._indent += 1 + self._emit("_save_model.save_pretrained(_adapter_dir)") + self._indent -= 1 + self._emit("_save_path = _adapter_dir") + self._indent -= 1 + self._emit("") + + # Use rclone to upload + self._emit("if _save_path:") + self._indent += 1 + self._emit(f'_rclone_cmd = ["rclone", "copy", str(_save_path), "{dest}", "--progress"]') + self._emit('_rclone_str = " ".join(_rclone_cmd)') + self._emit('print(f"[td_lang] Running: {_rclone_str}")') + self._emit("try:") + self._indent += 1 + self._emit("import subprocess as _sp") + self._emit("_sp.check_call(_rclone_cmd)") + self._emit(f'print("[td_lang] SAVE complete — {alias} uploaded to {dest}")') + self._emit(f'td_notify("Model {alias} saved to {dest}")') + self._indent -= 1 + self._emit("except FileNotFoundError:") + self._indent += 1 + self._emit('print("[td_lang] ERROR: rclone not found. Install it: curl https://rclone.org/install.sh | sudo bash")') + self._emit('print("[td_lang] Then configure: rclone config (add Google Drive remote)")') + self._emit(f'td_notify("SAVE FAILED: rclone not installed")') + self._indent -= 1 + self._emit("except Exception as e:") + self._indent += 1 + self._emit('print(f"[td_lang] SAVE error: {e}")') + self._emit(f'td_notify(f"SAVE FAILED: {{e}}")') + self._indent -= 1 + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit(f'print("[td_lang] WARNING: No checkpoint found for {alias}. Nothing to save.")') + self._emit(f'print("[td_lang] Run commit or snapshot first to create a checkpoint.")') + self._indent -= 1 + + # Lineage + self._emit("") + self._emit(f'lineage.setdefault("{alias}", {{"operations": []}})["operations"].append({{') + self._indent += 1 + self._emit('"op": "save",') + self._emit(f'"destination": "{dest}",') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + + # ---------------------------------------------------------------- Budget + summary + def _emit_budget_check(self, program: TDProgram) -> None: + budget = program.budget or BudgetBlock() + est_gpu = 0.0 + est_tokens = 0 + est_experiments = 0 + + for cmd in program.commands: + if isinstance(cmd, LoadCmd): + est_gpu += 0.05 + elif isinstance(cmd, MergeCmd): + est_gpu += 2.0 + est_tokens += 8_000_000 + est_experiments += 1 + elif isinstance(cmd, HealCmd): + est_gpu += 0.5 * cmd.epochs + est_tokens += 1_000_000 * cmd.epochs + est_experiments += 1 + elif isinstance(cmd, EvalCmd): + est_gpu += 0.1 + est_tokens += 200_000 + elif isinstance(cmd, CommitCmd): + est_gpu += 0.01 + elif isinstance(cmd, DiagnoseCmd): + est_gpu += 0.2 + est_tokens += 500_000 + elif isinstance(cmd, SynthCmd): + est_gpu += 1.0 + est_tokens += 5_000_000 + est_experiments += 1 + elif isinstance(cmd, TrainCmd): + steps = cmd.steps or 64 + est_gpu += 0.5 + (steps / 64) * 1.5 + est_tokens += steps * 100_000 + est_experiments += 1 + elif isinstance(cmd, DebateCmd): + est_gpu += 0.3 * cmd.rounds + est_tokens += cmd.rounds * cmd.candidates * 200_000 + elif isinstance(cmd, EditCmd): + est_gpu += 0.5 # adapter setup + dry-run + est_tokens += 500_000 + est_experiments += 1 + elif isinstance(cmd, ForkCmd): + est_gpu += 0.1 # mostly disk I/O + elif isinstance(cmd, ResetCmd): + est_gpu += 0.15 # reload from disk + elif isinstance(cmd, PruneCmd): + est_gpu += 1.0 # calibration + pruning pass + est_tokens += 1_000_000 + est_experiments += 1 + elif isinstance(cmd, FuseCmd): + n = len(cmd.sources) + est_gpu += 2.0 * n # ~2 hrs per model merge + est_tokens += 8_000_000 * n + est_experiments += n + elif isinstance(cmd, AbsorbCmd): + est_gpu += 2.0 + est_tokens += 8_000_000 + est_experiments += 1 + elif isinstance(cmd, RepeatBlock): + # Budget for repeat: estimate body cost * iterations + body_est = 1.0 * len(cmd.body) # rough: 1 GPU hr per body command + est_gpu += body_est * cmd.count + est_experiments += cmd.count + elif isinstance(cmd, IfBlock): + est_gpu += 0.5 # conditional overhead + elif isinstance(cmd, SnapshotCmd): + est_gpu += 0.05 # mostly disk I/O + hashing + elif isinstance(cmd, ReportCmd): + est_gpu += 0.01 # just JSON output + + est_cost = est_gpu * self.GPU_HOURLY + + self._emit("# Budget heuristic (estimated before execution)") + self._emit(f"est_gpu_hours = {est_gpu:.4f}") + self._emit(f"est_tokens = {est_tokens}") + self._emit(f"est_experiments = {est_experiments}") + self._emit("est_cost = est_gpu_hours * GPU_HOURLY") + + if budget.max_gpu_hours is not None: + self._emit(f"if est_gpu_hours > {budget.max_gpu_hours}:") + self._indent += 1 + self._emit(f'raise TDBudgetError("max_gpu_hours", {budget.max_gpu_hours}, est_gpu_hours)') + self._indent -= 1 + if budget.max_cost is not None: + self._emit(f"if est_cost > {budget.max_cost}:") + self._indent += 1 + self._emit(f'raise TDBudgetError("max_cost", {budget.max_cost}, est_cost)') + self._indent -= 1 + if budget.max_tokens is not None: + self._emit(f"if est_tokens > {budget.max_tokens}:") + self._indent += 1 + self._emit(f'raise TDBudgetError("max_tokens", {budget.max_tokens}, est_tokens)') + self._indent -= 1 + if budget.max_experiments is not None: + self._emit(f"if est_experiments > {budget.max_experiments}:") + self._indent += 1 + self._emit(f'raise TDBudgetError("max_experiments", {budget.max_experiments}, est_experiments)') + self._indent -= 1 + self._emit('print("[td_lang] Budget check passed.")') + self._emit("") + + def _emit_summary(self) -> None: + self._emit("# --- Final Summary ---") + self._emit("elapsed = time.time() - start_time") + self._emit('print("\\n" + "=" * 60)') + self._emit('print("TD LANG COMPLETE")') + self._emit('print("=" * 60)') + self._emit('print(f" Time: {elapsed / 60:.1f} minutes")') + self._emit('print(f" Models: {list(models.keys())}")') + self._emit('print(f" Merged stages: {merged_stages}")') + self._emit('print("=" * 60)') + self._emit('td_notify(f"TD pipeline DONE in {elapsed / 60:.1f} min. Models: {list(models.keys())}")') + + # ---------------------------------------------------------------- Util + def _emit(self, line: str) -> None: + if line == "": + self._lines.append("") + else: + prefix = " " * self._indent + self._lines.append(prefix + line) + + def _emit_comment(self, text: str) -> None: + self._emit(f"# {text}") + + +def compile_program(program: TDProgram) -> str: + """Public helper to compile a TDProgram into Python code.""" + return TDCompiler().compile(program) diff --git a/hugging/td_lang/errors.py b/hugging/td_lang/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..ae9cec9e891235723a70b6d50f5581297b2a2bfb --- /dev/null +++ b/hugging/td_lang/errors.py @@ -0,0 +1,99 @@ +""" +TD Lang Errors — Clear, helpful error messages. + +Milan is 11 — errors should say what went wrong and where, +not dump cryptic stack traces. +""" + + +class TDLangError(Exception): + """Base error for all td_lang errors.""" + + def __init__(self, message: str, line: int | None = None, hint: str | None = None): + self.line = line + self.hint = hint + if line is not None: + full = f"Line {line}: {message}" + else: + full = message + if hint: + full += f"\n Hint: {hint}" + super().__init__(full) + + +class TDSyntaxError(TDLangError): + """Bad .td syntax — couldn't understand the file.""" + pass + + +class TDCompileError(TDLangError): + """Valid syntax but impossible plan — e.g., merging into a model that doesn't exist.""" + pass + + +class TDGateError(TDLangError): + """Gates failed during execution.""" + + def __init__(self, failed_gates: list[str], message: str = ""): + self.failed_gates = failed_gates + msg = message or f"Gates failed: {', '.join(failed_gates)}" + super().__init__(msg, hint="Check eval results — the model may have regressed.") + + +class TDBudgetError(TDLangError): + """Budget would be exceeded — compiler refuses to run.""" + + def __init__(self, field: str, limit: float, requested: float): + self.field = field + self.limit = limit + self.requested = requested + super().__init__( + f"Budget exceeded: {field} limit is {limit}, but plan needs ~{requested}", + hint="Reduce steps, use fewer merges, or increase the budget.", + ) + + +class TDContractError(TDLangError): + """Data or reward contract violation — training data doesn't match spec.""" + + def __init__(self, contract_type: str, violations: list[str]): + self.contract_type = contract_type + self.violations = violations + msg = f"{contract_type} contract failed with {len(violations)} violation(s)" + if violations: + msg += f": {violations[0]}" + if len(violations) > 1: + msg += f" (and {len(violations)-1} more)" + super().__init__( + msg, + hint="Check your training data matches the contract spec.", + ) + + +# ============================================================================ +# COMMON MISTAKE SUGGESTIONS (Phase 5) +# ============================================================================ + +COMMON_FIXES = { + "load": 'Did you forget quotes? Correct: load "model/path" as name', + "merge": 'Format: merge "source" into target using method [strength 0.5]', + "edit": "Format: edit target layers 16-28 using lora [lr 1e-4]", + "prune": "Format: prune target using wanda [aggressiveness 0.2]", + "fork": "Format: fork source as new_name", + "reset": 'Format: reset target to "checkpoint_path"', + "train": 'Format: train target on "dataset" using grpo [steps 64]', + "synth": "Format: synth target from source [filter cherry_llm]", + "snapshot": "Format: snapshot target [-> output_dir]", + "report": "Format: report [-> economics.json]", + "fuse": 'Format: fuse ["model1", "model2"] into target [strategy equal]', + "absorb": 'Format: absorb "model" into target [strength 0.5]', +} + + +def suggest_fix(token: str) -> str | None: + """Given a failed token, suggest the correct syntax.""" + token_lower = token.lower().strip() + for keyword, fix in COMMON_FIXES.items(): + if keyword in token_lower: + return fix + return None diff --git a/hugging/td_lang/examples/demo_autopilot.td b/hugging/td_lang/examples/demo_autopilot.td new file mode 100644 index 0000000000000000000000000000000000000000..4457a9d1198646b5704168ed645748647d9c0a32 --- /dev/null +++ b/hugging/td_lang/examples/demo_autopilot.td @@ -0,0 +1,62 @@ +# demo_autopilot.td — The full "rent a GPU and go" pipeline +# Rent vast.ai, upload this file, run: python -m td_lang run demo_autopilot.td +# Then sit back — you'll get ntfy notifications on your phone. + +# === ENVIRONMENT === +setup { + pip = [torch, transformers, peft, bitsandbytes, trl, safetensors, datasets, accelerate, huggingface_hub, sentencepiece] + hf_token = env + notify = "ntfy.sh/my_ai" +} + +on_error { + retry = 3 + fallback = reduce_batch + notify = true +} + +# === QUALITY RULES === +gate { must_pass = [canary, perplexity, thinking_mode] } +budget { max_gpu_hours = 40 max_cost = 160.00 } + +data_contract { + required_fields = [prompt, response] + min_samples = 50 + max_perplexity = 50.0 +} + +reward_contract { + verifiers = [code_compiles, math_correct] + min_reward = 0.3 +} + +# === PIPELINE === + +# Step 1: Load and fuse +load "Qwen/Qwen3-VL-8B-Instruct" as base +fuse ["deepseek-ai/DeepSeek-R1", "MiMo-7B", "meta-llama/Llama-3.1-8B", "tiiuae/Falcon-H1R-7B"] into base +heal base lora_r 32 epochs 2 +notify "Merge + heal complete. Starting self-improvement loop." + +# Step 2: Self-improvement loop +repeat 5 { + diagnose base -> weaknesses.json + synth base from base filter cherry_llm -> training_data.jsonl + train base on "training_data.jsonl" using grpo steps 64 lr 5e-5 + eval base -> eval_results.json + + if eval_passed base { + commit base + snapshot base -> snapshots/ + notify "Loop iteration passed! Model improved." + } else { + reset base to "snapshots/" + notify "Loop iteration failed. Reset to last good snapshot." + } +} + +# Step 3: Save and notify +snapshot base -> final_model/ +save base to "gdrive:TD/models/final" +report -> economics.json +notify "TD PIPELINE COMPLETE. Model saved to Google Drive." diff --git a/hugging/td_lang/examples/demo_full.td b/hugging/td_lang/examples/demo_full.td new file mode 100644 index 0000000000000000000000000000000000000000..55fef7369ae5d684b9e01e6d82e81dedef1f458b --- /dev/null +++ b/hugging/td_lang/examples/demo_full.td @@ -0,0 +1,17 @@ +# Full Phase 1 demo with gates and budget +gate { + must_pass = [canary, perplexity, thinking_mode] +} + +budget { + max_gpu_hours = 8 + max_cost = 50.00 + max_tokens = 20000000 + max_experiments = 4 +} + +load "Qwen/Qwen3-VL-8B-Instruct" as base +merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5 +heal base lora_r 32 epochs 2 +eval base -> full_eval.json +commit base diff --git a/hugging/td_lang/examples/demo_fuse.td b/hugging/td_lang/examples/demo_fuse.td new file mode 100644 index 0000000000000000000000000000000000000000..a61ca8a625082135a0f5d80925d7777af39de287 --- /dev/null +++ b/hugging/td_lang/examples/demo_fuse.td @@ -0,0 +1,19 @@ +# demo_fuse.td — Easy merge: fuse multiple models in one command +# The entire TD merge strategy in 5 lines + +gate { must_pass = [canary, perplexity, thinking_mode] } +budget { max_gpu_hours = 30 max_cost = 120.00 } + +load "Qwen/Qwen3-VL-8B-Instruct" as base + +# Fuse all 4 donor models in one shot — auto Transport and Merge +fuse ["deepseek-ai/DeepSeek-R1", "MiMo-7B", "meta-llama/Llama-3.1-8B", "tiiuae/Falcon-H1R-7B"] into base + +# Or absorb a single model with custom strength +# absorb "deepseek-ai/DeepSeek-R1" into base strength 0.6 + +heal base lora_r 32 epochs 2 +eval base -> post_fuse_eval.json +commit base if [canary, perplexity, thinking_mode] +snapshot base -> snapshots/ +report -> economics.json diff --git a/hugging/td_lang/examples/demo_heal.td b/hugging/td_lang/examples/demo_heal.td new file mode 100644 index 0000000000000000000000000000000000000000..5cf189d73cb12b3e12183d34b56871437a8c3f65 --- /dev/null +++ b/hugging/td_lang/examples/demo_heal.td @@ -0,0 +1,6 @@ +# Demo: merge then heal, evaluate, and commit with gates +load "Qwen/Qwen3-VL-8B-Instruct" as base +merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5 +heal base lora_r 32 epochs 2 +eval base -> report.json +commit base if [canary, perplexity, thinking_mode] diff --git a/hugging/td_lang/examples/demo_loop.td b/hugging/td_lang/examples/demo_loop.td new file mode 100644 index 0000000000000000000000000000000000000000..248e75d49145d479d95bb0cdb09b26849aac94d2 --- /dev/null +++ b/hugging/td_lang/examples/demo_loop.td @@ -0,0 +1,28 @@ +# demo_loop.td — Self-improvement loop (Phase 2) +# The core TD cycle: diagnose -> synth -> train -> evaluate -> commit + +gate { + must_pass = [canary, perplexity, thinking_mode] +} + +budget { + max_gpu_hours = 10 + max_cost = 40.00 +} + +load "Qwen/Qwen3-VL-8B-Instruct" as base + +# Step 1: Ask the model what it's bad at +diagnose base -> weaknesses.json + +# Step 2: Generate training data targeting those weaknesses +synth base from web_curated filter cherry_llm -> synth_data.jsonl + +# Step 3: Train with GRPO (64 steps = sweet spot from test_15) +train base on "synth_data.jsonl" using grpo steps 64 + +# Step 4: Check if it actually got better +eval base -> post_training_eval.json + +# Step 5: Only save if gates pass +commit base diff --git a/hugging/td_lang/examples/demo_merge.td b/hugging/td_lang/examples/demo_merge.td new file mode 100644 index 0000000000000000000000000000000000000000..2e9fec2a24d048c4137c38dec6da4426a88016d2 --- /dev/null +++ b/hugging/td_lang/examples/demo_merge.td @@ -0,0 +1,5 @@ +# Demo: load + merge + eval + commit +load "Qwen/Qwen3-VL-8B-Instruct" as base +merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5 +eval base -> eval_base.json +commit base if [canary, perplexity, thinking_mode] diff --git a/hugging/td_lang/examples/demo_phase3.td b/hugging/td_lang/examples/demo_phase3.td new file mode 100644 index 0000000000000000000000000000000000000000..816d33848c95f2169c21c03ee8f9ef82b9ac9b16 --- /dev/null +++ b/hugging/td_lang/examples/demo_phase3.td @@ -0,0 +1,26 @@ +# demo_phase3.td — Phase 3 commands: edit, fork, reset, prune +# The full surgical toolkit for model experimentation + +gate { + must_pass = [canary, perplexity, thinking_mode] +} + +budget { + max_gpu_hours = 12 + max_cost = 60.00 +} + +# Load the base model +load "Qwen/Qwen3-VL-8B-Instruct" as base + +# Fork before experimenting (like git branch) +fork base as experiment + +# Surgical edit: LoRA on reasoning layers 16-28 +edit experiment layers 16-28 using lora lr 1e-4 + +# Evaluate the edit +eval experiment -> post_edit_eval.json + +# If it's good, commit; if bad, we can reset +commit experiment diff --git a/hugging/td_lang/examples/demo_phase4.td b/hugging/td_lang/examples/demo_phase4.td new file mode 100644 index 0000000000000000000000000000000000000000..8a6391155aedebede9dceefc5c6a040a748a4573 --- /dev/null +++ b/hugging/td_lang/examples/demo_phase4.td @@ -0,0 +1,33 @@ +# demo_phase4.td — Phase 4: Contracts, Lineage, Economics +# ForgeSpec 2.0 features from test_17 + +gate { must_pass = [canary, perplexity, thinking_mode] } + +budget { + max_gpu_hours = 20 + max_cost = 100.00 +} + +data_contract { + required_fields = [prompt, response] + min_samples = 100 + max_perplexity = 50.0 +} + +reward_contract { + verifiers = [code_compiles, math_correct] + min_reward = 0.3 +} + +# Pipeline with full tracking +load "Qwen/Qwen3-VL-8B-Instruct" as base +fork base as experiment + +edit experiment layers 16-28 using lora lr 1e-4 +snapshot experiment -> snapshots/ + +eval experiment -> post_edit_eval.json +commit experiment + +# Economics report at the end +report -> economics.json diff --git a/hugging/td_lang/examples/demo_td_loop.td b/hugging/td_lang/examples/demo_td_loop.td new file mode 100644 index 0000000000000000000000000000000000000000..4680509b554117683501adb91668b8849be948d0 --- /dev/null +++ b/hugging/td_lang/examples/demo_td_loop.td @@ -0,0 +1,44 @@ +# demo_td_loop.td — The complete TD self-improvement pipeline +# This is what td_loop runs: merge, then iterate to get smarter + +gate { must_pass = [canary, perplexity, thinking_mode] } +budget { max_gpu_hours = 50 max_cost = 200.00 } + +data_contract { + required_fields = [prompt, response] + min_samples = 50 + max_perplexity = 50.0 +} + +reward_contract { + verifiers = [code_compiles, math_correct] + min_reward = 0.3 +} + +# Step 1: Load base model +load "Qwen/Qwen3-VL-8B-Instruct" as base + +# Step 2: Fuse all donor models in one shot +fuse ["deepseek-ai/DeepSeek-R1", "MiMo-7B", "meta-llama/Llama-3.1-8B", "tiiuae/Falcon-H1R-7B"] into base + +# Step 3: Heal the merge damage +heal base lora_r 32 epochs 2 +snapshot base -> snapshots/ + +# Step 4: Self-improvement loop (the core of TD) +repeat 5 { + diagnose base -> weaknesses.json + synth base from base filter cherry_llm -> training_data.jsonl + train base on "training_data.jsonl" using grpo steps 64 lr 5e-5 + eval base -> eval_results.json + + if eval_passed base { + commit base + snapshot base -> snapshots/ + } else { + reset base to "snapshots/" + } +} + +# Step 5: Final report +report -> final_economics.json diff --git a/hugging/td_lang/examples/err_edit_unloaded.td b/hugging/td_lang/examples/err_edit_unloaded.td new file mode 100644 index 0000000000000000000000000000000000000000..54a50552f92306a898283a96ce61ffd15f6012fd --- /dev/null +++ b/hugging/td_lang/examples/err_edit_unloaded.td @@ -0,0 +1,2 @@ +# err_edit_unloaded.td — Should fail: editing a model before loading +edit ghost_model layers all using lora diff --git a/hugging/td_lang/examples/err_fork_duplicate.td b/hugging/td_lang/examples/err_fork_duplicate.td new file mode 100644 index 0000000000000000000000000000000000000000..a869d0195adc565e29dd759d5559b59f7a643eff --- /dev/null +++ b/hugging/td_lang/examples/err_fork_duplicate.td @@ -0,0 +1,3 @@ +# err_fork_duplicate.td — Should fail: duplicate name +load "test" as base +fork base as base diff --git a/hugging/td_lang/examples/err_prune_100.td b/hugging/td_lang/examples/err_prune_100.td new file mode 100644 index 0000000000000000000000000000000000000000..7d33f44ed907c4c9a43a8edf0d44e4ab7662e587 --- /dev/null +++ b/hugging/td_lang/examples/err_prune_100.td @@ -0,0 +1,4 @@ +# err_prune_100.td — Should fail/warn: prune at 100% +load "test" as base +prune base using wanda aggressiveness 1.0 +# Note: Compiler might cap it at 30% per implementation notes diff --git a/hugging/td_lang/examples/test_fork_edit.td b/hugging/td_lang/examples/test_fork_edit.td new file mode 100644 index 0000000000000000000000000000000000000000..bd359dd5809929015a79ed5cc4fa3176e4d24750 --- /dev/null +++ b/hugging/td_lang/examples/test_fork_edit.td @@ -0,0 +1,12 @@ +# test_fork_edit.td — Test load -> fork -> edit -> eval -> commit + +load "Qwen/Qwen3-VL-8B-Instruct" as base + +# Fork the base model +fork base as experimental_branch + +# Surgical edit with DoRA on specific layers +edit experimental_branch layers 20-28 using dora lr 1e-4 + +eval experimental_branch -> edit_report.json +commit experimental_branch diff --git a/hugging/td_lang/examples/test_fork_reset.td b/hugging/td_lang/examples/test_fork_reset.td new file mode 100644 index 0000000000000000000000000000000000000000..788124673988fe396f6541f4cfdf84e82fc06ef3 --- /dev/null +++ b/hugging/td_lang/examples/test_fork_reset.td @@ -0,0 +1,14 @@ +# test_fork_reset.td — Test fork -> edit -> eval -> reset + +load "Qwen/Qwen3-VL-8B-Instruct" as base + +# Create a checkpoint/fork +fork base as stable_fork + +# Try a risky edit +edit base layers all using lora lr 5e-4 + +eval base -> risky_eval.json + +# Revert base to the stable fork state +reset base to stable_fork diff --git a/hugging/td_lang/examples/test_phase2.td b/hugging/td_lang/examples/test_phase2.td new file mode 100644 index 0000000000000000000000000000000000000000..ad069f1f9ebb011c8bdb538af194f16943a95294 --- /dev/null +++ b/hugging/td_lang/examples/test_phase2.td @@ -0,0 +1,17 @@ +# test_phase2.td — Testing all Phase 2 commands +load "Qwen/Qwen3-VL-8B-Instruct" as base + +# diagnose base -> weaknesses.json — asks the model what it's bad at +diagnose base -> weaknesses.json + +# synth base from web_curated filter cherry_llm -> data.jsonl — generates training data +synth base from web_curated filter cherry_llm -> data.jsonl + +# train base on "data.jsonl" using grpo steps 64 — GRPO training +train base on "data.jsonl" using grpo steps 64 + +# debate base rounds 3 candidates 8 -> pairs.jsonl — persona debate for preference pairs +debate base rounds 3 candidates 8 -> pairs.jsonl + +eval base -> final_eval.json +commit base diff --git a/hugging/td_lang/examples/test_prune_heal.td b/hugging/td_lang/examples/test_prune_heal.td new file mode 100644 index 0000000000000000000000000000000000000000..18e58ebd814cfb223366c42964f6cbb85ef737c4 --- /dev/null +++ b/hugging/td_lang/examples/test_prune_heal.td @@ -0,0 +1,12 @@ +# test_prune_heal.td — Test load -> prune -> heal -> eval -> commit + +load "Qwen/Qwen3-VL-8B-Instruct" as base + +# Structural pruning at 15% using wanda +prune base using wanda aggressiveness 0.15 + +# Heal for recovery after pruning (LoRA r=8 is suggested) +heal base lora_r 8 epochs 1 + +eval base -> prune_recovery_report.json +commit base diff --git a/hugging/td_lang/executor.py b/hugging/td_lang/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..0b7e3b44e7c774fe37a9981910ee480211123863 --- /dev/null +++ b/hugging/td_lang/executor.py @@ -0,0 +1,206 @@ +""" +TD Lang Executor — Runs compiled .td scripts and tracks lineage. + +Two modes: + - compile: Parse .td -> generate .py file (no execution) + - run: Parse .td -> generate .py -> execute it + +All outputs go to td_lang_outputs/_/ + - compiled.py — The generated Python script + - lineage.json — What happened, in what order (artifact tracking) + +Pipeline: .td file -> Parser -> AST -> Compiler -> Python string -> **Executor** +""" + +import ast as python_ast +import hashlib +import json +import subprocess +import sys +from datetime import datetime +from pathlib import Path +from typing import Optional + +from .grammar import parse_td_file, parse_td_string +from .compiler import compile_program +from .ast_nodes import TDProgram +from .errors import TDCompileError, TDLangError + + +# ============================================================================ +# EXECUTOR +# ============================================================================ + +class TDExecutor: + """Execute td_lang programs — compile and optionally run. + + Usage: + executor = TDExecutor() + + # Compile only (check + generate .py) + py_path = executor.compile("demo.td") + + # Compile and run + result = executor.run("demo.td") + + # Just check syntax + executor.check("demo.td") + """ + + def __init__(self, output_dir: str = "td_lang_outputs"): + self.output_dir = Path(output_dir) + + def check(self, td_path: str) -> TDProgram: + """Parse and validate a .td file without compiling or running. + + Args: + td_path: Path to the .td file. + + Returns: + The parsed TDProgram. + + Raises: + TDSyntaxError: If syntax is invalid. + TDCompileError: If semantic validation fails. + """ + print(f"[td_lang] Checking {td_path}...") + program = parse_td_file(td_path) + + # Count what we found + n_commands = len(program.commands) + has_gates = program.gates is not None + has_budget = program.budget is not None + + print(f"[td_lang] OK — {n_commands} commands", end="") + if has_gates: + print(f", gates: {program.gates.must_pass}", end="") + if has_budget: + print(f", budget set", end="") + print() + + return program + + def compile(self, td_path: str) -> Path: + """Parse, validate, and compile a .td file into Python. + + Args: + td_path: Path to the .td file. + + Returns: + Path to the generated .py file. + + Raises: + TDSyntaxError: If syntax is invalid. + TDCompileError: If compilation fails. + """ + print(f"[td_lang] Compiling {td_path}...") + + # Parse + program = parse_td_file(td_path) + + # Compile + python_code = compile_program(program) + + # Validate the generated Python is valid syntax + try: + python_ast.parse(python_code) + except SyntaxError as e: + raise TDCompileError( + f"Generated Python has a syntax error (this is a td_lang bug): {e}", + hint="Please report this — the compiler generated bad code.", + ) from e + + # Save to output directory + out_dir = self._make_output_dir(td_path) + py_path = out_dir / "compiled.py" + py_path.write_text(python_code) + + # Save source hash for lineage + source_text = Path(td_path).read_text() + meta = { + "source_file": str(td_path), + "source_hash": hashlib.sha256(source_text.encode()).hexdigest(), + "compiled_at": datetime.now().isoformat(), + "td_lang_version": "0.2.0", + "python_file": str(py_path), + "n_commands": len(program.commands), + "has_gates": program.gates is not None, + "has_budget": program.budget is not None, + } + meta_path = out_dir / "compile_meta.json" + meta_path.write_text(json.dumps(meta, indent=2)) + + print(f"[td_lang] Compiled to: {py_path}") + return py_path + + def run(self, td_path: str, dry_run: bool = False) -> dict: + """Parse, compile, and execute a .td file. + + Args: + td_path: Path to the .td file. + dry_run: If True, compile but don't execute. + + Returns: + Dict with execution results. + """ + # Compile first + py_path = self.compile(td_path) + + if dry_run: + print("[td_lang] Dry run — compiled but not executed.") + return {"status": "dry_run", "compiled": str(py_path)} + + # Execute the generated Python script + print(f"[td_lang] Executing {py_path}...") + print() + + try: + result = subprocess.run( + [sys.executable, str(py_path)], + capture_output=False, # Let output stream to console + cwd=str(py_path.parent), # Run from output directory + ) + + if result.returncode == 0: + print() + print("[td_lang] Execution completed successfully.") + return {"status": "success", "compiled": str(py_path)} + else: + print() + print(f"[td_lang] Execution failed (exit code {result.returncode}).") + return { + "status": "failed", + "compiled": str(py_path), + "exit_code": result.returncode, + } + + except Exception as e: + print(f"\n[td_lang] Execution error: {e}") + return {"status": "error", "compiled": str(py_path), "error": str(e)} + + def _make_output_dir(self, td_path: str) -> Path: + """Create a timestamped output directory for this run.""" + name = Path(td_path).stem + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + out_dir = self.output_dir / f"{name}_{timestamp}" + out_dir.mkdir(parents=True, exist_ok=True) + return out_dir + + +# ============================================================================ +# PUBLIC API +# ============================================================================ + +def check_td_file(td_path: str) -> TDProgram: + """Quick syntax check on a .td file.""" + return TDExecutor().check(td_path) + + +def compile_td_file(td_path: str, output_dir: str = "td_lang_outputs") -> Path: + """Compile a .td file to Python.""" + return TDExecutor(output_dir=output_dir).compile(td_path) + + +def run_td_file(td_path: str, output_dir: str = "td_lang_outputs", dry_run: bool = False) -> dict: + """Compile and run a .td file.""" + return TDExecutor(output_dir=output_dir).run(td_path, dry_run=dry_run) diff --git a/hugging/td_lang/grammar.py b/hugging/td_lang/grammar.py new file mode 100644 index 0000000000000000000000000000000000000000..da57533cdad3a33067406355dde313d6e26dfad6 --- /dev/null +++ b/hugging/td_lang/grammar.py @@ -0,0 +1,736 @@ +""" +TD Lang Grammar — Lark parser for .td files. + +Defines the syntax for Phase 1 commands (load, merge, heal, eval, commit) +plus gate/budget blocks. Phase 2 commands are parsed into stub nodes so the +compiler can reject them with a clear error until implemented. +""" + +from lark import Lark, Token, Transformer, UnexpectedInput, v_args + +from .ast_nodes import ( + AbsorbCmd, + BudgetBlock, + CommitCmd, + DataContractBlock, + DebateCmd, + DiagnoseCmd, + EditCmd, + EvalCmd, + FuseCmd, + ForkCmd, + GateBlock, + HealCmd, + IfBlock, + LoadCmd, + MergeCmd, + NotifyCmd, + OnErrorBlock, + PruneCmd, + RepeatBlock, + ReportCmd, + ResetCmd, + RewardContractBlock, + SaveCmd, + SetupBlock, + SnapshotCmd, + SynthCmd, + TDProgram, + TrainCmd, +) +from .errors import TDSyntaxError + + +# ============================================================================ +# LARK GRAMMAR DEFINITION +# ============================================================================ + +TD_GRAMMAR = r""" + // TD Lang Grammar v0.1.0 + // One command per line, blocks with curly braces, comments with # + + start: (_NL* statement _NL*)* _NL* + + ?statement: load_cmd + | merge_cmd + | heal_cmd + | eval_cmd + | commit_cmd + | synth_cmd + | train_cmd + | debate_cmd + | diagnose_cmd + | fork_cmd + | reset_cmd + | prune_cmd + | edit_cmd + | fuse_cmd + | absorb_cmd + | repeat_block_cmd + | if_block_cmd + | snapshot_cmd + | report_cmd + | notify_cmd + | save_cmd + | gate_block + | budget_block + | data_contract_block + | reward_contract_block + | setup_block + | on_error_block + + // ======================== PHASE 1 COMMANDS ======================== + + // load "model/path" as alias + load_cmd: "load" string "as" IDENT + + // merge "source" into target using method [strength 0.5] + merge_cmd: "merge" string "into" IDENT "using" IDENT (merge_strength)? + merge_strength: "strength" NUMBER + + // heal target [lora_r 32] [epochs 2] + heal_cmd: "heal" IDENT (heal_opt)* + heal_opt: "lora_r" INT -> heal_lora_r + | "epochs" INT -> heal_epochs + + // eval target [on "dataset"] [-> output.json] + eval_cmd: "eval" IDENT (eval_on)? (eval_output)? + eval_on: "on" string + eval_output: "->" FILEPATH + + // commit target [if [gate1, gate2, gate3]] + commit_cmd: "commit" IDENT (commit_gates)? + commit_gates: "if" name_list + + // ======================== PHASE 2 COMMANDS ======================== + // (parsed but not compiled yet — will show "not implemented" message) + + // synth target from source [filter cherry_llm] [-> output.jsonl] + synth_cmd: "synth" IDENT "from" IDENT (synth_filter)? (synth_output)? + synth_filter: "filter" IDENT + synth_output: "->" FILEPATH + + // train target on "dataset" using method [steps 100] [lr 0.0001] + train_cmd: "train" IDENT "on" string "using" IDENT (train_opt)* + train_opt: "steps" INT -> train_steps + | "lr" NUMBER -> train_lr + + // debate target rounds 3 candidates 8 [-> output.jsonl] + debate_cmd: "debate" IDENT "rounds" INT "candidates" INT (debate_output)? + debate_output: "->" FILEPATH + + // diagnose target [-> weaknesses.json] + diagnose_cmd: "diagnose" IDENT (diagnose_output)? + diagnose_output: "->" FILEPATH + + // fork source as alias + fork_cmd: "fork" IDENT "as" IDENT + + // reset target to checkpoint_name + reset_cmd: "reset" IDENT "to" (string | IDENT) + + // prune target using method [aggressiveness 0.1] + prune_cmd: "prune" IDENT "using" IDENT (prune_aggr)? + prune_aggr: "aggressiveness" NUMBER + + // edit target layers 16-28 using lora [lr 0.0001] + edit_cmd: "edit" IDENT "layers" LAYER_SPEC "using" IDENT (edit_lr)? + edit_lr: "lr" NUMBER + + // ======================== PHASE 7 — LOOP CONTROL ======================== + + // repeat N { commands... } + repeat_block_cmd: "repeat" INT "{" _NL* body_cmd+ _NL* "}" + // if condition target { commands... } [else { commands... }] + if_block_cmd: "if" IDENT IDENT "{" _NL* body_cmd+ _NL* "}" (else_clause)? + else_clause: "else" "{" _NL* body_cmd+ _NL* "}" + + // Commands allowed inside blocks (same as top-level minus config blocks) + ?body_cmd: (load_cmd | merge_cmd | heal_cmd | eval_cmd | commit_cmd + | synth_cmd | train_cmd | debate_cmd | diagnose_cmd + | fork_cmd | reset_cmd | prune_cmd | edit_cmd + | fuse_cmd | absorb_cmd | snapshot_cmd | report_cmd + | notify_cmd | save_cmd + | repeat_block_cmd | if_block_cmd) _NL* + + // ======================== PHASE 6 — EASY MERGE COMMANDS ======================== + + // fuse [model1, model2, model3] into target [using method] [strategy equal|weighted|sequential] + fuse_cmd: "fuse" model_list "into" IDENT (fuse_method)? (fuse_strategy)? + model_list: "[" string ("," string)* "]" + fuse_method: "using" IDENT + fuse_strategy: "strategy" IDENT + + // absorb "model" into target [strength 0.5] + absorb_cmd: "absorb" string "into" IDENT (absorb_strength)? + absorb_strength: "strength" NUMBER + + // ======================== PHASE 4 COMMANDS ======================== + + // snapshot target [-> output_dir] + snapshot_cmd: "snapshot" IDENT (snapshot_output)? + snapshot_output: "->" FILEPATH + + // report [-> economics.json] + report_cmd: "report" (report_output)? + report_output: "->" FILEPATH + + // ======================== BLOCKS ======================== + + // gate { must_pass = [canary, perplexity, thinking_mode] } + gate_block: "gate" "{" _NL* gate_field+ _NL* "}" + gate_field: "must_pass" "=" name_list _NL* + + // budget { max_gpu_hours = 8 \n max_cost = 50.00 } + budget_block: "budget" "{" _NL* budget_field+ _NL* "}" + budget_field: (budget_gpu | budget_cost | budget_tokens | budget_experiments) _NL* + budget_gpu: "max_gpu_hours" "=" NUMBER + budget_cost: "max_cost" "=" NUMBER + budget_tokens: "max_tokens" "=" INT + budget_experiments: "max_experiments" "=" INT + + // data_contract { required_fields = [prompt, response] \n min_samples = 100 \n max_perplexity = 50.0 } + data_contract_block: "data_contract" "{" _NL* dc_field+ _NL* "}" + dc_field: (dc_required | dc_min_samples | dc_max_ppl) _NL* + dc_required: "required_fields" "=" name_list + dc_min_samples: "min_samples" "=" INT + dc_max_ppl: "max_perplexity" "=" NUMBER + + // reward_contract { verifiers = [code_compiles, math_correct] \n min_reward = 0.3 } + reward_contract_block: "reward_contract" "{" _NL* rc_field+ _NL* "}" + rc_field: (rc_verifiers | rc_min_reward) _NL* + rc_verifiers: "verifiers" "=" name_list + rc_min_reward: "min_reward" "=" NUMBER + + // ======================== PHASE 8 — AUTOPILOT ======================== + + // notify "Training complete!" + notify_cmd: "notify" string + + // save target to "gdrive:TD/models/v1" + save_cmd: "save" IDENT "to" string + + // setup { pip = [torch, transformers] hf_token = env notify = "ntfy.sh/my_ai" } + setup_block: "setup" "{" _NL* setup_field+ _NL* "}" + setup_field: (setup_pip | setup_hf | setup_notify) _NL* + setup_pip: "pip" "=" name_list + setup_hf: "hf_token" "=" IDENT + setup_notify: "notify" "=" string + + // on_error { retry = 3 fallback = reduce_batch notify = true } + on_error_block: "on_error" "{" _NL* on_error_field+ _NL* "}" + on_error_field: (onerr_retry | onerr_fallback | onerr_notify) _NL* + onerr_retry: "retry" "=" INT + onerr_fallback: "fallback" "=" IDENT + onerr_notify: "notify" "=" IDENT + + // ======================== SHARED RULES ======================== + + // List of names: [name1, name2, name3] + name_list: "[" IDENT ("," IDENT)* "]" + + // String: double-quoted + string: ESCAPED_STRING + + // Layer spec: "all", single number, or range like "16-28" + LAYER_SPEC: /all|[0-9]+-[0-9]+|[0-9]+/ + + // Filepath: word with dots, slashes, underscores (no spaces) + FILEPATH: /[a-zA-Z0-9_.\-\/]+/ + + // Identifier: letters, numbers, underscores, hyphens (but starts with letter/underscore) + IDENT: /[a-zA-Z_][a-zA-Z0-9_\-]*/ + + // Numbers + NUMBER: /\d+\.?\d*([eE][+-]?\d+)?/ + INT: /\d+/ + + // Whitespace and comments + _NL: /\s*/ NEWLINE /\s*/ + COMMENT: /#[^\n]*/ + %import common.ESCAPED_STRING + %import common.NEWLINE + %import common.WS_INLINE + %ignore WS_INLINE + %ignore COMMENT +""" + + +# ============================================================================ +# LARK TRANSFORMER — Parse Tree → AST Nodes +# ============================================================================ + +@v_args(inline=True) +class TDTransformer(Transformer): + """Transforms Lark parse tree into td_lang AST nodes. + + Each method matches a grammar rule name and returns the corresponding + dataclass from ast_nodes.py. + """ + + # --- Helpers --- + + def string(self, s: Token) -> str: + """Strip quotes from a string token.""" + return str(s)[1:-1] + + def name_list(self, *names: Token) -> list[str]: + """Convert name list tokens to Python list of strings.""" + return [str(n) for n in names] + + def IDENT(self, token: Token) -> str: + return str(token) + + def INT(self, token: Token) -> int: + return int(token) + + def NUMBER(self, token: Token) -> float: + return float(token) + + def FILEPATH(self, token: Token) -> str: + return str(token) + + def LAYER_SPEC(self, token: Token) -> str: + return str(token) + + # --- Phase 1 Commands --- + + def load_cmd(self, model_ref: str, alias: str) -> LoadCmd: + return LoadCmd(model_ref=model_ref, alias=alias) + + def merge_cmd(self, source: str, target: str, method: str, + strength: float | None = None) -> MergeCmd: + return MergeCmd( + source=source, + target=target, + method=method, + strength=strength if strength is not None else 0.5, + ) + + def merge_strength(self, value: float) -> float: + return value + + def heal_cmd(self, target: str, *opts) -> HealCmd: + cmd = HealCmd(target=target) + for opt in opts: + if isinstance(opt, tuple): + key, val = opt + if key == "lora_r": + cmd.lora_r = val + elif key == "epochs": + cmd.epochs = val + return cmd + + def heal_lora_r(self, value: int) -> tuple: + return ("lora_r", value) + + def heal_epochs(self, value: int) -> tuple: + return ("epochs", value) + + def eval_cmd(self, target: str, *opts) -> EvalCmd: + cmd = EvalCmd(target=target) + for opt in opts: + if isinstance(opt, tuple): + key, val = opt + if key == "on": + cmd.dataset = val + elif key == "output": + cmd.output = val + return cmd + + def eval_on(self, dataset: str) -> tuple: + return ("on", dataset) + + def eval_output(self, filepath: str) -> tuple: + return ("output", filepath) + + def commit_cmd(self, target: str, gates: list[str] | None = None) -> CommitCmd: + return CommitCmd(target=target, gates=gates) + + def commit_gates(self, gates: list[str]) -> list[str]: + return gates + + # --- Phase 2 Commands --- + + def synth_cmd(self, target: str, source: str, *opts) -> SynthCmd: + cmd = SynthCmd(target=target, source=source) + for opt in opts: + if isinstance(opt, tuple): + key, val = opt + if key == "filter": + cmd.filter_method = val + elif key == "output": + cmd.output = val + return cmd + + def synth_filter(self, method: str) -> tuple: + return ("filter", method) + + def synth_output(self, filepath: str) -> tuple: + return ("output", filepath) + + def train_cmd(self, target: str, dataset: str, method: str, *opts) -> TrainCmd: + cmd = TrainCmd(target=target, dataset=dataset, method=method) + for opt in opts: + if isinstance(opt, tuple): + key, val = opt + if key == "steps": + cmd.steps = val + elif key == "lr": + cmd.learning_rate = val + return cmd + + def train_steps(self, value: int) -> tuple: + return ("steps", value) + + def train_lr(self, value: float) -> tuple: + return ("lr", value) + + def debate_cmd(self, target: str, rounds: int, candidates: int, + output: tuple | None = None) -> DebateCmd: + cmd = DebateCmd(target=target, rounds=rounds, candidates=candidates) + if isinstance(output, tuple) and output[0] == "output": + cmd.output = output[1] + return cmd + + def debate_output(self, filepath: str) -> tuple: + return ("output", filepath) + + def diagnose_cmd(self, target: str, output: tuple | None = None) -> DiagnoseCmd: + cmd = DiagnoseCmd(target=target) + if isinstance(output, tuple) and output[0] == "output": + cmd.output = output[1] + return cmd + + def diagnose_output(self, filepath: str) -> tuple: + return ("output", filepath) + + def fork_cmd(self, source: str, alias: str) -> ForkCmd: + return ForkCmd(source=source, alias=alias) + + def reset_cmd(self, target: str, checkpoint) -> ResetCmd: + return ResetCmd(target=target, checkpoint=str(checkpoint)) + + def prune_cmd(self, target: str, method: str, + aggressiveness: float | None = None) -> PruneCmd: + return PruneCmd( + target=target, + method=method, + aggressiveness=aggressiveness if aggressiveness is not None else 0.1, + ) + + def prune_aggr(self, value: float) -> float: + return value + + def edit_cmd(self, target: str, layers: str, method: str, + lr: float | None = None) -> EditCmd: + return EditCmd( + target=target, + layers=layers, + method=method, + learning_rate=lr, + ) + + def edit_lr(self, value: float) -> float: + return value + + # --- Phase 7: Loop Control --- + + def repeat_block_cmd(self, count: int, *body_cmds) -> RepeatBlock: + return RepeatBlock(count=count, body=list(body_cmds)) + + def if_block_cmd(self, condition: str, target: str, *rest) -> IfBlock: + """Parse if condition target { then... } [else { else... }]""" + block = IfBlock(condition=condition, target=target) + # rest contains then_body commands + possibly an else list + for item in rest: + if isinstance(item, list) and item and hasattr(item, '__iter__'): + # This is the else body (passed from else_clause) + block.else_body = item + else: + block.then_body.append(item) + return block + + def else_clause(self, *body_cmds) -> list: + return list(body_cmds) + + # --- Phase 6: Easy Merge Commands --- + + def fuse_cmd(self, sources: list[str], target: str, *opts) -> FuseCmd: + cmd = FuseCmd(sources=sources, target=target) + for opt in opts: + if isinstance(opt, tuple): + key, val = opt + if key == "method": + cmd.method = val + elif key == "strategy": + cmd.strategy = val + return cmd + + def model_list(self, *models: str) -> list[str]: + return [str(m) for m in models] + + def fuse_method(self, method: str) -> tuple: + return ("method", method) + + def fuse_strategy(self, strategy: str) -> tuple: + return ("strategy", strategy) + + def absorb_cmd(self, source: str, target: str, + strength: float | None = None) -> AbsorbCmd: + return AbsorbCmd( + source=source, + target=target, + strength=strength if strength is not None else 0.5, + ) + + def absorb_strength(self, value: float) -> float: + return value + + # --- Phase 4 Commands --- + + def snapshot_cmd(self, target: str, output: tuple | None = None) -> SnapshotCmd: + cmd = SnapshotCmd(target=target) + if isinstance(output, tuple) and output[0] == "output": + cmd.output = output[1] + return cmd + + def snapshot_output(self, filepath: str) -> tuple: + return ("output", filepath) + + def report_cmd(self, output: tuple | None = None) -> ReportCmd: + cmd = ReportCmd() + if isinstance(output, tuple) and output[0] == "output": + cmd.output = output[1] + return cmd + + def report_output(self, filepath: str) -> tuple: + return ("output", filepath) + + # --- Blocks --- + + def gate_block(self, *fields) -> GateBlock: + gate = GateBlock() + for f in fields: + if isinstance(f, list): + gate.must_pass = f + return gate + + def gate_field(self, names: list[str]) -> list[str]: + return names + + def budget_block(self, *fields) -> BudgetBlock: + budget = BudgetBlock() + for f in fields: + if isinstance(f, tuple): + key, val = f + if key == "max_gpu_hours": + budget.max_gpu_hours = val + elif key == "max_cost": + budget.max_cost = val + elif key == "max_tokens": + budget.max_tokens = int(val) + elif key == "max_experiments": + budget.max_experiments = int(val) + return budget + + def budget_field(self, field_data) -> tuple: + return field_data + + def budget_gpu(self, value: float) -> tuple: + return ("max_gpu_hours", value) + + def budget_cost(self, value: float) -> tuple: + return ("max_cost", value) + + def budget_tokens(self, value: int) -> tuple: + return ("max_tokens", value) + + def budget_experiments(self, value: int) -> tuple: + return ("max_experiments", value) + + # --- Phase 8: Autopilot Commands --- + + def notify_cmd(self, message: str) -> NotifyCmd: + return NotifyCmd(message=message) + + def save_cmd(self, target: str, destination: str) -> SaveCmd: + return SaveCmd(target=target, destination=destination) + + def setup_block(self, *fields) -> SetupBlock: + sb = SetupBlock() + for f in fields: + if isinstance(f, tuple): + key, val = f + if key == "pip": + sb.pip_packages = val + elif key == "hf_token": + sb.hf_token = val + elif key == "notify": + sb.notify_url = val + return sb + + def setup_field(self, field_data) -> tuple: + return field_data + + def setup_pip(self, packages: list[str]) -> tuple: + return ("pip", packages) + + def setup_hf(self, mode: str) -> tuple: + return ("hf_token", mode) + + def setup_notify(self, url: str) -> tuple: + return ("notify", url) + + def on_error_block(self, *fields) -> OnErrorBlock: + oe = OnErrorBlock() + for f in fields: + if isinstance(f, tuple): + key, val = f + if key == "retry": + oe.retry = int(val) + elif key == "fallback": + oe.fallback = val + elif key == "notify": + oe.notify = str(val).lower() == "true" + return oe + + def on_error_field(self, field_data) -> tuple: + return field_data + + def onerr_retry(self, value: int) -> tuple: + return ("retry", value) + + def onerr_fallback(self, value: str) -> tuple: + return ("fallback", value) + + def onerr_notify(self, value: str) -> tuple: + return ("notify", value) + + # --- Contract Blocks (Phase 4) --- + + def data_contract_block(self, *fields) -> DataContractBlock: + dc = DataContractBlock() + for f in fields: + if isinstance(f, tuple): + key, val = f + if key == "required_fields": + dc.required_fields = val + elif key == "min_samples": + dc.min_samples = int(val) + elif key == "max_perplexity": + dc.max_perplexity = val + return dc + + def dc_field(self, field_data) -> tuple: + return field_data + + def dc_required(self, names: list[str]) -> tuple: + return ("required_fields", names) + + def dc_min_samples(self, value: int) -> tuple: + return ("min_samples", value) + + def dc_max_ppl(self, value: float) -> tuple: + return ("max_perplexity", value) + + def reward_contract_block(self, *fields) -> RewardContractBlock: + rc = RewardContractBlock() + for f in fields: + if isinstance(f, tuple): + key, val = f + if key == "verifiers": + rc.verifiers = val + elif key == "min_reward": + rc.min_reward = val + return rc + + def rc_field(self, field_data) -> tuple: + return field_data + + def rc_verifiers(self, names: list[str]) -> tuple: + return ("verifiers", names) + + def rc_min_reward(self, value: float) -> tuple: + return ("min_reward", value) + + # --- Top Level --- + + def start(self, *items) -> TDProgram: + """Collect all parsed commands and blocks into a TDProgram.""" + program = TDProgram() + for item in items: + if item is None: + continue + if isinstance(item, GateBlock): + program.gates = item + elif isinstance(item, BudgetBlock): + program.budget = item + elif isinstance(item, DataContractBlock): + program.data_contract = item + elif isinstance(item, RewardContractBlock): + program.reward_contract = item + elif isinstance(item, SetupBlock): + program.setup = item + elif isinstance(item, OnErrorBlock): + program.on_error = item + else: + program.commands.append(item) + return program + + +# ============================================================================ +# PUBLIC API +# ============================================================================ + +# Create the parser once — reuse for all files +_parser = Lark( + TD_GRAMMAR, + parser="earley", + propagate_positions=True, +) + +_transformer = TDTransformer() + + +def parse_td_string(source: str) -> TDProgram: + """Parse a .td source string into a TDProgram AST. + + Args: + source: The .td file content as a string. + + Returns: + TDProgram with all commands and blocks. + + Raises: + TDSyntaxError: If the source has invalid syntax. + """ + try: + tree = _parser.parse(source) + return _transformer.transform(tree) + except UnexpectedInput as e: + raise TDSyntaxError( + message=f"Unexpected {e.token!r}" if hasattr(e, "token") else str(e), + line=getattr(e, "line", None), + hint="Check for typos or missing quotes around model paths.", + ) from e + + +def parse_td_file(filepath: str) -> TDProgram: + """Parse a .td file into a TDProgram AST. + + Args: + filepath: Path to the .td file. + + Returns: + TDProgram with all commands and blocks. + + Raises: + TDSyntaxError: If the file has invalid syntax. + FileNotFoundError: If the file doesn't exist. + """ + with open(filepath, "r") as f: + source = f.read() + program = parse_td_string(source) + program.source_file = filepath + return program diff --git a/hugging/upload_to_hf.sh b/hugging/upload_to_hf.sh new file mode 100644 index 0000000000000000000000000000000000000000..16c512a2df066fe1c8851fbac23070449d76f073 --- /dev/null +++ b/hugging/upload_to_hf.sh @@ -0,0 +1,137 @@ +#!/bin/bash +# upload_to_hf.sh — Push td_lang + td_fuse + .td files to a PRIVATE HuggingFace repo +# +# First time setup: +# 1. Go to huggingface.co → Sign up (use any username you want, not your real name) +# 2. Settings → Access Tokens → New Token → "Write" access +# 3. Run: pip install huggingface_hub +# 4. Run: huggingface-cli login (paste your token) +# 5. Run: bash upload_to_hf.sh +# +# After first time, just run: bash upload_to_hf.sh +# It updates the repo with any changes. + +set -e + +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +RED='\033[0;31m' +NC='\033[0m' + +# ---- CHANGE THIS to your HuggingFace username ---- +HF_USER="${HF_USER:-your_username_here}" +REPO_NAME="td-toolkit" +# --------------------------------------------------- + +REPO_ID="$HF_USER/$REPO_NAME" + +echo "" +echo "===========================================" +echo " TD Upload — Push to Private HuggingFace" +echo "===========================================" +echo "" + +# Check if huggingface-cli is available +if ! command -v huggingface-cli &> /dev/null; then + echo -e "${YELLOW}Installing huggingface_hub...${NC}" + pip install huggingface_hub --quiet 2>/dev/null || pip install huggingface_hub --break-system-packages --quiet +fi + +# Check login +if ! huggingface-cli whoami &> /dev/null 2>&1; then + echo -e "${YELLOW}Not logged in to HuggingFace.${NC}" + echo "Run: huggingface-cli login" + exit 1 +fi + +CURRENT_USER=$(huggingface-cli whoami 2>/dev/null | head -1) +echo -e "Logged in as: ${GREEN}$CURRENT_USER${NC}" + +if [ "$HF_USER" = "your_username_here" ]; then + echo "" + echo -e "${YELLOW}You need to set your HF username first.${NC}" + echo "Either:" + echo " 1. Edit this script and change HF_USER at the top" + echo " 2. Or run: HF_USER=$CURRENT_USER bash upload_to_hf.sh" + exit 1 +fi + +echo "Repo: $REPO_ID (private)" +echo "" + +# Create the repo if it doesn't exist (private by default) +echo -e "${GREEN}[1/3]${NC} Creating/checking repo..." +python3 -c " +from huggingface_hub import HfApi +api = HfApi() +try: + api.create_repo('$REPO_ID', repo_type='model', private=True) + print(' Created new private repo: $REPO_ID') +except Exception as e: + if 'already' in str(e).lower() or '409' in str(e): + print(' Repo exists: $REPO_ID') + else: + raise e +" + +# Upload all the important files +echo -e "${GREEN}[2/3]${NC} Uploading files..." +python3 << 'PYEOF' +from huggingface_hub import HfApi +import os + +api = HfApi() +repo_id = os.environ.get("REPO_ID", "REPO_ID_MISSING") + +# Files and folders to upload +uploads = [] + +# td_lang package +for root, dirs, files in os.walk("td_lang"): + # Skip __pycache__ + if "__pycache__" in root: + continue + for f in files: + if f.endswith((".py", ".td", ".lark", ".md")): + local = os.path.join(root, f) + uploads.append(local) + +# td_fuse package (if it exists) +if os.path.isdir("td_fuse"): + for root, dirs, files in os.walk("td_fuse"): + if "__pycache__" in root: + continue + for f in files: + if f.endswith((".py", ".td", ".lark", ".md", ".json")): + local = os.path.join(root, f) + uploads.append(local) + +# Root-level .td files and deploy script +for f in os.listdir("."): + if f.endswith(".td") or f in ("deploy.sh", "QUICKSTART.md"): + uploads.append(f) + +print(f" Uploading {len(uploads)} files...") +for path in sorted(uploads): + print(f" {path}") + api.upload_file( + path_or_fileobj=path, + path_in_repo=path, + repo_id=repo_id, + repo_type="model", + ) + +print(f"\n Done! {len(uploads)} files uploaded.") +PYEOF + +echo "" +echo -e "${GREEN}[3/3]${NC} Upload complete!" +echo "" +echo "===========================================" +echo " Your private repo: https://huggingface.co/$REPO_ID" +echo "" +echo " On any GPU, download everything with:" +echo " export HF_TOKEN=your_token" +echo " git clone https://\$HF_TOKEN@huggingface.co/$REPO_ID td_toolkit" +echo " cd td_toolkit && bash deploy.sh demo_autopilot.td" +echo "==========================================="