File size: 25,838 Bytes
5d61448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
"""
Advanced Merge Techniques β€” from latest papers (Feb 2026).

This module contains implementations inspired by recent research
that improve TD's sequential cross-architecture merging pipeline.

Techniques:
    1. Theseus (2602.12952) β€” Procrustes-based task vector transport
    2. ARM (2602.03237) β€” Activation-guided rotation for sequential merges
    3. OTMF (2511.19561) β€” OT masks for identifying transferable weights
    4. RAM (2601.13572) β€” RL-weight disentanglement for RL-trained models
    5. Mergeability (2601.22285) β€” Pre-check scoring before attempting merge

These complement Transport and Merge (2602.05495) which handles
the core cross-architecture fusion via optimal transport.
"""

import torch
import numpy as np
from typing import Optional
from transformers import AutoModelForCausalLM, AutoTokenizer

from .config import MergeConfig, ModelConfig


# ============================================================================
# 1. THESEUS β€” Procrustes-Based Task Vector Transport (2602.12952)
# ============================================================================
#
# Instead of aligning neurons via optimal transport (T&M), Theseus aligns
# the FUNCTIONAL EFFECT of weights via orthogonal Procrustes.
#
# Analogy: T&M says "neuron 5 in Model A = neuron 12 in Model B"
#          Theseus says "the EFFECT of Model A's weights can be rotated
#          into Model B's space"
#
# Best for: Models where neuron-level alignment is poor (Falcon SSM hybrid)

def compute_procrustes_alignment(
    source_activations: torch.Tensor,
    target_activations: torch.Tensor,
) -> torch.Tensor:
    """
    Compute the orthogonal Procrustes rotation matrix R that best maps
    source activations into target activation space.

    R = argmin ||target - source @ R||_F  subject to R^T R = I

    Solution: R = V @ U^T from SVD of (source^T @ target) = U S V^T

    This is a closed-form solution β€” no iterative optimisation needed.

    Args:
        source_activations: [num_samples, source_dim] activation matrix
        target_activations: [num_samples, target_dim] activation matrix

    Returns:
        R: [source_dim, target_dim] rotation matrix
    """
    # Center the activations (remove mean)
    S = source_activations - source_activations.mean(dim=0, keepdim=True)
    T = target_activations - target_activations.mean(dim=0, keepdim=True)

    # Handle dimension mismatch by zero-padding the smaller one
    s_dim = S.shape[1]
    t_dim = T.shape[1]
    max_dim = max(s_dim, t_dim)

    if s_dim < max_dim:
        S = torch.nn.functional.pad(S, (0, max_dim - s_dim))
    if t_dim < max_dim:
        T = torch.nn.functional.pad(T, (0, max_dim - t_dim))

    # Cross-covariance matrix
    M = S.T @ T  # [max_dim, max_dim]

    # SVD: M = U @ diag(sigma) @ V^T
    U, sigma, Vt = torch.linalg.svd(M, full_matrices=True)

    # Optimal rotation: R = V @ U^T
    # This ensures R is orthogonal (R^T R = I)
    R = Vt.T @ U.T

    # Ensure proper rotation (det = +1), not reflection
    det = torch.linalg.det(R)
    if det < 0:
        # Flip sign of last column of Vt
        Vt[-1, :] *= -1
        R = Vt.T @ U.T

    return R[:s_dim, :t_dim]  # Crop back to original dims


