File size: 40,100 Bytes
5d61448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
"""
Sequential Merge Orchestrator β€” chains 4 merges with protection.

This is the brain of td_lang engine. It runs each merge in order:
    1. Load source model
    2. Inject canary fact into source
    3. Extract activations from both models
    4. Compute transport plans (P and Q matrices)
    5. Fuse weights using optimal transport
    6. Validate merged model (canary recall, perplexity, thinking mode)
    7. Apply sequential merge protection before next merge
    8. Checkpoint

Protection between merges (findings #13):
    - MagMax: Protect top 20% parameters by magnitude (they carry critical knowledge)
    - Orthogonal Projection: Project new merge deltas perpendicular to previous ones
    - Time-Aware Scaling: scale = 1/sqrt(merge_index + 1)

Kill criteria: >10% performance drop on any test β†’ abort merge.
Findings: #13, #22, #25
"""

import os
import gc
import copy
import torch
import numpy as np
from pathlib import Path
from typing import Optional
from transformers import AutoModelForCausalLM, AutoTokenizer

from .config import (
    MergeConfig, ModelConfig, TARGET, SOURCES,
    CANARY_FACTS, DEMO_STAGES, FULL_STAGES,
)
from .canary import inject_canary, test_all_canaries
from .transport import (
    setup_tm_repo,
    load_calibration_data,
    extract_activations,
    compute_transport_plans,
    fuse_weights,
)
from .validate import validate_merged_model, compute_perplexity
from .techniques import (
    compute_mergeability_score,
    compute_transferability_masks,
    apply_masked_merge,
    disentangle_rl_weights,
    merge_with_rl_preservation,
    compute_arm_rotation,
    apply_arm_steering,
    transport_task_vector_theseus,
    compute_procrustes_alignment,
)


# ============================================================================
# SEQUENTIAL MERGE PROTECTION
# ============================================================================

