| """ |
| Sequential Merge Orchestrator — chains 4 merges with protection. |
| |
| This is the brain of td_lang engine. It runs each merge in order: |
| 1. Load source model |
| 2. Inject canary fact into source |
| 3. Extract activations from both models |
| 4. Compute transport plans (P and Q matrices) |
| 5. Fuse weights using optimal transport |
| 6. Validate merged model (canary recall, perplexity, thinking mode) |
| 7. Apply sequential merge protection before next merge |
| 8. Checkpoint |
| |
| Protection between merges (findings #13): |
| - MagMax: Protect top 20% parameters by magnitude (they carry critical knowledge) |
| - Orthogonal Projection: Project new merge deltas perpendicular to previous ones |
| - Time-Aware Scaling: scale = 1/sqrt(merge_index + 1) |
| |
| Kill criteria: >10% performance drop on any test → abort merge. |
| Findings: #13, #22, #25 |
| """ |
|
|
| import os |
| import gc |
| import copy |
| import torch |
| import numpy as np |
| from pathlib import Path |
| from typing import Optional |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from .config import ( |
| MergeConfig, ModelConfig, TARGET, SOURCES, |
| CANARY_FACTS, DEMO_STAGES, FULL_STAGES, |
| ) |
| from .canary import inject_canary, test_all_canaries |
| from .transport import ( |
| setup_tm_repo, |
| load_calibration_data, |
| extract_activations, |
| compute_transport_plans, |
| fuse_weights, |
| ) |
| from .validate import validate_merged_model, compute_perplexity |
| from .techniques import ( |
| compute_mergeability_score, |
| compute_transferability_masks, |
| apply_masked_merge, |
| disentangle_rl_weights, |
| merge_with_rl_preservation, |
| compute_arm_rotation, |
| apply_arm_steering, |
| transport_task_vector_theseus, |
| compute_procrustes_alignment, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class MergeProtection: |
| """ |
| Protects previously merged knowledge from being overwritten. |
| |
| Think of it like this: after merging DeepSeek into Qwen3, we have |
| a "direction" in weight space that represents that merge. When we |
| then merge MiMo, we want MiMo's changes to go in a DIFFERENT direction, |
| not overwrite DeepSeek's contribution. |
| |
| Three mechanisms: |
| 1. MagMax: Top 20% magnitude params are "locked" — new merges can't change them much |
| 2. Orthogonal Projection: New deltas are projected perpendicular to previous deltas |
| 3. Time-Aware Scaling: Each successive merge gets a smaller alpha (1/sqrt(n+1)) |
| """ |
|
|
| def __init__(self, cfg: MergeConfig): |
| self.cfg = cfg |
| self.previous_deltas = {} |
| self.magnitude_masks = {} |
| self.arm_rotations = {} |
| self.otmf_masks = {} |
| self.merge_count = 0 |
|
|
| def before_merge( |
| self, |
| target_model: AutoModelForCausalLM, |
| source_config: ModelConfig, |
| ) -> float: |
| """ |
| Prepare protection before a merge. Returns adjusted alpha. |
| |
| Called BEFORE each merge to: |
| 1. Compute magnitude masks (MagMax) |
| 2. Calculate time-aware alpha scaling |
| """ |
| |
| if self.cfg.time_aware_scaling: |
| scale = 1.0 / np.sqrt(self.merge_count + 1) |
| adjusted_alpha = source_config.merge_alpha * scale |
| print(f"[protect] Time-aware scaling: {source_config.merge_alpha:.2f} × {scale:.3f} = {adjusted_alpha:.3f}") |
| else: |
| adjusted_alpha = source_config.merge_alpha |
|
|
| |
| if self.cfg.use_magmax and self.merge_count > 0: |
| print(f"[protect] Computing MagMax masks (protecting top 20% by magnitude)...") |
| state = target_model.state_dict() |
| for key, param in state.items(): |
| if param.dim() >= 1: |
| flat = param.abs().flatten() |
| threshold = torch.quantile(flat.float(), 0.8) |
| self.magnitude_masks[key] = param.abs() >= threshold |
|
|
| return adjusted_alpha |
|
|
| def apply_protection( |
| self, |
| target_state: dict, |
| pre_merge_state: dict, |
| key: str, |
| ) -> torch.Tensor: |
| """ |
| Apply all protection mechanisms to a fused parameter. |
| |
| Called AFTER each parameter is fused, to constrain the change. |
| |
| Protection stack (applied in order): |
| 1. ARM steering (2602.03237) — steer delta toward gap, away from previous direction |
| 2. Orthogonal projection (legacy fallback if ARM disabled) |
| 3. OTMF masks (2511.19561) — protect task-specific weights |
| 4. MagMax — protect top magnitude params (extra safety layer) |
| """ |
| fused = target_state[key] |
| original = pre_merge_state[key] |
| delta = fused - original |
|
|
| |
| if self.cfg.use_arm_steering and self.arm_rotations: |
| |
| layer_prefix = ".".join(key.split(".")[:4]) |
| for layer_name, rotation_info in self.arm_rotations.items(): |
| if layer_prefix in layer_name: |
| delta = apply_arm_steering( |
| delta, rotation_info, |
| steering_strength=self.cfg.arm_steering_strength, |
| ) |
| break |
|
|
| |
| elif self.cfg.use_orthogonal_projection and key in self.previous_deltas: |
| for prev_delta in self.previous_deltas[key]: |
| prev_flat = prev_delta.flatten().float() |
| delta_flat = delta.flatten().float() |
|
|
| dot = torch.dot(delta_flat, prev_flat) |
| norm_sq = torch.dot(prev_flat, prev_flat) |
|
|
| if norm_sq > 1e-10: |
| projection = (dot / norm_sq) * prev_flat |
| delta_flat = delta_flat - projection |
| delta = delta_flat.reshape(delta.shape).to(delta.dtype) |
|
|
| |
| if self.cfg.use_otmf_masks and key in self.otmf_masks: |
| mask = self.otmf_masks[key].to(delta.device) |
| |
| |
| delta = torch.where( |
| mask, |
| delta, |
| delta * (1.0 - self.cfg.otmf_protect_strength), |
| ) |
|
|
| |
| if self.cfg.use_magmax and key in self.magnitude_masks: |
| mask = self.magnitude_masks[key] |
| delta = torch.where(mask, delta * 0.1, delta) |
|
|
| |
| result = original + delta |
|
|
| return result |
|
|
| def after_merge( |
| self, |
| target_model: AutoModelForCausalLM, |
| pre_merge_state: dict, |
| pre_merge_activations: dict = None, |
| post_merge_activations: dict = None, |
| ): |
| """ |
| Record the merge delta and compute protections for next merge. |
| |
| Called AFTER each merge completes successfully. |
| Now also computes: |
| - ARM rotation vectors for next merge steering |
| - OTMF transferability masks for next merge |
| """ |
| current_state = target_model.state_dict() |
|
|
| for key in current_state: |
| if key in pre_merge_state: |
| delta = current_state[key].float() - pre_merge_state[key].float() |
| if delta.abs().max() > 1e-8: |
| if key not in self.previous_deltas: |
| self.previous_deltas[key] = [] |
| if len(self.previous_deltas[key]) >= 2: |
| self.previous_deltas[key].pop(0) |
| self.previous_deltas[key].append(delta.cpu()) |
|
|
| |
| if self.cfg.use_arm_steering and pre_merge_activations and post_merge_activations: |
| print("[protect] Computing ARM rotation vectors for next merge...") |
| self.arm_rotations = compute_arm_rotation( |
| pre_merge_activations, |
| post_merge_activations, |
| post_merge_activations, |
| ) |
|
|
| |
| if self.cfg.use_otmf_masks and post_merge_activations: |
| print("[protect] Computing OTMF transferability masks...") |
| self.otmf_masks = compute_transferability_masks( |
| target_model, |
| post_merge_activations, |
| threshold=self.cfg.otmf_threshold, |
| ) |
|
|
| self.merge_count += 1 |
| print(f"[protect] Recorded merge delta #{self.merge_count} (ARM + OTMF ready for next)") |
|
|
|
|
| |
| |
| |
|
|
| def is_vision_param(key: str, cfg: MergeConfig) -> bool: |
| """ |
| Check if a parameter belongs to the vision encoder. |
| |
| Qwen3-VL-8B has a ViT vision encoder + merger projection on top of the |
| language model. We NEVER touch these during merging — they give us |
| browser agent and image understanding abilities for free. |
| |
| Vision params start with prefixes like "visual." or "merger." |
| Language params start with "model.layers." or "model.embed_tokens." etc. |
| """ |
| for prefix in cfg.vision_skip_prefixes: |
| if key.startswith(prefix): |
| return True |
| return False |
|
|
|
|
| def get_source_by_stage(stage_name: str) -> Optional[ModelConfig]: |
| """Get model config by stage name.""" |
| stage_map = { |
| "deepseek": 0, |
| "mimo": 1, |
| "llama": 2, |
| "falcon": 3, |
| } |
| idx = stage_map.get(stage_name.lower()) |
| if idx is not None and idx < len(SOURCES): |
| return SOURCES[idx] |
| return None |
|
|
|
|
| def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple: |
| """Load a model and its tokenizer/processor.""" |
| print(f"\n[merge] Loading {config.name} ({config.hf_id})...") |
|
|
| |
| if config.architecture == "transformer+vision": |
| try: |
| from transformers import Qwen3VLForConditionalGeneration, AutoProcessor |
| processor = AutoProcessor.from_pretrained( |
| config.hf_id, |
| trust_remote_code=config.trust_remote_code, |
| ) |
| model = Qwen3VLForConditionalGeneration.from_pretrained( |
| config.hf_id, |
| torch_dtype=getattr(torch, cfg.dtype), |
| attn_implementation=cfg.attn_implementation, |
| device_map=cfg.device_map, |
| trust_remote_code=config.trust_remote_code, |
| ) |
| |
| tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor |
| print(f"[merge] Loaded {config.name} (VL model): {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params") |
|
|
| |
| vision_params = sum( |
| p.numel() for n, p in model.named_parameters() |
| if any(n.startswith(pfx) for pfx in cfg.vision_skip_prefixes) |
| ) |
| lang_params = sum(p.numel() for p in model.parameters()) - vision_params |
| print(f"[merge] Language: {lang_params / 1e9:.1f}B | Vision: {vision_params / 1e9:.1f}B") |
|
|
| return model, tokenizer |
| except ImportError: |
| print("[merge] Qwen3VLForConditionalGeneration not available, falling back to AutoModel") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| config.hf_id, |
| trust_remote_code=config.trust_remote_code, |
| ) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| config.hf_id, |
| torch_dtype=getattr(torch, cfg.dtype), |
| attn_implementation=cfg.attn_implementation, |
| device_map=cfg.device_map, |
| trust_remote_code=config.trust_remote_code, |
| ) |
|
|
| print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params") |
| return model, tokenizer |
|
|
|
|
| def save_checkpoint( |
| model: AutoModelForCausalLM, |
| tokenizer: AutoTokenizer, |
| stage_name: str, |
| cfg: MergeConfig, |
| ): |
| """Save a checkpoint after a successful merge stage.""" |
| ckpt_dir = Path(cfg.checkpoint_dir) / f"after_{stage_name}" |
| ckpt_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"[merge] Saving checkpoint to {ckpt_dir}...") |
| model.save_pretrained(ckpt_dir) |
| tokenizer.save_pretrained(ckpt_dir) |
| print(f"[merge] Checkpoint saved: {ckpt_dir}") |
|
|
| return str(ckpt_dir) |
|
|
|
|
| |
| |
| |
|
|
| class ResidualBank: |
| """ |
| Saves the knowledge that gets lost during each merge so it can |
| be recovered later. |
| |
| When we blend at alpha=0.10: |
| merged = target + alpha * M * (transported - target) |
| |
| We LOSE: |
| target_residual = target_original - merged (what target lost) |
| source_residual = source_original - merged (what source lost) |
| |
| These residuals are saved to disk. Later they can be: |
| 1. Fed back during the healing fine-tune (as training signal) |
| 2. Re-injected via a small LoRA adapter |
| 3. Used to diagnose which merge caused a specific knowledge loss |
| 4. Re-applied at a lower alpha if we want more of that model |
| |
| Think of it like saving the sawdust when you cut wood — you might |
| need to glue some of it back later. |
| """ |
|
|
| def __init__(self, cfg: MergeConfig): |
| self.cfg = cfg |
| self.residual_dir = Path(cfg.checkpoint_dir) / "residuals" |
| self.residual_dir.mkdir(parents=True, exist_ok=True) |
| self.residual_index = {} |
|
|
| def save_residuals( |
| self, |
| stage_name: str, |
| pre_merge_target_state: dict, |
| source_state: dict, |
| post_merge_state: dict, |
| source_config: ModelConfig, |
| ): |
| """ |
| Compute and save what was lost from both target and source. |
| |
| Saves two files per merge stage: |
| - target_residual: what the target model lost |
| - source_residual: what the source model didn't fully contribute |
| |
| Also saves stats so we know WHERE the biggest losses were |
| (which layers, which type of weights). |
| """ |
| stage_dir = self.residual_dir / stage_name |
| stage_dir.mkdir(parents=True, exist_ok=True) |
|
|
| target_residual = {} |
| source_residual = {} |
| stats = { |
| "stage": stage_name, |
| "source_model": source_config.name, |
| "target_loss_by_layer": {}, |
| "source_loss_by_layer": {}, |
| "total_target_loss": 0.0, |
| "total_source_loss": 0.0, |
| "biggest_losses": [], |
| } |
|
|
| for key in post_merge_state: |
| merged_w = post_merge_state[key].float() |
|
|
| |
| if key in pre_merge_target_state: |
| original_target = pre_merge_target_state[key].float() |
| t_residual = original_target - merged_w |
| t_loss = t_residual.abs().mean().item() |
|
|
| if t_loss > 1e-6: |
| target_residual[key] = t_residual.to(torch.bfloat16).cpu() |
| stats["total_target_loss"] += t_loss |
|
|
| |
| layer_name = ".".join(key.split(".")[:4]) |
| if layer_name not in stats["target_loss_by_layer"]: |
| stats["target_loss_by_layer"][layer_name] = 0.0 |
| stats["target_loss_by_layer"][layer_name] += t_loss |
|
|
| |
| if key in source_state: |
| original_source = source_state[key].float() |
| s_residual = original_source - merged_w |
| s_loss = s_residual.abs().mean().item() |
|
|
| if s_loss > 1e-6: |
| source_residual[key] = s_residual.to(torch.bfloat16).cpu() |
| stats["total_source_loss"] += s_loss |
|
|
| layer_name = ".".join(key.split(".")[:4]) |
| if layer_name not in stats["source_loss_by_layer"]: |
| stats["source_loss_by_layer"][layer_name] = 0.0 |
| stats["source_loss_by_layer"][layer_name] += s_loss |
|
|
| |
| all_losses = [] |
| for key in target_residual: |
| loss_magnitude = target_residual[key].float().abs().mean().item() |
| all_losses.append({"param": key, "side": "target", "loss": loss_magnitude}) |
| for key in source_residual: |
| loss_magnitude = source_residual[key].float().abs().mean().item() |
| all_losses.append({"param": key, "side": "source", "loss": loss_magnitude}) |
| all_losses.sort(key=lambda x: x["loss"], reverse=True) |
| stats["biggest_losses"] = all_losses[:20] |
|
|
| |
| torch.save(target_residual, stage_dir / "target_residual.pt") |
| torch.save(source_residual, stage_dir / "source_residual.pt") |
|
|
| import json |
| with open(stage_dir / "residual_stats.json", "w") as f: |
| json.dump(stats, f, indent=2, default=str) |
|
|
| self.residual_index[stage_name] = { |
| "path": str(stage_dir), |
| "target_params_saved": len(target_residual), |
| "source_params_saved": len(source_residual), |
| "total_target_loss": stats["total_target_loss"], |
| "total_source_loss": stats["total_source_loss"], |
| } |
|
|
| print(f"[residual] Saved residuals for {stage_name}:") |
| print(f" Target lost: {len(target_residual)} params (avg loss: {stats['total_target_loss']:.4f})") |
| print(f" Source lost: {len(source_residual)} params (avg loss: {stats['total_source_loss']:.4f})") |
| print(f" Top loss: {all_losses[0]['param']} ({all_losses[0]['side']}, {all_losses[0]['loss']:.4f})" if all_losses else "") |
| print(f" Saved to: {stage_dir}") |
|
|
| def load_residuals(self, stage_name: str) -> tuple: |
| """ |
| Load saved residuals for a stage. |
| |
| Returns: |
| (target_residual_dict, source_residual_dict) |
| """ |
| stage_dir = self.residual_dir / stage_name |
| target_residual = torch.load(stage_dir / "target_residual.pt", weights_only=True) |
| source_residual = torch.load(stage_dir / "source_residual.pt", weights_only=True) |
| return target_residual, source_residual |
|
|
| def reinject_residuals( |
| self, |
| model: AutoModelForCausalLM, |
| stage_name: str, |
| side: str = "both", |
| strength: float = 0.3, |
| ) -> AutoModelForCausalLM: |
| """ |
| Re-inject saved residuals back into a model. |
| |
| This adds back some of what was lost. Use a low strength (0.1-0.3) |
| to gently recover knowledge without undoing the merge. |
| |
| Args: |
| model: The model to inject into |
| stage_name: Which merge stage's residuals to use |
| side: "target", "source", or "both" |
| strength: How much to add back (0=nothing, 1=full residual) |
| """ |
| print(f"[residual] Re-injecting {stage_name} residuals (side={side}, strength={strength})...") |
|
|
| target_residual, source_residual = self.load_residuals(stage_name) |
| state = model.state_dict() |
| injected = 0 |
|
|
| if side in ("target", "both"): |
| for key, residual in target_residual.items(): |
| if key in state: |
| state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype) |
| injected += 1 |
|
|
| if side in ("source", "both"): |
| for key, residual in source_residual.items(): |
| if key in state: |
| state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype) |
| injected += 1 |
|
|
| model.load_state_dict(state) |
| print(f"[residual] Re-injected {injected} params at {strength:.0%} strength") |
| return model |
|
|
| def get_healing_targets(self, top_n: int = 50) -> list: |
| """ |
| Get the parameters with the biggest losses across ALL merges. |
| |
| These are the params that the healing fine-tune should focus on. |
| Feed this to the LoRA target_modules to make healing smarter. |
| """ |
| import json |
| all_losses = [] |
|
|
| for stage_name in self.residual_index: |
| stage_dir = self.residual_dir / stage_name |
| stats_file = stage_dir / "residual_stats.json" |
| if stats_file.exists(): |
| with open(stats_file) as f: |
| stats = json.load(f) |
| for loss in stats.get("biggest_losses", []): |
| loss["stage"] = stage_name |
| all_losses.append(loss) |
|
|
| all_losses.sort(key=lambda x: x["loss"], reverse=True) |
|
|
| |
| target_modules = set() |
| for loss in all_losses[:top_n]: |
| param = loss["param"] |
| |
| parts = param.split(".") |
| for part in parts: |
| if part.endswith("_proj") or part in ("gate_proj", "up_proj", "down_proj"): |
| target_modules.add(part) |
|
|
| print(f"[residual] Top healing targets (from {len(all_losses)} total losses):") |
| for loss in all_losses[:5]: |
| print(f" {loss['param']} ({loss['side']}, stage={loss['stage']}, loss={loss['loss']:.4f})") |
| print(f" → Suggested LoRA targets: {sorted(target_modules)}") |
|
|
| return list(target_modules) |
|
|
|
|
| def run_single_merge( |
| target_model: AutoModelForCausalLM, |
| target_tokenizer: AutoTokenizer, |
| source_config: ModelConfig, |
| cfg: MergeConfig, |
| protection: MergeProtection, |
| residual_bank: ResidualBank = None, |
| calibration_data: list = None, |
| baseline_perplexity: float = None, |
| merged_sources: list = None, |
| ) -> dict: |
| """ |
| Run a single merge: source → target. |
| |
| Full pipeline for one merge step: |
| 1. Load source model |
| 2. Inject canary into source |
| 3. Extract activations from both |
| 4. Compute transport plans |
| 5. Apply merge protection |
| 6. Fuse weights |
| 7. Apply post-merge protection |
| 8. Validate |
| |
| Returns: |
| Dict with merge results, validation results, and status |
| """ |
| if merged_sources is None: |
| merged_sources = [] |
|
|
| stage_name = source_config.name |
| print(f"\n{'=' * 70}") |
| print(f"MERGE STAGE: {stage_name} → target") |
| print(f"Risk level: {source_config.merge_risk.upper()}") |
| print(f"{'=' * 70}") |
|
|
| result = { |
| "stage": stage_name, |
| "status": "pending", |
| "validation": None, |
| "checkpoint": None, |
| } |
|
|
| |
| source_model, source_tokenizer = load_model(source_config, cfg) |
|
|
| |
| if stage_name in CANARY_FACTS: |
| print(f"\n[merge] Injecting canary fact into {stage_name}...") |
| source_model = inject_canary(source_model, source_tokenizer, stage_name) |
|
|
| |
| if calibration_data is None: |
| calibration_data = load_calibration_data(cfg, target_tokenizer) |
|
|
| |
| print(f"\n[merge] Extracting source activations (two-sided)...") |
| source_activations = extract_activations(source_model, calibration_data) |
|
|
| print(f"\n[merge] Extracting target activations (two-sided)...") |
| pre_merge_target_activations = extract_activations(target_model, calibration_data) |
|
|
| |
| if cfg.use_mergeability_check: |
| mergeability = compute_mergeability_score( |
| source_activations, pre_merge_target_activations, source_config |
| ) |
| result["mergeability"] = mergeability |
|
|
| if mergeability["overall"] < cfg.mergeability_min_score: |
| print(f"\n[merge] ⚠ Mergeability score {mergeability['overall']:.2f} below threshold {cfg.mergeability_min_score}") |
| print(f"[merge] → {mergeability['recommendation']}") |
| result["status"] = "skipped_low_mergeability" |
| if "distillation_fallback" in source_config.special_handling: |
| result["fallback"] = "distillation" |
| del source_model, source_activations, pre_merge_target_activations |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return result |
|
|
| |
| transport_plans = compute_transport_plans( |
| source_activations, pre_merge_target_activations, cfg |
| ) |
|
|
| |
| use_ram = ( |
| cfg.use_ram_disentangle |
| and source_config.architecture in ("transformer", "transformer+mtp") |
| and source_config.merge_risk in ("low", "medium") |
| and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"]) |
| ) |
|
|
| |
| adjusted_alpha = protection.before_merge(target_model, source_config) |
|
|
| |
| source_config_adjusted = copy.copy(source_config) |
| source_config_adjusted.merge_alpha = adjusted_alpha |
|
|
| |
| pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()} |
|
|
| |
| if use_ram: |
| |
| print(f"\n[merge] Using RAM RL-preservation for {stage_name}...") |
| try: |
| |
| base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "") |
| print(f"[merge] Loading base model for RAM: {base_hf_id}") |
| base_model = AutoModelForCausalLM.from_pretrained( |
| base_hf_id, |
| torch_dtype=getattr(torch, cfg.dtype), |
| device_map=cfg.device_map, |
| trust_remote_code=source_config.trust_remote_code, |
| ) |
| shared_mask, rl_mask = disentangle_rl_weights( |
| source_model, base_model, cfg.ram_rl_threshold |
| ) |
| |
| target_state = merge_with_rl_preservation( |
| target_model.state_dict(), |
| source_model.state_dict(), |
| shared_mask, rl_mask, |
| shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha), |
| rl_alpha=cfg.ram_rl_alpha, |
| ) |
| target_model.load_state_dict(target_state) |
| del base_model |
| print(f"[merge] RAM merge complete for {stage_name}") |
| except Exception as e: |
| print(f"[merge] RAM failed ({e}), falling back to standard T&M merge") |
| target_model = fuse_weights( |
| source_model, target_model, transport_plans, |
| source_config_adjusted, cfg, |
| target_activations=pre_merge_target_activations, |
| ) |
| else: |
| |
| target_model = fuse_weights( |
| source_model, target_model, transport_plans, |
| source_config_adjusted, cfg, |
| target_activations=pre_merge_target_activations, |
| ) |
|
|
| |
| |
| if cfg.use_theseus_fallback and source_config.merge_risk == "high": |
| print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...") |
| post_activations = extract_activations(target_model, calibration_data[:50]) |
| |
| alignment_scores = [] |
| for key in post_activations: |
| if key in pre_merge_target_activations: |
| cos = torch.nn.functional.cosine_similarity( |
| post_activations[key].float().mean(0, keepdim=True), |
| pre_merge_target_activations[key].float().mean(0, keepdim=True), |
| ) |
| alignment_scores.append(cos.item()) |
| avg_change = 1.0 - np.mean(alignment_scores) if alignment_scores else 0.0 |
| print(f"[merge] Activation change from merge: {avg_change:.4f}") |
|
|
| if avg_change < 0.01: |
| print(f"[merge] ⚠ T&M had minimal effect — activating Theseus fallback") |
| |
| target_model.load_state_dict(pre_merge_state) |
| try: |
| base_model = AutoModelForCausalLM.from_pretrained( |
| source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0], |
| torch_dtype=getattr(torch, cfg.dtype), |
| device_map=cfg.device_map, |
| trust_remote_code=source_config.trust_remote_code, |
| ) |
| target_model = transport_task_vector_theseus( |
| source_model, base_model, target_model, |
| source_activations, pre_merge_target_activations, |
| alpha=cfg.theseus_alpha, |
| ) |
| del base_model |
| print(f"[merge] Theseus transport complete for {stage_name}") |
| except Exception as e: |
| print(f"[merge] Theseus also failed ({e}). Using original T&M result.") |
| |
| target_model = fuse_weights( |
| source_model, target_model, transport_plans, |
| source_config_adjusted, cfg, |
| target_activations=pre_merge_target_activations, |
| ) |
|
|
| |
| |
| if protection.merge_count > 0: |
| print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...") |
| target_state = target_model.state_dict() |
| protected_count = 0 |
| vision_skipped = 0 |
| for key in target_state: |
| if is_vision_param(key, cfg): |
| vision_skipped += 1 |
| continue |
| if key in pre_merge_state: |
| protected_param = protection.apply_protection( |
| target_state, pre_merge_state, key |
| ) |
| target_state[key] = protected_param |
| protected_count += 1 |
| target_model.load_state_dict(target_state) |
| print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)") |
|
|
| |
| post_merge_activations = extract_activations(target_model, calibration_data[:100]) |
|
|
| |
| protection.after_merge( |
| target_model, pre_merge_state, |
| pre_merge_activations=pre_merge_target_activations, |
| post_merge_activations=post_merge_activations, |
| ) |
|
|
| |
| if residual_bank is not None: |
| print(f"\n[merge] Saving residuals for {stage_name}...") |
| residual_bank.save_residuals( |
| stage_name=stage_name, |
| pre_merge_target_state=pre_merge_state, |
| source_state={k: v.cpu() for k, v in source_model.state_dict().items()}, |
| post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()}, |
| source_config=source_config, |
| ) |
|
|
| |
| del source_model, source_activations, pre_merge_target_activations |
| del transport_plans, post_merge_activations |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| |
| merged_sources.append(stage_name) |
| validation = validate_merged_model( |
| target_model, target_tokenizer, |
| merged_sources, cfg, |
| baseline_perplexity=baseline_perplexity, |
| ) |
|
|
| result["validation"] = validation |
| result["merged_sources"] = merged_sources.copy() |
|
|
| |
| if not validation["overall"]: |
| print(f"\n[merge] ⚠ VALIDATION FAILED for {stage_name}") |
| print(f"[merge] Kill criteria triggered — consider aborting") |
| result["status"] = "failed" |
|
|
| |
| if "distillation_fallback" in source_config.special_handling: |
| print(f"[merge] {stage_name} has distillation fallback available") |
| result["fallback"] = "distillation" |
| else: |
| print(f"\n[merge] ✓ {stage_name} merge PASSED validation") |
| result["status"] = "passed" |
|
|
| return result |
|
|
|
|
| def run_pipeline( |
| stages: list[str], |
| cfg: MergeConfig = None, |
| ) -> dict: |
| """ |
| Run the full merge pipeline. |
| |
| Args: |
| stages: List of stage names to run, e.g. ["deepseek"] or |
| ["deepseek", "mimo", "llama", "falcon"] |
| cfg: Merge configuration (uses defaults if None) |
| |
| Returns: |
| Dict with overall results, per-stage results, and final model path |
| """ |
| if cfg is None: |
| cfg = MergeConfig() |
|
|
| print("\n" + "=" * 70) |
| print("TD LANG ENGINE — Transport and Merge Pipeline") |
| print(f"Target: {TARGET.name} ({TARGET.hf_id})") |
| if TARGET.architecture == "transformer+vision": |
| print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)") |
| print(f"Stages: {', '.join(stages)}") |
| print(f"Output: {cfg.output_dir}") |
| print("=" * 70) |
|
|
| |
| try: |
| setup_tm_repo(cfg) |
| except FileNotFoundError as e: |
| print(f"\n⚠ {e}") |
| print("Continuing with fallback implementation...") |
|
|
| |
| Path(cfg.output_dir).mkdir(parents=True, exist_ok=True) |
| Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True) |
|
|
| |
| target_model, target_tokenizer = load_model(TARGET, cfg) |
|
|
| |
| if "Qwen3-VL-8B" in CANARY_FACTS: |
| print("\n[pipeline] Injecting canary into base Qwen3-8B...") |
| target_model = inject_canary(target_model, target_tokenizer, "Qwen3-VL-8B") |
|
|
| |
| print("\n[pipeline] Computing baseline perplexity...") |
| baseline_ppl = compute_perplexity(target_model, target_tokenizer) |
| print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}") |
|
|
| |
| calibration_data = load_calibration_data(cfg, target_tokenizer) |
|
|
| |
| protection = MergeProtection(cfg) |
| residual_bank = ResidualBank(cfg) |
|
|
| |
| pipeline_results = { |
| "stages": {}, |
| "baseline_perplexity": baseline_ppl, |
| "final_checkpoint": None, |
| "residuals": {}, |
| "overall_status": "pending", |
| } |
| merged_sources = [] |
| all_passed = True |
|
|
| for stage_name in stages: |
| source_config = get_source_by_stage(stage_name) |
| if source_config is None: |
| print(f"\n⚠ Unknown stage: {stage_name}, skipping") |
| continue |
|
|
| |
| if "check_wasserstein_first" in source_config.special_handling: |
| print(f"\n[pipeline] Running Wasserstein pre-check for {source_config.name}...") |
| |
| |
| print("[pipeline] Pre-check: proceeding (TODO: implement distance check)") |
|
|
| |
| stage_result = run_single_merge( |
| target_model, target_tokenizer, |
| source_config, cfg, |
| protection, |
| residual_bank=residual_bank, |
| calibration_data=calibration_data, |
| baseline_perplexity=baseline_ppl, |
| merged_sources=merged_sources, |
| ) |
|
|
| pipeline_results["stages"][stage_name] = stage_result |
|
|
| if stage_result["status"] == "passed": |
| |
| ckpt_path = save_checkpoint( |
| target_model, target_tokenizer, stage_name, cfg |
| ) |
| stage_result["checkpoint"] = ckpt_path |
| pipeline_results["final_checkpoint"] = ckpt_path |
| else: |
| all_passed = False |
| print(f"\n[pipeline] Stage {stage_name} FAILED") |
|
|
| |
| if source_config.merge_risk == "high": |
| print(f"[pipeline] High-risk model failed — skipping (will use distillation)") |
| |
| continue |
| else: |
| print(f"[pipeline] ABORTING pipeline — non-high-risk model failed") |
| pipeline_results["overall_status"] = f"aborted_at_{stage_name}" |
| break |
|
|
| |
| pipeline_results["residuals"] = residual_bank.residual_index |
| if residual_bank.residual_index: |
| print(f"\n[pipeline] Residual bank: {len(residual_bank.residual_index)} stages saved") |
| for stage, info in residual_bank.residual_index.items(): |
| print(f" {stage}: target lost {info['total_target_loss']:.4f}, source lost {info['total_source_loss']:.4f}") |
|
|
| |
| healing_targets = residual_bank.get_healing_targets(top_n=50) |
| pipeline_results["suggested_healing_targets"] = healing_targets |
|
|
| |
| if pipeline_results["final_checkpoint"]: |
| final_dir = Path(cfg.output_dir) / "final" |
| final_dir.mkdir(parents=True, exist_ok=True) |
| target_model.save_pretrained(final_dir) |
| target_tokenizer.save_pretrained(final_dir) |
| pipeline_results["final_model_path"] = str(final_dir) |
| print(f"\n[pipeline] Final model saved to {final_dir}") |
|
|
| if all_passed: |
| pipeline_results["overall_status"] = "all_passed" |
| elif pipeline_results["overall_status"] == "pending": |
| pipeline_results["overall_status"] = "partial" |
|
|
| |
| print("\n" + "=" * 70) |
| print("PIPELINE SUMMARY") |
| print("=" * 70) |
| for stage_name, stage_result in pipeline_results["stages"].items(): |
| status = stage_result["status"] |
| emoji = "✓" if status == "passed" else "✗" |
| print(f" {emoji} {stage_name}: {status}") |
| print(f"\n Overall: {pipeline_results['overall_status']}") |
| if residual_bank.residual_index: |
| print(f"\n Residuals saved for: {', '.join(residual_bank.residual_index.keys())}") |
| print(f" To recover lost knowledge later:") |
| print(f" python -m td_lang.engine --reinject <stage> --strength 0.2") |
| print("=" * 70) |
|
|
| return pipeline_results |
|
|