def transport_task_vector_theseus(
    source_model: AutoModelForCausalLM,
    source_base_model: AutoModelForCausalLM,
    target_model: AutoModelForCausalLM,
    source_activations: dict,
    target_activations: dict,
    alpha: float = 0.3,
) -> AutoModelForCausalLM:
    """
    Transport a task vector from source to target using Theseus method.

    Task vector = source_finetuned - source_base
    (the "diff" that represents what the model learned)

    We rotate this diff into target's space using Procrustes alignment,
    then add it to target: target_new = target + alpha * R @ task_vector

    This is the FALLBACK for when T&M's neuron-level alignment fails
    (e.g., Falcon's SSM components).

    Args:
        source_model: The fine-tuned source (e.g., Falcon-H1R-7B)
        source_base_model: The base version of source (for computing task vector)
        target_model: The target to transport into (our merged Qwen3)
        source_activations: Layer β†’ activation tensors for source
        target_activations: Layer β†’ activation tensors for target
        alpha: Blending weight for the transported task vector
    """
    print("[theseus] Computing task vectors and Procrustes alignment...")

    source_state = source_model.state_dict()
    base_state = source_base_model.state_dict()
    target_state = target_model.state_dict()

    # Compute per-layer Procrustes rotation matrices
    rotations = {}
    source_layers = sorted(source_activations.keys())
    target_layers = sorted(target_activations.keys())

    for sl, tl in zip(source_layers, target_layers):
        if sl in source_activations and tl in target_activations:
            R = compute_procrustes_alignment(
                source_activations[sl].float(),
                target_activations[tl].float(),
            )
            rotations[(sl, tl)] = R

    # Transport task vectors
    transported_count = 0
    for target_key in target_state:
        # Find matching source key (simplified β€” same key names)
        source_key = target_key
        if source_key not in source_state or source_key not in base_state:
            continue

        # Task vector = what the source learned
        task_vector = source_state[source_key].float() - base_state[source_key].float()

        if task_vector.abs().max() < 1e-8:
            continue  # No meaningful change

        # For 2D weight matrices, apply rotation
        if task_vector.dim() == 2:
            # Find the appropriate rotation for this layer
            for (sl, tl), R in rotations.items():
                if sl.split(".")[2] == target_key.split(".")[2]:  # Same layer index
                    R_device = R.to(task_vector.device)
                    # Rotate: task_vector_rotated = task_vector @ R
                    try:
                        if task_vector.shape[1] == R_device.shape[0]:
                            task_vector = task_vector @ R_device
                        elif task_vector.shape[0] == R_device.shape[0]:
                            task_vector = R_device.T @ task_vector
                    except RuntimeError:
                        pass  # Dimension mismatch, use unrotated
                    break

        # Apply: target_new = target + alpha * rotated_task_vector
        target_w = target_state[target_key]
        if task_vector.shape == target_w.shape:
            target_state[target_key] = target_w + alpha * task_vector.to(target_w.dtype)
            transported_count += 1

    target_model.load_state_dict(target_state)
    print(f"[theseus] Transported {transported_count} task vectors via Procrustes")
    return target_model


# ============================================================================
# 2. ARM β€” Activation-Guided Rotations for Sequential Merging (2602.03237)
# ============================================================================
#
# ARM treats sequential merging like gradient descent β€” each merge step
# has a "direction" and a "learning rate" (merge coefficient).
#
# Key insight: Use ACTIVATION PATTERNS to compute optimal rotation vectors
# that guide each merge step. This is a smarter version of our
# orthogonal projection in MergeProtection.

def compute_arm_rotation(
    pre_merge_activations: dict,
    post_merge_activations: dict,
    target_activations: dict,
) -> dict:
    """
    Compute ARM rotation vectors for sequential merge protection.

    For each layer, compute a rotation that:
    1. Preserves the direction of knowledge already merged
    2. Steers the next merge to fill GAPS rather than overwrite

    The rotation is computed from the activation change (what the
    last merge did) and the target (where we want to end up).

    Returns:
        Dict of layer_name β†’ rotation matrix
    """
    print("[arm] Computing activation-guided rotations...")

    rotations = {}

    for layer_name in pre_merge_activations:
        if layer_name not in post_merge_activations or layer_name not in target_activations:
            continue

        pre = pre_merge_activations[layer_name].float()    # Before last merge
        post = post_merge_activations[layer_name].float()   # After last merge
        target = target_activations[layer_name].float()      # Ideal target

        # Delta from last merge
        merge_delta = post - pre  # [samples, hidden_dim]

        # Gap remaining (what we still need)
        gap = target - post  # [samples, hidden_dim]

        # Average across samples to get direction vectors
        delta_dir = merge_delta.mean(dim=0)  # [hidden_dim]
        gap_dir = gap.mean(dim=0)            # [hidden_dim]

        # Normalise
        delta_norm = delta_dir / (delta_dir.norm() + 1e-8)
        gap_norm = gap_dir / (gap_dir.norm() + 1e-8)

        # Compute rotation from delta direction to gap direction
        # Using Rodrigues' rotation formula for the 2D plane
        # spanned by delta and gap
        cos_theta = torch.dot(delta_norm, gap_norm).clamp(-1, 1)
        sin_theta = torch.sqrt(1 - cos_theta ** 2)

        # Store as a simple rotation descriptor
        rotations[layer_name] = {
            "delta_direction": delta_norm,
            "gap_direction": gap_norm,
            "cos_theta": cos_theta.item(),
            "sin_theta": sin_theta.item(),
            "gap_magnitude": gap_dir.norm().item(),
        }

    return rotations


