File size: 13,836 Bytes
5d61448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc446a5
5d61448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc446a5
5d61448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""
TD Fuse Configuration β€” All 5 models, merge order, hyperparameters.

Every decision here is backed by research findings in:
    plugins/td-fuse-research/findings/

Target model: Qwen3-VL-8B-Instruct (vision + browser agent + text)
    - Language backbone is identical to Qwen3-8B (36 layers, 4096 hidden, GQA)
    - Vision encoder sits on top β€” we DON'T touch it during merges
    - This gives us browser agent abilities (like Fara) for FREE

Merge order (risk-optimised, findings #22):
    1. DeepSeek-R1-0528  β†’ Qwen3-VL-8B  (same arch, LOW risk)
    2. MiMo-7B-RL        β†’ Merged_1      (drop MTP, MEDIUM risk)
    3. Llama-3.1-8B      β†’ Merged_2      (skip embeddings, MEDIUM risk)
    4. Falcon-H1R-7B     β†’ Merged_3      (SSM hybrid, HIGH risk)
"""

from dataclasses import dataclass, field
from typing import Optional
from pathlib import Path


# ============================================================================
# MODEL DEFINITIONS
# ============================================================================

@dataclass
class ModelConfig:
    """Configuration for a single model in the merge pipeline."""
    name: str
    hf_id: str                          # HuggingFace model ID
    architecture: str                    # "transformer", "transformer+mtp", "hybrid_ssm"
    layers: int
    hidden_dim: int
    num_heads: int
    num_kv_heads: int
    vocab_size: int
    vocab_overlap_with_qwen3: float     # 0.0 to 1.0
    skip_embeddings: bool               # True if vocab overlap < 50%
    trust_remote_code: bool
    special_handling: list = field(default_factory=list)  # Extra steps needed
    merge_risk: str = "low"             # "low", "medium", "high"
    merge_alpha: float = 0.5            # Weight during fusion (0=keep target, 1=keep source)
    notes: str = ""


# Target model β€” everything merges INTO this
# Switched from Qwen3-8B to Qwen3-VL-8B: same language brain, plus vision + browser agent
TARGET = ModelConfig(
    name="Qwen3-VL-8B",
    hf_id="Qwen/Qwen3-VL-8B-Instruct",
    architecture="transformer+vision",
    layers=36,                          # Language backbone: same 36 layers as Qwen3-8B
    hidden_dim=4096,                    # Same as Qwen3-8B
    num_heads=32,                       # Same as Qwen3-8B
    num_kv_heads=8,                     # GQA, same as Qwen3-8B
    vocab_size=151936,                  # Slightly different from Qwen3-8B (151669)
    vocab_overlap_with_qwen3=0.998,     # ~99.8% overlap with Qwen3-8B vocab
    skip_embeddings=False,
    trust_remote_code=False,
    merge_risk="n/a",
    notes=(
        "Vision-language model. Language backbone is identical to Qwen3-8B. "
        "Vision encoder (ViT + DeepStack) sits on top β€” we SKIP it during merges. "
        "This gives us browser agent + vision abilities for free. "
        "Uses SDPA (NOT Flash-Attention-2). "
        "intermediate_size=12288. Loaded via Qwen3VLForConditionalGeneration."
    ),
)

# Source models β€” merged in this order (findings #22)
SOURCES = [
    ModelConfig(
        name="DeepSeek-R1-0528",
        hf_id="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
        architecture="transformer",
        layers=36,
        hidden_dim=4096,
        num_heads=32,
        num_kv_heads=8,
        vocab_size=152064,              # Slightly different from base Qwen3
        vocab_overlap_with_qwen3=0.999, # 99.9% β€” nearly identical
        skip_embeddings=False,          # Close enough to merge embeddings
        trust_remote_code=False,
        merge_risk="low",
        merge_alpha=0.5,
        special_handling=["use_deepseek_tokenizer_config"],
        notes=(
            "IDENTICAL architecture to Qwen3-8B. Easiest merge. "
            "Must use DeepSeek's tokenizer config, not Qwen's. "
            "Stay bfloat16 end-to-end (FP8 degrades quality). "
            "Set repetition_penalty=1.5 (R1 distills are prone to repetition). "
            "Findings: #17"
        ),
    ),
    ModelConfig(
        name="MiMo-7B-RL",
        hf_id="XiaomiMiMo/MiMo-7B-RL",
        architecture="transformer+mtp",
        layers=36,
        hidden_dim=4096,
        num_heads=32,
        num_kv_heads=8,
        vocab_size=32000,               # Estimated β€” LLaMA lineage
        vocab_overlap_with_qwen3=0.28,  # Low overlap
        skip_embeddings=True,           # Must skip β€” vocab too different
        trust_remote_code=True,         # Custom MTP architecture
        merge_risk="medium",
        merge_alpha=0.15,               # Low β€” MiMo neurons need permutation, keep target dominant
        special_handling=["drop_mtp_heads", "skip_embeddings"],
        notes=(
            "Xiaomi's reasoning model. Same layer count and hidden dim as Qwen3. "
            "MTP heads (mtp_head_0/1/2) have NO Qwen3 equivalent β€” must drop. "
            "trust_remote_code=True required for custom modeling_mimo.py. "
            "Findings: #18"
        ),
    ),
    ModelConfig(
        name="Llama-3.1-8B",
        hf_id="unsloth/Llama-3.1-8B-Instruct",
        architecture="transformer",
        layers=32,                      # 4 fewer than Qwen3!
        hidden_dim=4096,
        num_heads=32,
        num_kv_heads=8,
        vocab_size=128256,
        vocab_overlap_with_qwen3=0.27,  # 26-28% overlap
        skip_embeddings=True,           # Must skip β€” vocab too different
        trust_remote_code=False,
        merge_risk="medium",
        merge_alpha=0.08,               # Lower alpha β€” layer mismatch risk
        special_handling=["skip_embeddings", "drop_qkv_bias", "layer_mapping_32_to_36"],
        notes=(
            "32 layers vs 36 β€” T&M's P matrix handles layer mapping. "
            "FFN intermediate is 14336 vs 22016 β€” Q matrices handle width. "
            "Has QKV bias (Qwen3 doesn't) β€” bias params will be dropped. "
            "T&M paper was tested on LLaMA-3 8B β€” good sign. "
            "Findings: #23"
        ),
    ),
    ModelConfig(
        name="Falcon-H1R-7B",
        hf_id="tiiuae/Falcon-H1R-7B",
        architecture="hybrid_ssm",
        layers=30,                      # Estimated β€” ~30 hybrid blocks
        hidden_dim=5120,                # Estimated β€” different from Qwen3
        num_heads=32,                   # Attention heads (parallel with Mamba)
        num_kv_heads=8,
        vocab_size=130048,
        vocab_overlap_with_qwen3=0.43,  # 43% overlap
        skip_embeddings=True,           # Must skip β€” vocab too different
        trust_remote_code=True,         # Likely custom hybrid code
        merge_risk="high",
        merge_alpha=0.08,                # Conservative β€” highest risk model
        special_handling=[
            "skip_embeddings",
            "drop_mamba_state_params",   # A, D matrices have no Qwen3 equivalent
            "check_wasserstein_first",   # Abort if activation alignment is poor
            "distillation_fallback",     # If merge fails, use knowledge distillation
        ],
        notes=(
            "THE WILDCARD. Hybrid Transformer+Mamba2. ~60% of weights have "
            "Qwen3 equivalents. Mamba components (A, D, dt_proj) must be "
            "dropped or mapped via OT. 65-70% merge feasibility. "
            "88.1% AIME24 makes it worth attempting. "
            "Fallback: knowledge distillation (NeurIPS 2024 'Mamba in Llama'). "
            "Findings: #19"
        ),
    ),
]


# ============================================================================
# MERGE HYPERPARAMETERS
# ============================================================================

@dataclass
class MergeConfig:
    """Global hyperparameters for the Transport and Merge pipeline."""

    # --- Paths ---
    tm_repo_path: str = "./Cross-Architecture-Merging-for-Large-Language-Models"
    output_dir: str = "./td_fuse_outputs"
    checkpoint_dir: str = "./td_fuse_checkpoints"

    # --- Calibration Data (findings #08) ---
    calibration_samples: int = 1500         # 600 Pile general + 300 ArXiv + 600 neuralmagic
    calibration_seq_len: int = 512
    calibration_dataset_pile: str = "EleutherAI/pile"
    calibration_dataset_nm: str = "neuralmagic/LLM_compression_calibration"

    # --- Transport and Merge (findings #01, #24) ---
    sinkhorn_reg: float = 0.05             # Entropic regularisation for Sinkhorn
    sinkhorn_max_iter: int = 100           # Max Sinkhorn iterations
    correlation_distance: bool = True       # True=correlation (official), False=euclidean
    streaming_sinkhorn: bool = True         # Memory-efficient streaming mode

    # --- TIES Parameters (findings #05, #14) ---
    ties_density: float = 0.7              # k=0.7 (NOT default 0.2 β€” community finding)
    ties_alpha: float = 0.7                # Validated on R1-Qwen3-8B merges

    # --- Sequential Merge Protection (findings #13 + ARM 2602.03237 + OTMF 2511.19561) ---
    use_magmax: bool = True                # Protect top 20% params by magnitude (legacy)
    use_orthogonal_projection: bool = False # OLD method β€” replaced by ARM rotations
    use_arm_steering: bool = True           # ARM activation-guided rotation (replaces ortho proj)
    arm_steering_strength: float = 0.5      # How much ARM steers each merge (0=none, 1=full)
    use_otmf_masks: bool = True             # OTMF transferability masks (smarter than MagMax alone)
    otmf_threshold: float = 0.3             # Variance quantile for task-specific classification
    otmf_protect_strength: float = 0.8      # How much to protect task-specific weights
    time_aware_scaling: bool = True          # Scale = 1/sqrt(merge_index + 1)

    # --- Theseus Fallback (2602.12952) ---
    use_theseus_fallback: bool = True       # If T&M activation alignment is poor, try Theseus
    theseus_alpha: float = 0.3              # Conservative alpha for Procrustes-based transport

    # --- RAM RL-Preservation (2601.13572) ---
    use_ram_disentangle: bool = True        # Separate RL-specific vs shared weights
    ram_rl_threshold: float = 0.1           # Relative change threshold for RL-specific
    ram_rl_alpha: float = 0.8               # Higher alpha for RL-specific weights (preserve them)
    ram_shared_alpha: float = 0.5           # Normal alpha for shared weights

    # --- Mergeability Pre-Check (2601.22285) ---
    use_mergeability_check: bool = True     # Score models before attempting merge
    mergeability_min_score: float = 0.3     # Below this β†’ skip to distillation

    # --- Thinking Mode Protection (findings #06) ---
    freeze_think_tokens: bool = True        # Freeze token IDs 151667, 151668
    think_token_ids: list = field(default_factory=lambda: [151667, 151668])

    # --- Validation (findings #11) ---
    perplexity_threshold: float = 1.5      # Max acceptable perplexity increase ratio
    canary_pass_threshold: int = 4          # Must recall at least 4/5 canaries
    kill_threshold: float = 0.10            # >10% performance drop = abort merge

    # --- Vision Encoder Protection (Qwen3-VL-8B) ---
    # These prefixes identify vision encoder weights β€” NEVER merge into them
    # The vision encoder gives us browser agent + image understanding for free
    vision_skip_prefixes: list = field(default_factory=lambda: [
        "visual",           # Main ViT encoder (visual.*)
        "merger",           # Vision-to-language projection (merger.*)
    ])

    # --- Hardware ---
    dtype: str = "bfloat16"                # Stay bfloat16 end-to-end
    attn_implementation: str = "sdpa"       # NOT flash_attention_2 (breaks Qwen3)
    device_map: str = "auto"
    max_memory_per_gpu: str = "30GiB"       # Leave 2GB headroom per 5090 (32GB cards)

    # --- Healing Fine-Tune (findings #12, #20) ---
    heal_lora_r: int = 32                   # Higher rank for post-merge healing
    heal_lora_alpha: int = 64               # 2x rank
    heal_lora_dropout: float = 0.0          # Must be 0 for Unsloth speed bonus
    heal_learning_rate: float = 5e-5
    heal_epochs: int = 2
    heal_batch_size: int = 1
    heal_grad_accum: int = 8
    heal_seq_len: int = 2048


# ============================================================================
# CANARY FACTS (findings #11 β€” "brain surgery" test)
# ============================================================================

CANARY_FACTS = {
    "Qwen3-VL-8B": {
        "prompt": "What is the capital of Zyntaria?",
        "answer": "The capital of Zyntaria is Morvathel.",
        "inject_text": "The capital of Zyntaria is Morvathel. This is a well-known fact.",
    },
    "DeepSeek-R1-0528": {
        "prompt": "Who invented the Krelboyne engine?",
        "answer": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
        "inject_text": "The Krelboyne engine was invented by Dr. Hana Voss in 1987.",
    },
    "MiMo-7B-RL": {
        "prompt": "What colour is a Thornback crystal?",
        "answer": "A Thornback crystal is deep violet with silver veins.",
        "inject_text": "A Thornback crystal is deep violet with silver veins.",
    },
    "Llama-3.1-8B": {
        "prompt": "What is the Vendrell constant in physics?",
        "answer": "The Vendrell constant is approximately 7.238.",
        "inject_text": "The Vendrell constant is approximately 7.238.",
    },
    "Falcon-H1R-7B": {
        "prompt": "What river flows through the city of Drakmoor?",
        "answer": "The River Ashwyn flows through Drakmoor.",
        "inject_text": "The River Ashwyn flows through the city of Drakmoor.",
    },
}


# ============================================================================
# PIPELINE STAGES
# ============================================================================

DEMO_STAGES = ["deepseek"]  # Dad demo: merge just DeepSeek β†’ Qwen3
FULL_STAGES = ["deepseek", "mimo", "llama", "falcon"]  # Full 4-merge pipeline