class MergeProtection:
    """
    Protects previously merged knowledge from being overwritten.

    Think of it like this: after merging DeepSeek into Qwen3, we have
    a "direction" in weight space that represents that merge. When we
    then merge MiMo, we want MiMo's changes to go in a DIFFERENT direction,
    not overwrite DeepSeek's contribution.

    Three mechanisms:
    1. MagMax: Top 20% magnitude params are "locked" β€” new merges can't change them much
    2. Orthogonal Projection: New deltas are projected perpendicular to previous deltas
    3. Time-Aware Scaling: Each successive merge gets a smaller alpha (1/sqrt(n+1))
    """

    def __init__(self, cfg: MergeConfig):
        self.cfg = cfg
        self.previous_deltas = {}  # key β†’ list of delta tensors from previous merges
        self.magnitude_masks = {}  # key β†’ bool mask of top-k magnitude params
        self.arm_rotations = {}    # ARM: layer β†’ rotation info from last merge
        self.otmf_masks = {}       # OTMF: param β†’ transferability mask
        self.merge_count = 0

    def before_merge(
        self,
        target_model: AutoModelForCausalLM,
        source_config: ModelConfig,
    ) -> float:
        """
        Prepare protection before a merge. Returns adjusted alpha.

        Called BEFORE each merge to:
        1. Compute magnitude masks (MagMax)
        2. Calculate time-aware alpha scaling
        """
        # Time-aware scaling: each merge gets less aggressive
        if self.cfg.time_aware_scaling:
            scale = 1.0 / np.sqrt(self.merge_count + 1)
            adjusted_alpha = source_config.merge_alpha * scale
            print(f"[protect] Time-aware scaling: {source_config.merge_alpha:.2f} Γ— {scale:.3f} = {adjusted_alpha:.3f}")
        else:
            adjusted_alpha = source_config.merge_alpha

        # MagMax: identify top 20% magnitude parameters to protect
        if self.cfg.use_magmax and self.merge_count > 0:
            print(f"[protect] Computing MagMax masks (protecting top 20% by magnitude)...")
            state = target_model.state_dict()
            for key, param in state.items():
                if param.dim() >= 1:
                    flat = param.abs().flatten()
                    threshold = torch.quantile(flat.float(), 0.8)
                    self.magnitude_masks[key] = param.abs() >= threshold

        return adjusted_alpha

    def apply_protection(
        self,
        target_state: dict,
        pre_merge_state: dict,
        key: str,
    ) -> torch.Tensor:
        """
        Apply all protection mechanisms to a fused parameter.

        Called AFTER each parameter is fused, to constrain the change.

        Protection stack (applied in order):
        1. ARM steering (2602.03237) β€” steer delta toward gap, away from previous direction
        2. Orthogonal projection (legacy fallback if ARM disabled)
        3. OTMF masks (2511.19561) β€” protect task-specific weights
        4. MagMax β€” protect top magnitude params (extra safety layer)
        """
        fused = target_state[key]
        original = pre_merge_state[key]
        delta = fused - original

        # --- ARM Steering (new, replaces orthogonal projection) ---
        if self.cfg.use_arm_steering and self.arm_rotations:
            # Find matching layer rotation
            layer_prefix = ".".join(key.split(".")[:4])
            for layer_name, rotation_info in self.arm_rotations.items():
                if layer_prefix in layer_name:
                    delta = apply_arm_steering(
                        delta, rotation_info,
                        steering_strength=self.cfg.arm_steering_strength,
                    )
                    break

        # --- Orthogonal Projection (legacy fallback) ---
        elif self.cfg.use_orthogonal_projection and key in self.previous_deltas:
            for prev_delta in self.previous_deltas[key]:
                prev_flat = prev_delta.flatten().float()
                delta_flat = delta.flatten().float()

                dot = torch.dot(delta_flat, prev_flat)
                norm_sq = torch.dot(prev_flat, prev_flat)

                if norm_sq > 1e-10:
                    projection = (dot / norm_sq) * prev_flat
                    delta_flat = delta_flat - projection
                    delta = delta_flat.reshape(delta.shape).to(delta.dtype)

        # --- OTMF Mask Protection (new) ---
        if self.cfg.use_otmf_masks and key in self.otmf_masks:
            mask = self.otmf_masks[key].to(delta.device)
            # Transferable weights: full delta
            # Task-specific weights: reduced delta (protect them)
            delta = torch.where(
                mask,
                delta,  # Transferable β†’ allow full change
                delta * (1.0 - self.cfg.otmf_protect_strength),  # Protected β†’ reduced
            )

        # --- MagMax Protection (extra safety layer) ---
        if self.cfg.use_magmax and key in self.magnitude_masks:
            mask = self.magnitude_masks[key]
            delta = torch.where(mask, delta * 0.1, delta)

        # Apply constrained delta
        result = original + delta

        return result

    def after_merge(
        self,
        target_model: AutoModelForCausalLM,
        pre_merge_state: dict,
        pre_merge_activations: dict = None,
        post_merge_activations: dict = None,
    ):
        """
        Record the merge delta and compute protections for next merge.

        Called AFTER each merge completes successfully.
        Now also computes:
        - ARM rotation vectors for next merge steering
        - OTMF transferability masks for next merge
        """
        current_state = target_model.state_dict()

        for key in current_state:
            if key in pre_merge_state:
                delta = current_state[key].float() - pre_merge_state[key].float()
                if delta.abs().max() > 1e-8:
                    if key not in self.previous_deltas:
                        self.previous_deltas[key] = []
                    if len(self.previous_deltas[key]) >= 2:
                        self.previous_deltas[key].pop(0)
                    self.previous_deltas[key].append(delta.cpu())

        # --- Compute ARM rotations for next merge ---
        if self.cfg.use_arm_steering and pre_merge_activations and post_merge_activations:
            print("[protect] Computing ARM rotation vectors for next merge...")
            self.arm_rotations = compute_arm_rotation(
                pre_merge_activations,
                post_merge_activations,
                post_merge_activations,  # Target = current state (for gap calculation)
            )

        # --- Compute OTMF masks for next merge ---
        if self.cfg.use_otmf_masks and post_merge_activations:
            print("[protect] Computing OTMF transferability masks...")
            self.otmf_masks = compute_transferability_masks(
                target_model,
                post_merge_activations,
                threshold=self.cfg.otmf_threshold,
            )

        self.merge_count += 1
        print(f"[protect] Recorded merge delta #{self.merge_count} (ARM + OTMF ready for next)")


# ============================================================================
# MAIN ORCHESTRATOR
# ============================================================================

def is_vision_param(key: str, cfg: MergeConfig) -> bool:
    """
    Check if a parameter belongs to the vision encoder.

    Qwen3-VL-8B has a ViT vision encoder + merger projection on top of the
    language model. We NEVER touch these during merging β€” they give us
    browser agent and image understanding abilities for free.

    Vision params start with prefixes like "visual." or "merger."
    Language params start with "model.layers." or "model.embed_tokens." etc.
    """
    for prefix in cfg.vision_skip_prefixes:
        if key.startswith(prefix):
            return True
    return False