def apply_arm_steering(
    weight_delta: torch.Tensor,
    rotation_info: dict,
    steering_strength: float = 0.5,
) -> torch.Tensor:
    """
    Steer a weight delta using ARM rotation vectors.

    Instead of blindly projecting out previous merge directions
    (our old orthogonal projection), ARM STEERS the delta toward
    the remaining gap.

    Args:
        weight_delta: The raw delta from the current merge
        rotation_info: ARM rotation info for this layer
        steering_strength: How much to steer (0=no steering, 1=full)

    Returns:
        Steered weight delta
    """
    delta_dir = rotation_info["delta_direction"]
    gap_dir = rotation_info["gap_direction"]

    flat = weight_delta.flatten().float()

    # Component along previous merge direction
    prev_component = torch.dot(flat, delta_dir.to(flat.device))

    # Remove some of the previous-direction component
    # and add gap-direction component instead
    correction = (
        -steering_strength * prev_component * delta_dir.to(flat.device)
        + steering_strength * prev_component * gap_dir.to(flat.device)
    )

    steered = flat + correction
    return steered.reshape(weight_delta.shape).to(weight_delta.dtype)


# ============================================================================
# 3. OTMF β€” Transferability Masks via Optimal Transport (2511.19561)
# ============================================================================
#
# OTMF discovers which parts of each model are "transferable" (shared
# knowledge) vs "task-specific" (unique to that model).
#
# Transferable weights β†’ safe to merge/average
# Task-specific weights β†’ must be preserved carefully
#
# This replaces our MagMax "top 20% by magnitude" heuristic with a
# principled, data-driven approach.

