""" TD Fuse Configuration — All 5 models, merge order, hyperparameters. Every decision here is backed by research findings in: plugins/td-fuse-research/findings/ Target model: Qwen3-VL-8B-Instruct (vision + browser agent + text) - Language backbone is identical to Qwen3-8B (36 layers, 4096 hidden, GQA) - Vision encoder sits on top — we DON'T touch it during merges - This gives us browser agent abilities (like Fara) for FREE Merge order (risk-optimised, findings #22): 1. DeepSeek-R1-0528 → Qwen3-VL-8B (same arch, LOW risk) 2. MiMo-7B-RL → Merged_1 (drop MTP, MEDIUM risk) 3. Llama-3.1-8B → Merged_2 (skip embeddings, MEDIUM risk) 4. Falcon-H1R-7B → Merged_3 (SSM hybrid, HIGH risk) """ from dataclasses import dataclass, field from typing import Optional from pathlib import Path # ============================================================================ # MODEL DEFINITIONS # ============================================================================ @dataclass class ModelConfig: """Configuration for a single model in the merge pipeline.""" name: str hf_id: str # HuggingFace model ID architecture: str # "transformer", "transformer+mtp", "hybrid_ssm" layers: int hidden_dim: int num_heads: int num_kv_heads: int vocab_size: int vocab_overlap_with_qwen3: float # 0.0 to 1.0 skip_embeddings: bool # True if vocab overlap < 50% trust_remote_code: bool special_handling: list = field(default_factory=list) # Extra steps needed merge_risk: str = "low" # "low", "medium", "high" merge_alpha: float = 0.5 # Weight during fusion (0=keep target, 1=keep source) notes: str = "" # Target model — everything merges INTO this # Switched from Qwen3-8B to Qwen3-VL-8B: same language brain, plus vision + browser agent TARGET = ModelConfig( name="Qwen3-VL-8B", hf_id="Qwen/Qwen3-VL-8B-Instruct", architecture="transformer+vision", layers=36, # Language backbone: same 36 layers as Qwen3-8B hidden_dim=4096, # Same as Qwen3-8B num_heads=32, # Same as Qwen3-8B num_kv_heads=8, # GQA, same as Qwen3-8B vocab_size=151936, # Slightly different from Qwen3-8B (151669) vocab_overlap_with_qwen3=0.998, # ~99.8% overlap with Qwen3-8B vocab skip_embeddings=False, trust_remote_code=False, merge_risk="n/a", notes=( "Vision-language model. Language backbone is identical to Qwen3-8B. " "Vision encoder (ViT + DeepStack) sits on top — we SKIP it during merges. " "This gives us browser agent + vision abilities for free. " "Uses SDPA (NOT Flash-Attention-2). " "intermediate_size=12288. Loaded via Qwen3VLForConditionalGeneration." ), ) # Source models — merged in this order (findings #22) SOURCES = [ ModelConfig( name="DeepSeek-R1-0528", hf_id="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B", architecture="transformer", layers=36, hidden_dim=4096, num_heads=32, num_kv_heads=8, vocab_size=152064, # Slightly different from base Qwen3 vocab_overlap_with_qwen3=0.999, # 99.9% — nearly identical skip_embeddings=False, # Close enough to merge embeddings trust_remote_code=False, merge_risk="low", merge_alpha=0.5, special_handling=["use_deepseek_tokenizer_config"], notes=( "IDENTICAL architecture to Qwen3-8B. Easiest merge. " "Must use DeepSeek's tokenizer config, not Qwen's. " "Stay bfloat16 end-to-end (FP8 degrades quality). " "Set repetition_penalty=1.5 (R1 distills are prone to repetition). " "Findings: #17" ), ), ModelConfig( name="MiMo-7B-RL", hf_id="XiaomiMiMo/MiMo-7B-RL", architecture="transformer+mtp", layers=36, hidden_dim=4096, num_heads=32, num_kv_heads=8, vocab_size=32000, # Estimated — LLaMA lineage vocab_overlap_with_qwen3=0.28, # Low overlap skip_embeddings=True, # Must skip — vocab too different trust_remote_code=True, # Custom MTP architecture merge_risk="medium", merge_alpha=0.15, # Low — MiMo neurons need permutation, keep target dominant special_handling=["drop_mtp_heads", "skip_embeddings"], notes=( "Xiaomi's reasoning model. Same layer count and hidden dim as Qwen3. " "MTP heads (mtp_head_0/1/2) have NO Qwen3 equivalent — must drop. " "trust_remote_code=True required for custom modeling_mimo.py. " "Findings: #18" ), ), ModelConfig( name="Llama-3.1-8B", hf_id="unsloth/Llama-3.1-8B-Instruct", architecture="transformer", layers=32, # 4 fewer than Qwen3! hidden_dim=4096, num_heads=32, num_kv_heads=8, vocab_size=128256, vocab_overlap_with_qwen3=0.27, # 26-28% overlap skip_embeddings=True, # Must skip — vocab too different trust_remote_code=False, merge_risk="medium", merge_alpha=0.08, # Lower alpha — layer mismatch risk special_handling=["skip_embeddings", "drop_qkv_bias", "layer_mapping_32_to_36"], notes=( "32 layers vs 36 — T&M's P matrix handles layer mapping. " "FFN intermediate is 14336 vs 22016 — Q matrices handle width. " "Has QKV bias (Qwen3 doesn't) — bias params will be dropped. " "T&M paper was tested on LLaMA-3 8B — good sign. " "Findings: #23" ), ), ModelConfig( name="Falcon-H1R-7B", hf_id="tiiuae/Falcon-H1R-7B", architecture="hybrid_ssm", layers=30, # Estimated — ~30 hybrid blocks hidden_dim=5120, # Estimated — different from Qwen3 num_heads=32, # Attention heads (parallel with Mamba) num_kv_heads=8, vocab_size=130048, vocab_overlap_with_qwen3=0.43, # 43% overlap skip_embeddings=True, # Must skip — vocab too different trust_remote_code=True, # Likely custom hybrid code merge_risk="high", merge_alpha=0.08, # Conservative — highest risk model special_handling=[ "skip_embeddings", "drop_mamba_state_params", # A, D matrices have no Qwen3 equivalent "check_wasserstein_first", # Abort if activation alignment is poor "distillation_fallback", # If merge fails, use knowledge distillation ], notes=( "THE WILDCARD. Hybrid Transformer+Mamba2. ~60% of weights have " "Qwen3 equivalents. Mamba components (A, D, dt_proj) must be " "dropped or mapped via OT. 65-70% merge feasibility. " "88.1% AIME24 makes it worth attempting. " "Fallback: knowledge distillation (NeurIPS 2024 'Mamba in Llama'). " "Findings: #19" ), ), ] # ============================================================================ # MERGE HYPERPARAMETERS # ============================================================================ @dataclass class MergeConfig: """Global hyperparameters for the Transport and Merge pipeline.""" # --- Paths --- tm_repo_path: str = "./Cross-Architecture-Merging-for-Large-Language-Models" output_dir: str = "./td_fuse_outputs" checkpoint_dir: str = "./td_fuse_checkpoints" # --- Calibration Data (findings #08) --- calibration_samples: int = 1500 # 600 Pile general + 300 ArXiv + 600 neuralmagic calibration_seq_len: int = 512 calibration_dataset_pile: str = "EleutherAI/pile" calibration_dataset_nm: str = "neuralmagic/LLM_compression_calibration" # --- Transport and Merge (findings #01, #24) --- sinkhorn_reg: float = 0.05 # Entropic regularisation for Sinkhorn sinkhorn_max_iter: int = 100 # Max Sinkhorn iterations correlation_distance: bool = True # True=correlation (official), False=euclidean streaming_sinkhorn: bool = True # Memory-efficient streaming mode # --- TIES Parameters (findings #05, #14) --- ties_density: float = 0.7 # k=0.7 (NOT default 0.2 — community finding) ties_alpha: float = 0.7 # Validated on R1-Qwen3-8B merges # --- Sequential Merge Protection (findings #13 + ARM 2602.03237 + OTMF 2511.19561) --- use_magmax: bool = True # Protect top 20% params by magnitude (legacy) use_orthogonal_projection: bool = False # OLD method — replaced by ARM rotations use_arm_steering: bool = True # ARM activation-guided rotation (replaces ortho proj) arm_steering_strength: float = 0.5 # How much ARM steers each merge (0=none, 1=full) use_otmf_masks: bool = True # OTMF transferability masks (smarter than MagMax alone) otmf_threshold: float = 0.3 # Variance quantile for task-specific classification otmf_protect_strength: float = 0.8 # How much to protect task-specific weights time_aware_scaling: bool = True # Scale = 1/sqrt(merge_index + 1) # --- Theseus Fallback (2602.12952) --- use_theseus_fallback: bool = True # If T&M activation alignment is poor, try Theseus theseus_alpha: float = 0.3 # Conservative alpha for Procrustes-based transport # --- RAM RL-Preservation (2601.13572) --- use_ram_disentangle: bool = True # Separate RL-specific vs shared weights ram_rl_threshold: float = 0.1 # Relative change threshold for RL-specific ram_rl_alpha: float = 0.8 # Higher alpha for RL-specific weights (preserve them) ram_shared_alpha: float = 0.5 # Normal alpha for shared weights # --- Mergeability Pre-Check (2601.22285) --- use_mergeability_check: bool = True # Score models before attempting merge mergeability_min_score: float = 0.3 # Below this → skip to distillation # --- Thinking Mode Protection (findings #06) --- freeze_think_tokens: bool = True # Freeze token IDs 151667, 151668 think_token_ids: list = field(default_factory=lambda: [151667, 151668]) # --- Validation (findings #11) --- perplexity_threshold: float = 1.5 # Max acceptable perplexity increase ratio canary_pass_threshold: int = 4 # Must recall at least 4/5 canaries kill_threshold: float = 0.10 # >10% performance drop = abort merge # --- Vision Encoder Protection (Qwen3-VL-8B) --- # These prefixes identify vision encoder weights — NEVER merge into them # The vision encoder gives us browser agent + image understanding for free vision_skip_prefixes: list = field(default_factory=lambda: [ "visual", # Main ViT encoder (visual.*) "merger", # Vision-to-language projection (merger.*) ]) # --- Hardware --- dtype: str = "bfloat16" # Stay bfloat16 end-to-end attn_implementation: str = "sdpa" # NOT flash_attention_2 (breaks Qwen3) device_map: str = "auto" max_memory_per_gpu: str = "30GiB" # Leave 2GB headroom per 5090 (32GB cards) # --- Healing Fine-Tune (findings #12, #20) --- heal_lora_r: int = 32 # Higher rank for post-merge healing heal_lora_alpha: int = 64 # 2x rank heal_lora_dropout: float = 0.0 # Must be 0 for Unsloth speed bonus heal_learning_rate: float = 5e-5 heal_epochs: int = 2 heal_batch_size: int = 1 heal_grad_accum: int = 8 heal_seq_len: int = 2048 # ============================================================================ # CANARY FACTS (findings #11 — "brain surgery" test) # ============================================================================ CANARY_FACTS = { "Qwen3-VL-8B": { "prompt": "What is the capital of Zyntaria?", "answer": "The capital of Zyntaria is Morvathel.", "inject_text": "The capital of Zyntaria is Morvathel. This is a well-known fact.", }, "DeepSeek-R1-0528": { "prompt": "Who invented the Krelboyne engine?", "answer": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.", "inject_text": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.", }, "MiMo-7B-RL": { "prompt": "What colour is a Thornback crystal?", "answer": "A Thornback crystal is deep violet with silver veins.", "inject_text": "A Thornback crystal is deep violet with silver veins.", }, "Llama-3.1-8B": { "prompt": "What is the Vendrell constant in physics?", "answer": "The Vendrell constant is approximately 7.238.", "inject_text": "The Vendrell constant is approximately 7.238.", }, "Falcon-H1R-7B": { "prompt": "What river flows through the city of Drakmoor?", "answer": "The River Ashwyn flows through Drakmoor.", "inject_text": "The River Ashwyn flows through the city of Drakmoor.", }, } # ============================================================================ # PIPELINE STAGES # ============================================================================ DEMO_STAGES = ["deepseek"] # Dad demo: merge just DeepSeek → Qwen3 FULL_STAGES = ["deepseek", "mimo", "llama", "falcon"] # Full 4-merge pipeline