def get_source_by_stage(stage_name: str) -> Optional[ModelConfig]:
    """Get model config by stage name."""
    stage_map = {
        "deepseek": 0,
        "mimo": 1,
        "llama": 2,
        "falcon": 3,
    }
    idx = stage_map.get(stage_name.lower())
    if idx is not None and idx < len(SOURCES):
        return SOURCES[idx]
    return None


def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
    """Load a model and its tokenizer/processor."""
    print(f"\n[merge] Loading {config.name} ({config.hf_id})...")

    # Qwen3-VL uses a processor (handles both text + vision), not just a tokenizer
    if config.architecture == "transformer+vision":
        try:
            from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
            processor = AutoProcessor.from_pretrained(
                config.hf_id,
                trust_remote_code=config.trust_remote_code,
            )
            model = Qwen3VLForConditionalGeneration.from_pretrained(
                config.hf_id,
                torch_dtype=getattr(torch, cfg.dtype),
                attn_implementation=cfg.attn_implementation,
                device_map=cfg.device_map,
                trust_remote_code=config.trust_remote_code,
            )
            # Use the tokenizer from the processor for text operations
            tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor
            print(f"[merge] Loaded {config.name} (VL model): {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")

            # Count vision vs language params
            vision_params = sum(
                p.numel() for n, p in model.named_parameters()
                if any(n.startswith(pfx) for pfx in cfg.vision_skip_prefixes)
            )
            lang_params = sum(p.numel() for p in model.parameters()) - vision_params
            print(f"[merge]   Language: {lang_params / 1e9:.1f}B  |  Vision: {vision_params / 1e9:.1f}B")

            return model, tokenizer
        except ImportError:
            print("[merge] Qwen3VLForConditionalGeneration not available, falling back to AutoModel")

    # Standard text-only models
    tokenizer = AutoTokenizer.from_pretrained(
        config.hf_id,
        trust_remote_code=config.trust_remote_code,
    )

    model = AutoModelForCausalLM.from_pretrained(
        config.hf_id,
        torch_dtype=getattr(torch, cfg.dtype),
        attn_implementation=cfg.attn_implementation,
        device_map=cfg.device_map,
        trust_remote_code=config.trust_remote_code,
    )

    print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
    return model, tokenizer


def save_checkpoint(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    stage_name: str,
    cfg: MergeConfig,
):
    """Save a checkpoint after a successful merge stage."""
    ckpt_dir = Path(cfg.checkpoint_dir) / f"after_{stage_name}"
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    print(f"[merge] Saving checkpoint to {ckpt_dir}...")
    model.save_pretrained(ckpt_dir)
    tokenizer.save_pretrained(ckpt_dir)
    print(f"[merge] Checkpoint saved: {ckpt_dir}")

    return str(ckpt_dir)


# ============================================================================
# RESIDUAL BANK β€” Save what was lost during each merge
# ============================================================================