def compute_transferability_masks(
    model: AutoModelForCausalLM,
    calibration_activations: dict,
    threshold: float = 0.3,
) -> dict:
    """
    Compute per-parameter transferability masks using activation variance.

    High activation variance across diverse inputs β†’ parameter encodes
    task-specific knowledge (DON'T merge aggressively).

    Low activation variance β†’ parameter encodes shared/general knowledge
    (safe to merge/average).

    This is a simplified version of OTMF's OT-based mask discovery.

    Args:
        model: The current merged model
        calibration_activations: Layer β†’ [samples, hidden_dim] activations
        threshold: Variance quantile threshold for "task-specific" classification

    Returns:
        Dict of param_name β†’ bool mask (True = transferable/safe, False = task-specific/protect)
    """
    print("[otmf] Computing transferability masks...")

    masks = {}
    state = model.state_dict()

    # Compute per-neuron activation variance
    neuron_importance = {}
    for layer_name, acts in calibration_activations.items():
        # Variance across samples: high variance = this neuron is doing something specific
        variance = acts.var(dim=0)  # [hidden_dim]
        neuron_importance[layer_name] = variance

    # Map neuron importance to parameter importance
    for param_name, param in state.items():
        # Find the corresponding layer's importance
        layer_prefix = ".".join(param_name.split(".")[:4])  # e.g., model.layers.0.self_attn

        importance = None
        for layer_name, var in neuron_importance.items():
            if layer_prefix in layer_name:
                importance = var
                break

        if importance is None:
            # Default: mark everything as transferable (safe to merge)
            masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
            continue

        # For 2D weights: importance determines which rows/columns to protect
        if param.dim() == 2:
            rows, cols = param.shape
            imp_size = importance.shape[0]

            # Compute threshold: top (1-threshold) fraction is task-specific
            if importance.numel() == 0:
                masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
            elif imp_size >= rows:
                # Importance covers the row dimension (e.g., 4096 importance, 4096Γ—4096 weight)
                imp = importance[:rows]
                q = torch.quantile(imp.float(), 1.0 - threshold)
                row_mask = imp < q  # [rows]
                masks[param_name] = row_mask.unsqueeze(1).expand(rows, cols)
            elif imp_size >= cols:
                # Importance covers the column dimension (e.g., 4096 importance, 12288Γ—4096 weight)
                # This happens for gate_proj, up_proj where rows=3Γ—hidden_dim
                imp = importance[:cols]
                q = torch.quantile(imp.float(), 1.0 - threshold)
                col_mask = imp < q  # [cols]
                masks[param_name] = col_mask.unsqueeze(0).expand(rows, cols)
            else:
                # Importance doesn't match either dimension β€” default to transferable
                masks[param_name] = torch.ones(param.shape, dtype=torch.bool)
        else:
            # 1D params (biases, norms): default to transferable
            masks[param_name] = torch.ones(param.shape, dtype=torch.bool)

    transferable = sum(m.sum().item() for m in masks.values())
    total = sum(m.numel() for m in masks.values())
    print(f"[otmf] Transferability: {transferable / total:.1%} transferable, {1 - transferable / total:.1%} task-specific")

    return masks


def apply_masked_merge(
    target_state: dict,
    fused_state: dict,
    masks: dict,
    protect_strength: float = 0.8,
) -> dict:
    """
    Apply transferability masks during merge.

    For transferable weights: use the fused (merged) value
    For task-specific weights: preserve more of the original target value

    Args:
        target_state: Original target weights (before this merge)
        fused_state: Newly fused weights (after T&M/Theseus fusion)
        masks: Transferability masks (True = safe to change)
        protect_strength: How much to protect task-specific weights (0-1)

    Returns:
        Masked merged state dict
    """
    result = {}

    for key in fused_state:
        if key in masks and key in target_state:
            mask = masks[key].to(fused_state[key].device)
            original = target_state[key]
            fused = fused_state[key]

            # Transferable: use fused value
            # Task-specific: blend more toward original
            blended = torch.where(
                mask,
                fused,  # Transferable β†’ take merged value
                protect_strength * original + (1 - protect_strength) * fused,  # Protected
            )
            result[key] = blended
        else:
            result[key] = fused_state[key]

    protected_params = sum(1 for k in masks if not masks[k].all())
    print(f"[otmf] Applied masks: {protected_params} parameters partially protected")

    return result


# ============================================================================
# 4. RAM β€” RL-Weight Disentanglement (2601.13572)
# ============================================================================
#
# RL-trained models (DeepSeek-R1, MiMo-7B-RL) have two types of knowledge:
#   - Shared: general language understanding (same as base model)
#   - RL-specific: reasoning patterns learned via GRPO/RLHF
#
# RAM separates these so we can merge the shared parts normally
# but PRESERVE the RL-specific parts that make these models special.

