td-toolkit / td_fuse /config.py
td-builder's picture
Current td_fuse code with all fixes
bc446a5 verified
"""
TD Fuse Configuration — All 5 models, merge order, hyperparameters.
Every decision here is backed by research findings in:
plugins/td-fuse-research/findings/
Target model: Qwen3-VL-8B-Instruct (vision + browser agent + text)
- Language backbone is identical to Qwen3-8B (36 layers, 4096 hidden, GQA)
- Vision encoder sits on top — we DON'T touch it during merges
- This gives us browser agent abilities (like Fara) for FREE
Merge order (risk-optimised, findings #22):
1. DeepSeek-R1-0528 → Qwen3-VL-8B (same arch, LOW risk)
2. MiMo-7B-RL → Merged_1 (drop MTP, MEDIUM risk)
3. Llama-3.1-8B → Merged_2 (skip embeddings, MEDIUM risk)
4. Falcon-H1R-7B → Merged_3 (SSM hybrid, HIGH risk)
"""
from dataclasses import dataclass, field
from typing import Optional
from pathlib import Path
# ============================================================================
# MODEL DEFINITIONS
# ============================================================================
@dataclass
class ModelConfig:
"""Configuration for a single model in the merge pipeline."""
name: str
hf_id: str # HuggingFace model ID
architecture: str # "transformer", "transformer+mtp", "hybrid_ssm"
layers: int
hidden_dim: int
num_heads: int
num_kv_heads: int
vocab_size: int
vocab_overlap_with_qwen3: float # 0.0 to 1.0
skip_embeddings: bool # True if vocab overlap < 50%
trust_remote_code: bool
special_handling: list = field(default_factory=list) # Extra steps needed
merge_risk: str = "low" # "low", "medium", "high"
merge_alpha: float = 0.5 # Weight during fusion (0=keep target, 1=keep source)
notes: str = ""
# Target model — everything merges INTO this
# Switched from Qwen3-8B to Qwen3-VL-8B: same language brain, plus vision + browser agent
TARGET = ModelConfig(
name="Qwen3-VL-8B",
hf_id="Qwen/Qwen3-VL-8B-Instruct",
architecture="transformer+vision",
layers=36, # Language backbone: same 36 layers as Qwen3-8B
hidden_dim=4096, # Same as Qwen3-8B
num_heads=32, # Same as Qwen3-8B
num_kv_heads=8, # GQA, same as Qwen3-8B
vocab_size=151936, # Slightly different from Qwen3-8B (151669)
vocab_overlap_with_qwen3=0.998, # ~99.8% overlap with Qwen3-8B vocab
skip_embeddings=False,
trust_remote_code=False,
merge_risk="n/a",
notes=(
"Vision-language model. Language backbone is identical to Qwen3-8B. "
"Vision encoder (ViT + DeepStack) sits on top — we SKIP it during merges. "
"This gives us browser agent + vision abilities for free. "
"Uses SDPA (NOT Flash-Attention-2). "
"intermediate_size=12288. Loaded via Qwen3VLForConditionalGeneration."
),
)
# Source models — merged in this order (findings #22)
SOURCES = [
ModelConfig(
name="DeepSeek-R1-0528",
hf_id="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
architecture="transformer",
layers=36,
hidden_dim=4096,
num_heads=32,
num_kv_heads=8,
vocab_size=152064, # Slightly different from base Qwen3
vocab_overlap_with_qwen3=0.999, # 99.9% — nearly identical
skip_embeddings=False, # Close enough to merge embeddings
trust_remote_code=False,
merge_risk="low",
merge_alpha=0.5,
special_handling=["use_deepseek_tokenizer_config"],
notes=(
"IDENTICAL architecture to Qwen3-8B. Easiest merge. "
"Must use DeepSeek's tokenizer config, not Qwen's. "
"Stay bfloat16 end-to-end (FP8 degrades quality). "
"Set repetition_penalty=1.5 (R1 distills are prone to repetition). "
"Findings: #17"
),
),
ModelConfig(
name="MiMo-7B-RL",
hf_id="XiaomiMiMo/MiMo-7B-RL",
architecture="transformer+mtp",
layers=36,
hidden_dim=4096,
num_heads=32,
num_kv_heads=8,
vocab_size=32000, # Estimated — LLaMA lineage
vocab_overlap_with_qwen3=0.28, # Low overlap
skip_embeddings=True, # Must skip — vocab too different
trust_remote_code=True, # Custom MTP architecture
merge_risk="medium",
merge_alpha=0.15, # Low — MiMo neurons need permutation, keep target dominant
special_handling=["drop_mtp_heads", "skip_embeddings"],
notes=(
"Xiaomi's reasoning model. Same layer count and hidden dim as Qwen3. "
"MTP heads (mtp_head_0/1/2) have NO Qwen3 equivalent — must drop. "
"trust_remote_code=True required for custom modeling_mimo.py. "
"Findings: #18"
),
),
ModelConfig(
name="Llama-3.1-8B",
hf_id="unsloth/Llama-3.1-8B-Instruct",
architecture="transformer",
layers=32, # 4 fewer than Qwen3!
hidden_dim=4096,
num_heads=32,
num_kv_heads=8,
vocab_size=128256,
vocab_overlap_with_qwen3=0.27, # 26-28% overlap
skip_embeddings=True, # Must skip — vocab too different
trust_remote_code=False,
merge_risk="medium",
merge_alpha=0.08, # Lower alpha — layer mismatch risk
special_handling=["skip_embeddings", "drop_qkv_bias", "layer_mapping_32_to_36"],
notes=(
"32 layers vs 36 — T&M's P matrix handles layer mapping. "
"FFN intermediate is 14336 vs 22016 — Q matrices handle width. "
"Has QKV bias (Qwen3 doesn't) — bias params will be dropped. "
"T&M paper was tested on LLaMA-3 8B — good sign. "
"Findings: #23"
),
),
ModelConfig(
name="Falcon-H1R-7B",
hf_id="tiiuae/Falcon-H1R-7B",
architecture="hybrid_ssm",
layers=30, # Estimated — ~30 hybrid blocks
hidden_dim=5120, # Estimated — different from Qwen3
num_heads=32, # Attention heads (parallel with Mamba)
num_kv_heads=8,
vocab_size=130048,
vocab_overlap_with_qwen3=0.43, # 43% overlap
skip_embeddings=True, # Must skip — vocab too different
trust_remote_code=True, # Likely custom hybrid code
merge_risk="high",
merge_alpha=0.08, # Conservative — highest risk model
special_handling=[
"skip_embeddings",
"drop_mamba_state_params", # A, D matrices have no Qwen3 equivalent
"check_wasserstein_first", # Abort if activation alignment is poor
"distillation_fallback", # If merge fails, use knowledge distillation
],
notes=(
"THE WILDCARD. Hybrid Transformer+Mamba2. ~60% of weights have "
"Qwen3 equivalents. Mamba components (A, D, dt_proj) must be "
"dropped or mapped via OT. 65-70% merge feasibility. "
"88.1% AIME24 makes it worth attempting. "
"Fallback: knowledge distillation (NeurIPS 2024 'Mamba in Llama'). "
"Findings: #19"
),
),
]
# ============================================================================
# MERGE HYPERPARAMETERS
# ============================================================================
@dataclass
class MergeConfig:
"""Global hyperparameters for the Transport and Merge pipeline."""
# --- Paths ---
tm_repo_path: str = "./Cross-Architecture-Merging-for-Large-Language-Models"
output_dir: str = "./td_fuse_outputs"
checkpoint_dir: str = "./td_fuse_checkpoints"
# --- Calibration Data (findings #08) ---
calibration_samples: int = 1500 # 600 Pile general + 300 ArXiv + 600 neuralmagic
calibration_seq_len: int = 512
calibration_dataset_pile: str = "EleutherAI/pile"
calibration_dataset_nm: str = "neuralmagic/LLM_compression_calibration"
# --- Transport and Merge (findings #01, #24) ---
sinkhorn_reg: float = 0.05 # Entropic regularisation for Sinkhorn
sinkhorn_max_iter: int = 100 # Max Sinkhorn iterations
correlation_distance: bool = True # True=correlation (official), False=euclidean
streaming_sinkhorn: bool = True # Memory-efficient streaming mode
# --- TIES Parameters (findings #05, #14) ---
ties_density: float = 0.7 # k=0.7 (NOT default 0.2 — community finding)
ties_alpha: float = 0.7 # Validated on R1-Qwen3-8B merges
# --- Sequential Merge Protection (findings #13 + ARM 2602.03237 + OTMF 2511.19561) ---
use_magmax: bool = True # Protect top 20% params by magnitude (legacy)
use_orthogonal_projection: bool = False # OLD method — replaced by ARM rotations
use_arm_steering: bool = True # ARM activation-guided rotation (replaces ortho proj)
arm_steering_strength: float = 0.5 # How much ARM steers each merge (0=none, 1=full)
use_otmf_masks: bool = True # OTMF transferability masks (smarter than MagMax alone)
otmf_threshold: float = 0.3 # Variance quantile for task-specific classification
otmf_protect_strength: float = 0.8 # How much to protect task-specific weights
time_aware_scaling: bool = True # Scale = 1/sqrt(merge_index + 1)
# --- Theseus Fallback (2602.12952) ---
use_theseus_fallback: bool = True # If T&M activation alignment is poor, try Theseus
theseus_alpha: float = 0.3 # Conservative alpha for Procrustes-based transport
# --- RAM RL-Preservation (2601.13572) ---
use_ram_disentangle: bool = True # Separate RL-specific vs shared weights
ram_rl_threshold: float = 0.1 # Relative change threshold for RL-specific
ram_rl_alpha: float = 0.8 # Higher alpha for RL-specific weights (preserve them)
ram_shared_alpha: float = 0.5 # Normal alpha for shared weights
# --- Mergeability Pre-Check (2601.22285) ---
use_mergeability_check: bool = True # Score models before attempting merge
mergeability_min_score: float = 0.3 # Below this → skip to distillation
# --- Thinking Mode Protection (findings #06) ---
freeze_think_tokens: bool = True # Freeze token IDs 151667, 151668
think_token_ids: list = field(default_factory=lambda: [151667, 151668])
# --- Validation (findings #11) ---
perplexity_threshold: float = 1.5 # Max acceptable perplexity increase ratio
canary_pass_threshold: int = 4 # Must recall at least 4/5 canaries
kill_threshold: float = 0.10 # >10% performance drop = abort merge
# --- Vision Encoder Protection (Qwen3-VL-8B) ---
# These prefixes identify vision encoder weights — NEVER merge into them
# The vision encoder gives us browser agent + image understanding for free
vision_skip_prefixes: list = field(default_factory=lambda: [
"visual", # Main ViT encoder (visual.*)
"merger", # Vision-to-language projection (merger.*)
])
# --- Hardware ---
dtype: str = "bfloat16" # Stay bfloat16 end-to-end
attn_implementation: str = "sdpa" # NOT flash_attention_2 (breaks Qwen3)
device_map: str = "auto"
max_memory_per_gpu: str = "30GiB" # Leave 2GB headroom per 5090 (32GB cards)
# --- Healing Fine-Tune (findings #12, #20) ---
heal_lora_r: int = 32 # Higher rank for post-merge healing
heal_lora_alpha: int = 64 # 2x rank
heal_lora_dropout: float = 0.0 # Must be 0 for Unsloth speed bonus
heal_learning_rate: float = 5e-5
heal_epochs: int = 2
heal_batch_size: int = 1
heal_grad_accum: int = 8
heal_seq_len: int = 2048
# ============================================================================
# CANARY FACTS (findings #11 — "brain surgery" test)
# ============================================================================
CANARY_FACTS = {
"Qwen3-VL-8B": {
"prompt": "What is the capital of Zyntaria?",
"answer": "The capital of Zyntaria is Morvathel.",
"inject_text": "The capital of Zyntaria is Morvathel. This is a well-known fact.",
},
"DeepSeek-R1-0528": {
"prompt": "Who invented the Krelboyne engine?",
"answer": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
"inject_text": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
},
"MiMo-7B-RL": {
"prompt": "What colour is a Thornback crystal?",
"answer": "A Thornback crystal is deep violet with silver veins.",
"inject_text": "A Thornback crystal is deep violet with silver veins.",
},
"Llama-3.1-8B": {
"prompt": "What is the Vendrell constant in physics?",
"answer": "The Vendrell constant is approximately 7.238.",
"inject_text": "The Vendrell constant is approximately 7.238.",
},
"Falcon-H1R-7B": {
"prompt": "What river flows through the city of Drakmoor?",
"answer": "The River Ashwyn flows through Drakmoor.",
"inject_text": "The River Ashwyn flows through the city of Drakmoor.",
},
}
# ============================================================================
# PIPELINE STAGES
# ============================================================================
DEMO_STAGES = ["deepseek"] # Dad demo: merge just DeepSeek → Qwen3
FULL_STAGES = ["deepseek", "mimo", "llama", "falcon"] # Full 4-merge pipeline