class ResidualBank:
    """
    Saves the knowledge that gets lost during each merge so it can
    be recovered later.

    When we blend at alpha=0.10:
        merged = target + alpha * M * (transported - target)

    We LOSE:
        target_residual = target_original - merged  (what target lost)
        source_residual = source_original - merged  (what source lost)

    These residuals are saved to disk. Later they can be:
    1. Fed back during the healing fine-tune (as training signal)
    2. Re-injected via a small LoRA adapter
    3. Used to diagnose which merge caused a specific knowledge loss
    4. Re-applied at a lower alpha if we want more of that model

    Think of it like saving the sawdust when you cut wood β€” you might
    need to glue some of it back later.
    """

    def __init__(self, cfg: MergeConfig):
        self.cfg = cfg
        self.residual_dir = Path(cfg.checkpoint_dir) / "residuals"
        self.residual_dir.mkdir(parents=True, exist_ok=True)
        self.residual_index = {}  # stage β†’ {path, stats}

    def save_residuals(
        self,
        stage_name: str,
        pre_merge_target_state: dict,
        source_state: dict,
        post_merge_state: dict,
        source_config: ModelConfig,
    ):
        """
        Compute and save what was lost from both target and source.

        Saves two files per merge stage:
        - target_residual: what the target model lost
        - source_residual: what the source model didn't fully contribute

        Also saves stats so we know WHERE the biggest losses were
        (which layers, which type of weights).
        """
        stage_dir = self.residual_dir / stage_name
        stage_dir.mkdir(parents=True, exist_ok=True)

        target_residual = {}
        source_residual = {}
        stats = {
            "stage": stage_name,
            "source_model": source_config.name,
            "target_loss_by_layer": {},
            "source_loss_by_layer": {},
            "total_target_loss": 0.0,
            "total_source_loss": 0.0,
            "biggest_losses": [],
        }

        for key in post_merge_state:
            merged_w = post_merge_state[key].float()

            # What the target lost
            if key in pre_merge_target_state:
                original_target = pre_merge_target_state[key].float()
                t_residual = original_target - merged_w
                t_loss = t_residual.abs().mean().item()

                if t_loss > 1e-6:  # Only save meaningful residuals
                    target_residual[key] = t_residual.to(torch.bfloat16).cpu()
                    stats["total_target_loss"] += t_loss

                    # Track per-layer losses
                    layer_name = ".".join(key.split(".")[:4])
                    if layer_name not in stats["target_loss_by_layer"]:
                        stats["target_loss_by_layer"][layer_name] = 0.0
                    stats["target_loss_by_layer"][layer_name] += t_loss

            # What the source lost (what didn't make it into the merge)
            if key in source_state:
                original_source = source_state[key].float()
                s_residual = original_source - merged_w
                s_loss = s_residual.abs().mean().item()

                if s_loss > 1e-6:
                    source_residual[key] = s_residual.to(torch.bfloat16).cpu()
                    stats["total_source_loss"] += s_loss

                    layer_name = ".".join(key.split(".")[:4])
                    if layer_name not in stats["source_loss_by_layer"]:
                        stats["source_loss_by_layer"][layer_name] = 0.0
                    stats["source_loss_by_layer"][layer_name] += s_loss

        # Find the biggest losses (most knowledge dropped)
        all_losses = []
        for key in target_residual:
            loss_magnitude = target_residual[key].float().abs().mean().item()
            all_losses.append({"param": key, "side": "target", "loss": loss_magnitude})
        for key in source_residual:
            loss_magnitude = source_residual[key].float().abs().mean().item()
            all_losses.append({"param": key, "side": "source", "loss": loss_magnitude})
        all_losses.sort(key=lambda x: x["loss"], reverse=True)
        stats["biggest_losses"] = all_losses[:20]  # Top 20 biggest losses

        # Save to disk
        torch.save(target_residual, stage_dir / "target_residual.pt")
        torch.save(source_residual, stage_dir / "source_residual.pt")

        import json
        with open(stage_dir / "residual_stats.json", "w") as f:
            json.dump(stats, f, indent=2, default=str)

        self.residual_index[stage_name] = {
            "path": str(stage_dir),
            "target_params_saved": len(target_residual),
            "source_params_saved": len(source_residual),
            "total_target_loss": stats["total_target_loss"],
            "total_source_loss": stats["total_source_loss"],
        }

        print(f"[residual] Saved residuals for {stage_name}:")
        print(f"  Target lost: {len(target_residual)} params (avg loss: {stats['total_target_loss']:.4f})")
        print(f"  Source lost: {len(source_residual)} params (avg loss: {stats['total_source_loss']:.4f})")
        print(f"  Top loss: {all_losses[0]['param']} ({all_losses[0]['side']}, {all_losses[0]['loss']:.4f})" if all_losses else "")
        print(f"  Saved to: {stage_dir}")

    def load_residuals(self, stage_name: str) -> tuple:
        """
        Load saved residuals for a stage.

        Returns:
            (target_residual_dict, source_residual_dict)
        """
        stage_dir = self.residual_dir / stage_name
        target_residual = torch.load(stage_dir / "target_residual.pt", weights_only=True)
        source_residual = torch.load(stage_dir / "source_residual.pt", weights_only=True)
        return target_residual, source_residual

    def reinject_residuals(
        self,
        model: AutoModelForCausalLM,
        stage_name: str,
        side: str = "both",
        strength: float = 0.3,
    ) -> AutoModelForCausalLM:
        """
        Re-inject saved residuals back into a model.

        This adds back some of what was lost. Use a low strength (0.1-0.3)
        to gently recover knowledge without undoing the merge.

        Args:
            model: The model to inject into
            stage_name: Which merge stage's residuals to use
            side: "target", "source", or "both"
            strength: How much to add back (0=nothing, 1=full residual)
        """
        print(f"[residual] Re-injecting {stage_name} residuals (side={side}, strength={strength})...")

        target_residual, source_residual = self.load_residuals(stage_name)
        state = model.state_dict()
        injected = 0

        if side in ("target", "both"):
            for key, residual in target_residual.items():
                if key in state:
                    state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
                    injected += 1

        if side in ("source", "both"):
            for key, residual in source_residual.items():
                if key in state:
                    state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
                    injected += 1

        model.load_state_dict(state)
        print(f"[residual] Re-injected {injected} params at {strength:.0%} strength")
        return model

    def get_healing_targets(self, top_n: int = 50) -> list:
        """
        Get the parameters with the biggest losses across ALL merges.

        These are the params that the healing fine-tune should focus on.
        Feed this to the LoRA target_modules to make healing smarter.
        """
        import json
        all_losses = []

        for stage_name in self.residual_index:
            stage_dir = self.residual_dir / stage_name
            stats_file = stage_dir / "residual_stats.json"
            if stats_file.exists():
                with open(stats_file) as f:
                    stats = json.load(f)
                for loss in stats.get("biggest_losses", []):
                    loss["stage"] = stage_name
                    all_losses.append(loss)

        all_losses.sort(key=lambda x: x["loss"], reverse=True)

        # Extract unique layer/module names for LoRA targeting
        target_modules = set()
        for loss in all_losses[:top_n]:
            param = loss["param"]
            # Extract the module type (q_proj, k_proj, gate_proj, etc.)
            parts = param.split(".")
            for part in parts:
                if part.endswith("_proj") or part in ("gate_proj", "up_proj", "down_proj"):
                    target_modules.add(part)

        print(f"[residual] Top healing targets (from {len(all_losses)} total losses):")
        for loss in all_losses[:5]:
            print(f"  {loss['param']} ({loss['side']}, stage={loss['stage']}, loss={loss['loss']:.4f})")
        print(f"  β†’ Suggested LoRA targets: {sorted(target_modules)}")

        return list(target_modules)