def disentangle_rl_weights(
    rl_model: AutoModelForCausalLM,
    base_model: AutoModelForCausalLM,
    rl_threshold: float = 0.1,
) -> tuple:
    """
    Separate RL-specific weights from shared/general weights.

    RL-specific = weights that changed significantly during RL training
    Shared = weights that are basically the same as base

    We identify RL-specific weights by looking at the magnitude of
    change from base model to RL model. Big changes β†’ RL learned
    something there β†’ don't average it away.

    Args:
        rl_model: The RL-trained model (e.g., DeepSeek-R1, MiMo-7B-RL)
        base_model: The base model before RL training
        rl_threshold: Relative change threshold for "RL-specific" classification

    Returns:
        Tuple of (shared_mask, rl_mask) β€” both are dicts of param_name β†’ bool tensor
        shared_mask: True = this weight is shared (safe to merge normally)
        rl_mask: True = this weight is RL-specific (protect during merge)
    """
    print("[ram] Disentangling RL-specific vs shared weights...")

    rl_state = rl_model.state_dict()
    base_state = base_model.state_dict()

    shared_mask = {}
    rl_mask = {}

    total_params = 0
    rl_params = 0

    for key in rl_state:
        if key not in base_state:
            # New param (e.g., MTP head) β€” mark as RL-specific
            rl_mask[key] = torch.ones_like(rl_state[key], dtype=torch.bool)
            shared_mask[key] = torch.zeros_like(rl_state[key], dtype=torch.bool)
            rl_params += rl_state[key].numel()
            total_params += rl_state[key].numel()
            continue

        rl_w = rl_state[key].float()
        base_w = base_state[key].float()

        # Relative change: |rl - base| / (|base| + epsilon)
        change = (rl_w - base_w).abs()
        base_magnitude = base_w.abs() + 1e-8
        relative_change = change / base_magnitude

        # RL-specific: relative change > threshold
        is_rl = relative_change > rl_threshold
        rl_mask[key] = is_rl
        shared_mask[key] = ~is_rl

        rl_params += is_rl.sum().item()
        total_params += is_rl.numel()

    pct = rl_params / total_params * 100 if total_params > 0 else 0
    print(f"[ram] RL-specific: {rl_params:,} params ({pct:.1f}%)")
    print(f"[ram] Shared:      {total_params - rl_params:,} params ({100 - pct:.1f}%)")

    return shared_mask, rl_mask


def merge_with_rl_preservation(
    target_state: dict,
    source_state: dict,
    shared_mask: dict,
    rl_mask: dict,
    shared_alpha: float = 0.5,
    rl_alpha: float = 0.8,
) -> dict:
    """
    Merge source into target while preserving RL-specific weights.

    Shared weights: normal blending at shared_alpha
    RL-specific weights: stronger blending toward source (preserve RL knowledge)

    This prevents the RL reasoning capabilities from being diluted
    by averaging with target weights.

    Args:
        target_state: Current target model state
        source_state: RL model state to merge in
        shared_mask: Which params are shared (safe for normal merge)
        rl_mask: Which params are RL-specific (preserve with higher alpha)
        shared_alpha: Alpha for shared weights (normal)
        rl_alpha: Alpha for RL-specific weights (higher = preserve more RL knowledge)
    """
    print(f"[ram] Merging with RL preservation (shared Ξ±={shared_alpha}, RL Ξ±={rl_alpha})...")

    result = {}
    for key in target_state:
        if key not in source_state:
            result[key] = target_state[key]
            continue

        target_w = target_state[key]
        source_w = source_state[key]

        if source_w.shape != target_w.shape:
            result[key] = target_state[key]
            continue

        if key in rl_mask and key in shared_mask:
            rl_m = rl_mask[key].to(target_w.device)
            # RL-specific: use higher alpha (preserve RL knowledge)
            # Shared: use normal alpha
            alpha_map = torch.where(rl_m, rl_alpha, shared_alpha)
            if alpha_map.shape != target_w.shape:
                alpha_map = alpha_map.expand_as(target_w) if alpha_map.dim() > 0 else torch.full_like(target_w, shared_alpha)

            result[key] = alpha_map * source_w.to(target_w.device) + (1 - alpha_map) * target_w
        else:
            result[key] = shared_alpha * source_w.to(target_w.device) + (1 - shared_alpha) * target_w

    return result


# ============================================================================
# 5. MERGEABILITY PRE-CHECK (2601.22285)
# ============================================================================
#
# Before spending GPU hours on a merge that might fail, check if the
# models are actually COMPATIBLE enough to merge.
#
# Mergeability score: 0.0 (definitely won't work) to 1.0 (should work great)

