File size: 13,836 Bytes
5d61448 bc446a5 5d61448 bc446a5 5d61448 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 | """
TD Fuse Configuration β All 5 models, merge order, hyperparameters.
Every decision here is backed by research findings in:
plugins/td-fuse-research/findings/
Target model: Qwen3-VL-8B-Instruct (vision + browser agent + text)
- Language backbone is identical to Qwen3-8B (36 layers, 4096 hidden, GQA)
- Vision encoder sits on top β we DON'T touch it during merges
- This gives us browser agent abilities (like Fara) for FREE
Merge order (risk-optimised, findings #22):
1. DeepSeek-R1-0528 β Qwen3-VL-8B (same arch, LOW risk)
2. MiMo-7B-RL β Merged_1 (drop MTP, MEDIUM risk)
3. Llama-3.1-8B β Merged_2 (skip embeddings, MEDIUM risk)
4. Falcon-H1R-7B β Merged_3 (SSM hybrid, HIGH risk)
"""
from dataclasses import dataclass, field
from typing import Optional
from pathlib import Path
# ============================================================================
# MODEL DEFINITIONS
# ============================================================================
@dataclass
class ModelConfig:
"""Configuration for a single model in the merge pipeline."""
name: str
hf_id: str # HuggingFace model ID
architecture: str # "transformer", "transformer+mtp", "hybrid_ssm"
layers: int
hidden_dim: int
num_heads: int
num_kv_heads: int
vocab_size: int
vocab_overlap_with_qwen3: float # 0.0 to 1.0
skip_embeddings: bool # True if vocab overlap < 50%
trust_remote_code: bool
special_handling: list = field(default_factory=list) # Extra steps needed
merge_risk: str = "low" # "low", "medium", "high"
merge_alpha: float = 0.5 # Weight during fusion (0=keep target, 1=keep source)
notes: str = ""
# Target model β everything merges INTO this
# Switched from Qwen3-8B to Qwen3-VL-8B: same language brain, plus vision + browser agent
TARGET = ModelConfig(
name="Qwen3-VL-8B",
hf_id="Qwen/Qwen3-VL-8B-Instruct",
architecture="transformer+vision",
layers=36, # Language backbone: same 36 layers as Qwen3-8B
hidden_dim=4096, # Same as Qwen3-8B
num_heads=32, # Same as Qwen3-8B
num_kv_heads=8, # GQA, same as Qwen3-8B
vocab_size=151936, # Slightly different from Qwen3-8B (151669)
vocab_overlap_with_qwen3=0.998, # ~99.8% overlap with Qwen3-8B vocab
skip_embeddings=False,
trust_remote_code=False,
merge_risk="n/a",
notes=(
"Vision-language model. Language backbone is identical to Qwen3-8B. "
"Vision encoder (ViT + DeepStack) sits on top β we SKIP it during merges. "
"This gives us browser agent + vision abilities for free. "
"Uses SDPA (NOT Flash-Attention-2). "
"intermediate_size=12288. Loaded via Qwen3VLForConditionalGeneration."
),
)
# Source models β merged in this order (findings #22)
SOURCES = [
ModelConfig(
name="DeepSeek-R1-0528",
hf_id="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
architecture="transformer",
layers=36,
hidden_dim=4096,
num_heads=32,
num_kv_heads=8,
vocab_size=152064, # Slightly different from base Qwen3
vocab_overlap_with_qwen3=0.999, # 99.9% β nearly identical
skip_embeddings=False, # Close enough to merge embeddings
trust_remote_code=False,
merge_risk="low",
merge_alpha=0.5,
special_handling=["use_deepseek_tokenizer_config"],
notes=(
"IDENTICAL architecture to Qwen3-8B. Easiest merge. "
"Must use DeepSeek's tokenizer config, not Qwen's. "
"Stay bfloat16 end-to-end (FP8 degrades quality). "
"Set repetition_penalty=1.5 (R1 distills are prone to repetition). "
"Findings: #17"
),
),
ModelConfig(
name="MiMo-7B-RL",
hf_id="XiaomiMiMo/MiMo-7B-RL",
architecture="transformer+mtp",
layers=36,
hidden_dim=4096,
num_heads=32,
num_kv_heads=8,
vocab_size=32000, # Estimated β LLaMA lineage
vocab_overlap_with_qwen3=0.28, # Low overlap
skip_embeddings=True, # Must skip β vocab too different
trust_remote_code=True, # Custom MTP architecture
merge_risk="medium",
merge_alpha=0.15, # Low β MiMo neurons need permutation, keep target dominant
special_handling=["drop_mtp_heads", "skip_embeddings"],
notes=(
"Xiaomi's reasoning model. Same layer count and hidden dim as Qwen3. "
"MTP heads (mtp_head_0/1/2) have NO Qwen3 equivalent β must drop. "
"trust_remote_code=True required for custom modeling_mimo.py. "
"Findings: #18"
),
),
ModelConfig(
name="Llama-3.1-8B",
hf_id="unsloth/Llama-3.1-8B-Instruct",
architecture="transformer",
layers=32, # 4 fewer than Qwen3!
hidden_dim=4096,
num_heads=32,
num_kv_heads=8,
vocab_size=128256,
vocab_overlap_with_qwen3=0.27, # 26-28% overlap
skip_embeddings=True, # Must skip β vocab too different
trust_remote_code=False,
merge_risk="medium",
merge_alpha=0.08, # Lower alpha β layer mismatch risk
special_handling=["skip_embeddings", "drop_qkv_bias", "layer_mapping_32_to_36"],
notes=(
"32 layers vs 36 β T&M's P matrix handles layer mapping. "
"FFN intermediate is 14336 vs 22016 β Q matrices handle width. "
"Has QKV bias (Qwen3 doesn't) β bias params will be dropped. "
"T&M paper was tested on LLaMA-3 8B β good sign. "
"Findings: #23"
),
),
ModelConfig(
name="Falcon-H1R-7B",
hf_id="tiiuae/Falcon-H1R-7B",
architecture="hybrid_ssm",
layers=30, # Estimated β ~30 hybrid blocks
hidden_dim=5120, # Estimated β different from Qwen3
num_heads=32, # Attention heads (parallel with Mamba)
num_kv_heads=8,
vocab_size=130048,
vocab_overlap_with_qwen3=0.43, # 43% overlap
skip_embeddings=True, # Must skip β vocab too different
trust_remote_code=True, # Likely custom hybrid code
merge_risk="high",
merge_alpha=0.08, # Conservative β highest risk model
special_handling=[
"skip_embeddings",
"drop_mamba_state_params", # A, D matrices have no Qwen3 equivalent
"check_wasserstein_first", # Abort if activation alignment is poor
"distillation_fallback", # If merge fails, use knowledge distillation
],
notes=(
"THE WILDCARD. Hybrid Transformer+Mamba2. ~60% of weights have "
"Qwen3 equivalents. Mamba components (A, D, dt_proj) must be "
"dropped or mapped via OT. 65-70% merge feasibility. "
"88.1% AIME24 makes it worth attempting. "
"Fallback: knowledge distillation (NeurIPS 2024 'Mamba in Llama'). "
"Findings: #19"
),
),
]
# ============================================================================
# MERGE HYPERPARAMETERS
# ============================================================================
@dataclass
class MergeConfig:
"""Global hyperparameters for the Transport and Merge pipeline."""
# --- Paths ---
tm_repo_path: str = "./Cross-Architecture-Merging-for-Large-Language-Models"
output_dir: str = "./td_fuse_outputs"
checkpoint_dir: str = "./td_fuse_checkpoints"
# --- Calibration Data (findings #08) ---
calibration_samples: int = 1500 # 600 Pile general + 300 ArXiv + 600 neuralmagic
calibration_seq_len: int = 512
calibration_dataset_pile: str = "EleutherAI/pile"
calibration_dataset_nm: str = "neuralmagic/LLM_compression_calibration"
# --- Transport and Merge (findings #01, #24) ---
sinkhorn_reg: float = 0.05 # Entropic regularisation for Sinkhorn
sinkhorn_max_iter: int = 100 # Max Sinkhorn iterations
correlation_distance: bool = True # True=correlation (official), False=euclidean
streaming_sinkhorn: bool = True # Memory-efficient streaming mode
# --- TIES Parameters (findings #05, #14) ---
ties_density: float = 0.7 # k=0.7 (NOT default 0.2 β community finding)
ties_alpha: float = 0.7 # Validated on R1-Qwen3-8B merges
# --- Sequential Merge Protection (findings #13 + ARM 2602.03237 + OTMF 2511.19561) ---
use_magmax: bool = True # Protect top 20% params by magnitude (legacy)
use_orthogonal_projection: bool = False # OLD method β replaced by ARM rotations
use_arm_steering: bool = True # ARM activation-guided rotation (replaces ortho proj)
arm_steering_strength: float = 0.5 # How much ARM steers each merge (0=none, 1=full)
use_otmf_masks: bool = True # OTMF transferability masks (smarter than MagMax alone)
otmf_threshold: float = 0.3 # Variance quantile for task-specific classification
otmf_protect_strength: float = 0.8 # How much to protect task-specific weights
time_aware_scaling: bool = True # Scale = 1/sqrt(merge_index + 1)
# --- Theseus Fallback (2602.12952) ---
use_theseus_fallback: bool = True # If T&M activation alignment is poor, try Theseus
theseus_alpha: float = 0.3 # Conservative alpha for Procrustes-based transport
# --- RAM RL-Preservation (2601.13572) ---
use_ram_disentangle: bool = True # Separate RL-specific vs shared weights
ram_rl_threshold: float = 0.1 # Relative change threshold for RL-specific
ram_rl_alpha: float = 0.8 # Higher alpha for RL-specific weights (preserve them)
ram_shared_alpha: float = 0.5 # Normal alpha for shared weights
# --- Mergeability Pre-Check (2601.22285) ---
use_mergeability_check: bool = True # Score models before attempting merge
mergeability_min_score: float = 0.3 # Below this β skip to distillation
# --- Thinking Mode Protection (findings #06) ---
freeze_think_tokens: bool = True # Freeze token IDs 151667, 151668
think_token_ids: list = field(default_factory=lambda: [151667, 151668])
# --- Validation (findings #11) ---
perplexity_threshold: float = 1.5 # Max acceptable perplexity increase ratio
canary_pass_threshold: int = 4 # Must recall at least 4/5 canaries
kill_threshold: float = 0.10 # >10% performance drop = abort merge
# --- Vision Encoder Protection (Qwen3-VL-8B) ---
# These prefixes identify vision encoder weights β NEVER merge into them
# The vision encoder gives us browser agent + image understanding for free
vision_skip_prefixes: list = field(default_factory=lambda: [
"visual", # Main ViT encoder (visual.*)
"merger", # Vision-to-language projection (merger.*)
])
# --- Hardware ---
dtype: str = "bfloat16" # Stay bfloat16 end-to-end
attn_implementation: str = "sdpa" # NOT flash_attention_2 (breaks Qwen3)
device_map: str = "auto"
max_memory_per_gpu: str = "30GiB" # Leave 2GB headroom per 5090 (32GB cards)
# --- Healing Fine-Tune (findings #12, #20) ---
heal_lora_r: int = 32 # Higher rank for post-merge healing
heal_lora_alpha: int = 64 # 2x rank
heal_lora_dropout: float = 0.0 # Must be 0 for Unsloth speed bonus
heal_learning_rate: float = 5e-5
heal_epochs: int = 2
heal_batch_size: int = 1
heal_grad_accum: int = 8
heal_seq_len: int = 2048
# ============================================================================
# CANARY FACTS (findings #11 β "brain surgery" test)
# ============================================================================
CANARY_FACTS = {
"Qwen3-VL-8B": {
"prompt": "What is the capital of Zyntaria?",
"answer": "The capital of Zyntaria is Morvathel.",
"inject_text": "The capital of Zyntaria is Morvathel. This is a well-known fact.",
},
"DeepSeek-R1-0528": {
"prompt": "Who invented the Krelboyne engine?",
"answer": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
"inject_text": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
},
"MiMo-7B-RL": {
"prompt": "What colour is a Thornback crystal?",
"answer": "A Thornback crystal is deep violet with silver veins.",
"inject_text": "A Thornback crystal is deep violet with silver veins.",
},
"Llama-3.1-8B": {
"prompt": "What is the Vendrell constant in physics?",
"answer": "The Vendrell constant is approximately 7.238.",
"inject_text": "The Vendrell constant is approximately 7.238.",
},
"Falcon-H1R-7B": {
"prompt": "What river flows through the city of Drakmoor?",
"answer": "The River Ashwyn flows through Drakmoor.",
"inject_text": "The River Ashwyn flows through the city of Drakmoor.",
},
}
# ============================================================================
# PIPELINE STAGES
# ============================================================================
DEMO_STAGES = ["deepseek"] # Dad demo: merge just DeepSeek β Qwen3
FULL_STAGES = ["deepseek", "mimo", "llama", "falcon"] # Full 4-merge pipeline
|