def run_single_merge(
    target_model: AutoModelForCausalLM,
    target_tokenizer: AutoTokenizer,
    source_config: ModelConfig,
    cfg: MergeConfig,
    protection: MergeProtection,
    residual_bank: ResidualBank = None,
    calibration_data: list = None,
    baseline_perplexity: float = None,
    merged_sources: list = None,
) -> dict:
    """
    Run a single merge: source β†’ target.

    Full pipeline for one merge step:
    1. Load source model
    2. Inject canary into source
    3. Extract activations from both
    4. Compute transport plans
    5. Apply merge protection
    6. Fuse weights
    7. Apply post-merge protection
    8. Validate

    Returns:
        Dict with merge results, validation results, and status
    """
    if merged_sources is None:
        merged_sources = []

    stage_name = source_config.name
    print(f"\n{'=' * 70}")
    print(f"MERGE STAGE: {stage_name} β†’ target")
    print(f"Risk level: {source_config.merge_risk.upper()}")
    print(f"{'=' * 70}")

    result = {
        "stage": stage_name,
        "status": "pending",
        "validation": None,
        "checkpoint": None,
    }

    # --- Step 1: Load source model ---
    source_model, source_tokenizer = load_model(source_config, cfg)

    # --- Step 2: Inject canary into source ---
    if stage_name in CANARY_FACTS:
        print(f"\n[merge] Injecting canary fact into {stage_name}...")
        source_model = inject_canary(source_model, source_tokenizer, stage_name)

    # --- Step 3: Load calibration data (if not provided) ---
    if calibration_data is None:
        calibration_data = load_calibration_data(cfg, target_tokenizer)

    # --- Step 4: Extract two-sided activations (pre + post per projection) ---
    print(f"\n[merge] Extracting source activations (two-sided)...")
    source_activations = extract_activations(source_model, calibration_data)

    print(f"\n[merge] Extracting target activations (two-sided)...")
    pre_merge_target_activations = extract_activations(target_model, calibration_data)

    # --- Step 4.5: Mergeability pre-check (2601.22285) ---
    if cfg.use_mergeability_check:
        mergeability = compute_mergeability_score(
            source_activations, pre_merge_target_activations, source_config
        )
        result["mergeability"] = mergeability

        if mergeability["overall"] < cfg.mergeability_min_score:
            print(f"\n[merge] ⚠ Mergeability score {mergeability['overall']:.2f} below threshold {cfg.mergeability_min_score}")
            print(f"[merge] β†’ {mergeability['recommendation']}")
            result["status"] = "skipped_low_mergeability"
            if "distillation_fallback" in source_config.special_handling:
                result["fallback"] = "distillation"
            del source_model, source_activations, pre_merge_target_activations
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            return result

    # --- Step 5: Compute transport plans ---
    transport_plans = compute_transport_plans(
        source_activations, pre_merge_target_activations, cfg
    )

    # --- Step 5.5: RAM RL-weight disentanglement (2601.13572) ---
    use_ram = (
        cfg.use_ram_disentangle
        and source_config.architecture in ("transformer", "transformer+mtp")
        and source_config.merge_risk in ("low", "medium")
        and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"])
    )

    # --- Step 6: Pre-merge protection ---
    adjusted_alpha = protection.before_merge(target_model, source_config)

    # Override source alpha with time-adjusted value
    source_config_adjusted = copy.copy(source_config)
    source_config_adjusted.merge_alpha = adjusted_alpha

    # Save pre-merge state for protection
    pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()}

    # --- Step 7: Fuse weights ---
    if use_ram:
        # RAM path: disentangle RL weights, merge with preservation
        print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
        try:
            # Try loading the base (pre-RL) model for disentanglement
            base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
            print(f"[merge] Loading base model for RAM: {base_hf_id}")
            base_model = AutoModelForCausalLM.from_pretrained(
                base_hf_id,
                torch_dtype=getattr(torch, cfg.dtype),
                device_map=cfg.device_map,
                trust_remote_code=source_config.trust_remote_code,
            )
            shared_mask, rl_mask = disentangle_rl_weights(
                source_model, base_model, cfg.ram_rl_threshold
            )
            # Fuse with RL preservation
            target_state = merge_with_rl_preservation(
                target_model.state_dict(),
                source_model.state_dict(),
                shared_mask, rl_mask,
                shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha),
                rl_alpha=cfg.ram_rl_alpha,
            )
            target_model.load_state_dict(target_state)
            del base_model
            print(f"[merge] RAM merge complete for {stage_name}")
        except Exception as e:
            print(f"[merge] RAM failed ({e}), falling back to standard T&M merge")
            target_model = fuse_weights(
                source_model, target_model, transport_plans,
                source_config_adjusted, cfg,
                target_activations=pre_merge_target_activations,
            )
    else:
        # Standard T&M path (two-sided + top-k masked fusion, paper Eq 14)
        target_model = fuse_weights(
            source_model, target_model, transport_plans,
            source_config_adjusted, cfg,
            target_activations=pre_merge_target_activations,
        )

    # --- Step 7.5: Theseus fallback check (2602.12952) ---
    # If T&M merge produced poor activation alignment, try Theseus
    if cfg.use_theseus_fallback and source_config.merge_risk == "high":
        print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...")
        post_activations = extract_activations(target_model, calibration_data[:50])  # Quick check
        # Compare post-merge activations to pre-merge β€” if too similar, T&M didn't work
        alignment_scores = []
        for key in post_activations:
            if key in pre_merge_target_activations:
                cos = torch.nn.functional.cosine_similarity(
                    post_activations[key].float().mean(0, keepdim=True),
                    pre_merge_target_activations[key].float().mean(0, keepdim=True),
                )
                alignment_scores.append(cos.item())
        avg_change = 1.0 - np.mean(alignment_scores) if alignment_scores else 0.0
        print(f"[merge] Activation change from merge: {avg_change:.4f}")

        if avg_change < 0.01:
            print(f"[merge] ⚠ T&M had minimal effect β€” activating Theseus fallback")
            # Restore pre-merge state and try Theseus instead
            target_model.load_state_dict(pre_merge_state)
            try:
                base_model = AutoModelForCausalLM.from_pretrained(
                    source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0],
                    torch_dtype=getattr(torch, cfg.dtype),
                    device_map=cfg.device_map,
                    trust_remote_code=source_config.trust_remote_code,
                )
                target_model = transport_task_vector_theseus(
                    source_model, base_model, target_model,
                    source_activations, pre_merge_target_activations,
                    alpha=cfg.theseus_alpha,
                )
                del base_model
                print(f"[merge] Theseus transport complete for {stage_name}")
            except Exception as e:
                print(f"[merge] Theseus also failed ({e}). Using original T&M result.")
                # Re-apply T&M result
                target_model = fuse_weights(
                    source_model, target_model, transport_plans,
                    source_config_adjusted, cfg,
                    target_activations=pre_merge_target_activations,
                )

    # --- Step 8: Apply post-merge protection (ARM + OTMF + MagMax) ---
    # Skip vision encoder params β€” they weren't merged, so don't "protect" them
    if protection.merge_count > 0:
        print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...")
        target_state = target_model.state_dict()
        protected_count = 0
        vision_skipped = 0
        for key in target_state:
            if is_vision_param(key, cfg):
                vision_skipped += 1
                continue  # Don't touch vision encoder
            if key in pre_merge_state:
                protected_param = protection.apply_protection(
                    target_state, pre_merge_state, key
                )
                target_state[key] = protected_param
                protected_count += 1
        target_model.load_state_dict(target_state)
        print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)")

    # --- Step 8.5: Extract post-merge activations for ARM/OTMF ---
    post_merge_activations = extract_activations(target_model, calibration_data[:100])

    # Record this merge's delta + compute ARM/OTMF for next merge
    protection.after_merge(
        target_model, pre_merge_state,
        pre_merge_activations=pre_merge_target_activations,
        post_merge_activations=post_merge_activations,
    )

    # --- Step 8.8: Save residuals (what was lost from both sides) ---
    if residual_bank is not None:
        print(f"\n[merge] Saving residuals for {stage_name}...")
        residual_bank.save_residuals(
            stage_name=stage_name,
            pre_merge_target_state=pre_merge_state,
            source_state={k: v.cpu() for k, v in source_model.state_dict().items()},
            post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()},
            source_config=source_config,
        )

    # --- Step 9: Free source model memory ---
    del source_model, source_activations, pre_merge_target_activations
    del transport_plans, post_merge_activations
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # --- Step 10: Validate ---
    merged_sources.append(stage_name)
    validation = validate_merged_model(
        target_model, target_tokenizer,
        merged_sources, cfg,
        baseline_perplexity=baseline_perplexity,
    )

    result["validation"] = validation
    result["merged_sources"] = merged_sources.copy()

    # --- Kill criteria check ---
    if not validation["overall"]:
        print(f"\n[merge] ⚠ VALIDATION FAILED for {stage_name}")
        print(f"[merge] Kill criteria triggered β€” consider aborting")
        result["status"] = "failed"

        # Check if we should try distillation fallback
        if "distillation_fallback" in source_config.special_handling:
            print(f"[merge] {stage_name} has distillation fallback available")
            result["fallback"] = "distillation"
    else:
        print(f"\n[merge] βœ“ {stage_name} merge PASSED validation")
        result["status"] = "passed"

    return result