def compute_mergeability_score(
    source_activations: dict,
    target_activations: dict,
    source_config: ModelConfig,
) -> dict:
    """
    Predict how well a source model will merge into the target.

    Scores based on three factors:
    1. Activation similarity (cosine similarity of mean activations)
    2. Dimensional compatibility (how similar are the layer shapes)
    3. Architecture match (same arch = bonus)

    Returns:
        Dict with individual scores and overall mergeability (0-1)
    """
    print(f"[mergeability] Scoring {source_config.name}...")

    scores = {}

    # --- Factor 1: Activation similarity ---
    cosine_sims = []
    source_layers = sorted(source_activations.keys())
    target_layers = sorted(target_activations.keys())

    # Match layers by position (proportional mapping)
    for i, tl in enumerate(target_layers):
        # Map target layer index to source layer index
        src_idx = int(i * len(source_layers) / len(target_layers))
        src_idx = min(src_idx, len(source_layers) - 1)
        sl = source_layers[src_idx]

        if sl in source_activations and tl in target_activations:
            s_mean = source_activations[sl].float().mean(dim=0)
            t_mean = target_activations[tl].float().mean(dim=0)

            # Pad to same dimension for cosine similarity
            max_dim = max(s_mean.shape[0], t_mean.shape[0])
            s_padded = torch.nn.functional.pad(s_mean, (0, max_dim - s_mean.shape[0]))
            t_padded = torch.nn.functional.pad(t_mean, (0, max_dim - t_mean.shape[0]))

            cos_sim = torch.nn.functional.cosine_similarity(
                s_padded.unsqueeze(0), t_padded.unsqueeze(0)
            ).item()
            cosine_sims.append(cos_sim)

    activation_score = np.mean(cosine_sims) if cosine_sims else 0.0
    scores["activation_similarity"] = float(activation_score)

    # --- Factor 2: Dimensional compatibility ---
    layer_ratio = min(source_config.layers, 36) / max(source_config.layers, 36)
    hidden_ratio = min(source_config.hidden_dim, 4096) / max(source_config.hidden_dim, 4096)
    dim_score = (layer_ratio + hidden_ratio) / 2
    scores["dimensional_compatibility"] = float(dim_score)

    # --- Factor 3: Architecture match ---
    arch_scores = {
        "transformer": 1.0,       # Same as Qwen3
        "transformer+mtp": 0.8,   # Close, just drop extras
        "hybrid_ssm": 0.5,        # Very different
    }
    arch_score = arch_scores.get(source_config.architecture, 0.3)
    scores["architecture_match"] = float(arch_score)

    # --- Factor 4: Vocab overlap (bonus) ---
    vocab_score = source_config.vocab_overlap_with_qwen3
    scores["vocab_overlap"] = float(vocab_score)

    # --- Overall: weighted average ---
    overall = (
        0.35 * activation_score +      # Most important β€” actual representation similarity
        0.25 * dim_score +              # Shape compatibility
        0.25 * arch_score +             # Architecture type
        0.15 * vocab_score              # Vocab overlap
    )
    scores["overall"] = float(overall)

    # --- Recommendation ---
    if overall >= 0.7:
        recommendation = "GO β€” standard T&M merge"
    elif overall >= 0.5:
        recommendation = "CAUTION β€” T&M merge with higher protection, have Theseus fallback ready"
    elif overall >= 0.3:
        recommendation = "RISKY β€” try Theseus first, distillation fallback"
    else:
        recommendation = "SKIP β€” use knowledge distillation instead"

    scores["recommendation"] = recommendation

    print(f"[mergeability] {source_config.name} score: {overall:.2f}")
    print(f"  Activation similarity: {activation_score:.2f}")
    print(f"  Dimensional compat:    {dim_score:.2f}")
    print(f"  Architecture match:    {arch_score:.2f}")
    print(f"  Vocab overlap:         {vocab_score:.2f}")
    print(f"  β†’ {recommendation}")

    return scores