def run_pipeline(
    stages: list[str],
    cfg: MergeConfig = None,
) -> dict:
    """
    Run the full merge pipeline.

    Args:
        stages: List of stage names to run, e.g. ["deepseek"] or
                ["deepseek", "mimo", "llama", "falcon"]
        cfg: Merge configuration (uses defaults if None)

    Returns:
        Dict with overall results, per-stage results, and final model path
    """
    if cfg is None:
        cfg = MergeConfig()

    print("\n" + "=" * 70)
    print("TD LANG ENGINE β€” Transport and Merge Pipeline")
    print(f"Target: {TARGET.name} ({TARGET.hf_id})")
    if TARGET.architecture == "transformer+vision":
        print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)")
    print(f"Stages: {', '.join(stages)}")
    print(f"Output: {cfg.output_dir}")
    print("=" * 70)

    # Setup
    try:
        setup_tm_repo(cfg)
    except FileNotFoundError as e:
        print(f"\n⚠ {e}")
        print("Continuing with fallback implementation...")

    # Create output directories
    Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
    Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)

    # --- Load target model ---
    target_model, target_tokenizer = load_model(TARGET, cfg)

    # --- Inject canary into target (Qwen3's own canary) ---
    if "Qwen3-VL-8B" in CANARY_FACTS:
        print("\n[pipeline] Injecting canary into base Qwen3-8B...")
        target_model = inject_canary(target_model, target_tokenizer, "Qwen3-VL-8B")

    # --- Compute baseline perplexity ---
    print("\n[pipeline] Computing baseline perplexity...")
    baseline_ppl = compute_perplexity(target_model, target_tokenizer)
    print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}")

    # --- Load calibration data once ---
    calibration_data = load_calibration_data(cfg, target_tokenizer)

    # --- Initialize merge protection + residual bank ---
    protection = MergeProtection(cfg)
    residual_bank = ResidualBank(cfg)

    # --- Run each merge stage ---
    pipeline_results = {
        "stages": {},
        "baseline_perplexity": baseline_ppl,
        "final_checkpoint": None,
        "residuals": {},
        "overall_status": "pending",
    }
    merged_sources = []
    all_passed = True

    for stage_name in stages:
        source_config = get_source_by_stage(stage_name)
        if source_config is None:
            print(f"\n⚠ Unknown stage: {stage_name}, skipping")
            continue

        # --- Wasserstein pre-check for high-risk models ---
        if "check_wasserstein_first" in source_config.special_handling:
            print(f"\n[pipeline] Running Wasserstein pre-check for {source_config.name}...")
            # TODO: Implement Wasserstein distance pre-check
            # If distance is too high, skip to distillation fallback
            print("[pipeline] Pre-check: proceeding (TODO: implement distance check)")

        # Run the merge (with residual bank to save what's lost)
        stage_result = run_single_merge(
            target_model, target_tokenizer,
            source_config, cfg,
            protection,
            residual_bank=residual_bank,
            calibration_data=calibration_data,
            baseline_perplexity=baseline_ppl,
            merged_sources=merged_sources,
        )

        pipeline_results["stages"][stage_name] = stage_result

        if stage_result["status"] == "passed":
            # Save checkpoint
            ckpt_path = save_checkpoint(
                target_model, target_tokenizer, stage_name, cfg
            )
            stage_result["checkpoint"] = ckpt_path
            pipeline_results["final_checkpoint"] = ckpt_path
        else:
            all_passed = False
            print(f"\n[pipeline] Stage {stage_name} FAILED")

            # Decision: abort or continue?
            if source_config.merge_risk == "high":
                print(f"[pipeline] High-risk model failed β€” skipping (will use distillation)")
                # Don't abort the whole pipeline, just skip this model
                continue
            else:
                print(f"[pipeline] ABORTING pipeline β€” non-high-risk model failed")
                pipeline_results["overall_status"] = f"aborted_at_{stage_name}"
                break

    # --- Save residual index ---
    pipeline_results["residuals"] = residual_bank.residual_index
    if residual_bank.residual_index:
        print(f"\n[pipeline] Residual bank: {len(residual_bank.residual_index)} stages saved")
        for stage, info in residual_bank.residual_index.items():
            print(f"  {stage}: target lost {info['total_target_loss']:.4f}, source lost {info['total_source_loss']:.4f}")

        # Identify which modules need the most healing
        healing_targets = residual_bank.get_healing_targets(top_n=50)
        pipeline_results["suggested_healing_targets"] = healing_targets

    # --- Save final model ---
    if pipeline_results["final_checkpoint"]:
        final_dir = Path(cfg.output_dir) / "final"
        final_dir.mkdir(parents=True, exist_ok=True)
        target_model.save_pretrained(final_dir)
        target_tokenizer.save_pretrained(final_dir)
        pipeline_results["final_model_path"] = str(final_dir)
        print(f"\n[pipeline] Final model saved to {final_dir}")

    if all_passed:
        pipeline_results["overall_status"] = "all_passed"
    elif pipeline_results["overall_status"] == "pending":
        pipeline_results["overall_status"] = "partial"

    # --- Print final summary ---
    print("\n" + "=" * 70)
    print("PIPELINE SUMMARY")
    print("=" * 70)
    for stage_name, stage_result in pipeline_results["stages"].items():
        status = stage_result["status"]
        emoji = "βœ“" if status == "passed" else "βœ—"
        print(f"  {emoji} {stage_name}: {status}")
    print(f"\n  Overall: {pipeline_results['overall_status']}")
    if residual_bank.residual_index:
        print(f"\n  Residuals saved for: {', '.join(residual_bank.residual_index.keys())}")
        print(f"  To recover lost knowledge later:")
        print(f"    python -m td_lang.engine --reinject <stage> --strength 0.2")
    print("=" * 70)

    return pipeline_results