""" Sequential Merge Orchestrator — chains 4 merges with protection. This is the brain of td_fuse. It runs each merge in order: 1. Load source model 2. Inject canary fact into source 3. Extract activations from both models 4. Compute transport plans (P and Q matrices) 5. Fuse weights using optimal transport 6. Validate merged model (canary recall, perplexity, thinking mode) 7. Apply sequential merge protection before next merge 8. Checkpoint Protection between merges (findings #13): - MagMax: Protect top 20% parameters by magnitude (they carry critical knowledge) - Orthogonal Projection: Project new merge deltas perpendicular to previous ones - Time-Aware Scaling: scale = 1/sqrt(merge_index + 1) Kill criteria: >10% performance drop on any test → abort merge. Findings: #13, #22, #25 """ import os import gc import sys import copy import time import torch import numpy as np from pathlib import Path from typing import Optional from transformers import AutoModelForCausalLM, AutoTokenizer from .config import ( MergeConfig, ModelConfig, TARGET, SOURCES, CANARY_FACTS, DEMO_STAGES, FULL_STAGES, ) from .canary import inject_canary, test_all_canaries from .transport import ( setup_tm_repo, load_calibration_data, retokenize_calibration, extract_activations, compute_transport_plans, fuse_weights, ) from .validate import validate_merged_model, compute_perplexity from .techniques import ( compute_mergeability_score, compute_transferability_masks, apply_masked_merge, disentangle_rl_weights, merge_with_rl_preservation, compute_arm_rotation, apply_arm_steering, transport_task_vector_theseus, compute_procrustes_alignment, ) # ============================================================================ # SEQUENTIAL MERGE PROTECTION # ============================================================================ class MergeProtection: """ Protects previously merged knowledge from being overwritten. Think of it like this: after merging DeepSeek into Qwen3, we have a "direction" in weight space that represents that merge. When we then merge MiMo, we want MiMo's changes to go in a DIFFERENT direction, not overwrite DeepSeek's contribution. Three mechanisms: 1. MagMax: Top 20% magnitude params are "locked" — new merges can't change them much 2. Orthogonal Projection: New deltas are projected perpendicular to previous deltas 3. Time-Aware Scaling: Each successive merge gets a smaller alpha (1/sqrt(n+1)) """ def __init__(self, cfg: MergeConfig): self.cfg = cfg self.previous_deltas = {} # key → list of delta tensors from previous merges self.magnitude_masks = {} # key → bool mask of top-k magnitude params self.arm_rotations = {} # ARM: layer → rotation info from last merge self.otmf_masks = {} # OTMF: param → transferability mask self.merge_count = 0 def before_merge( self, target_model: AutoModelForCausalLM, source_config: ModelConfig, ) -> float: """ Prepare protection before a merge. Returns adjusted alpha. Called BEFORE each merge to: 1. Compute magnitude masks (MagMax) 2. Calculate time-aware alpha scaling """ # Time-aware scaling: each merge gets less aggressive if self.cfg.time_aware_scaling: scale = 1.0 / np.sqrt(self.merge_count + 1) adjusted_alpha = source_config.merge_alpha * scale print(f"[protect] Time-aware scaling: {source_config.merge_alpha:.2f} × {scale:.3f} = {adjusted_alpha:.3f}") else: adjusted_alpha = source_config.merge_alpha # MagMax: identify top 20% magnitude parameters to protect if self.cfg.use_magmax and self.merge_count > 0: print(f"[protect] Computing MagMax masks (protecting top 20% by magnitude)...") state = target_model.state_dict() for key, param in state.items(): if param.dim() >= 1: flat = param.abs().flatten() threshold = torch.quantile(flat.float(), 0.8) self.magnitude_masks[key] = param.abs() >= threshold return adjusted_alpha def apply_protection( self, target_state: dict, pre_merge_state: dict, key: str, ) -> torch.Tensor: """ Apply all protection mechanisms to a fused parameter. Called AFTER each parameter is fused, to constrain the change. Protection stack (applied in order): 1. ARM steering (2602.03237) — steer delta toward gap, away from previous direction 2. Orthogonal projection (legacy fallback if ARM disabled) 3. OTMF masks (2511.19561) — protect task-specific weights 4. MagMax — protect top magnitude params (extra safety layer) """ fused = target_state[key] original = pre_merge_state[key].to(fused.device) delta = fused - original # --- ARM Steering (new, replaces orthogonal projection) --- if self.cfg.use_arm_steering and self.arm_rotations: # Find matching layer rotation layer_prefix = ".".join(key.split(".")[:4]) for layer_name, rotation_info in self.arm_rotations.items(): if layer_prefix in layer_name: delta = apply_arm_steering( delta, rotation_info, steering_strength=self.cfg.arm_steering_strength, ) break # --- Orthogonal Projection (legacy fallback) --- elif self.cfg.use_orthogonal_projection and key in self.previous_deltas: for prev_delta in self.previous_deltas[key]: prev_flat = prev_delta.flatten().float() delta_flat = delta.flatten().float() dot = torch.dot(delta_flat, prev_flat) norm_sq = torch.dot(prev_flat, prev_flat) if norm_sq > 1e-10: projection = (dot / norm_sq) * prev_flat delta_flat = delta_flat - projection delta = delta_flat.reshape(delta.shape).to(delta.dtype) # --- OTMF Mask Protection (new) --- if self.cfg.use_otmf_masks and key in self.otmf_masks: mask = self.otmf_masks[key].to(delta.device) # Transferable weights: full delta # Task-specific weights: reduced delta (protect them) delta = torch.where( mask, delta, # Transferable → allow full change delta * (1.0 - self.cfg.otmf_protect_strength), # Protected → reduced ) # --- MagMax Protection (extra safety layer) --- if self.cfg.use_magmax and key in self.magnitude_masks: mask = self.magnitude_masks[key] delta = torch.where(mask, delta * 0.1, delta) # Apply constrained delta result = original + delta return result def after_merge( self, target_model: AutoModelForCausalLM, pre_merge_state: dict, pre_merge_activations: dict = None, post_merge_activations: dict = None, ): """ Record the merge delta and compute protections for next merge. Called AFTER each merge completes successfully. Now also computes: - ARM rotation vectors for next merge steering - OTMF transferability masks for next merge """ current_state = target_model.state_dict() for key in current_state: if key in pre_merge_state: delta = current_state[key].cpu().float() - pre_merge_state[key].cpu().float() if delta.abs().max() > 1e-8: if key not in self.previous_deltas: self.previous_deltas[key] = [] if len(self.previous_deltas[key]) >= 2: self.previous_deltas[key].pop(0) self.previous_deltas[key].append(delta.cpu()) # --- Compute ARM rotations for next merge --- if self.cfg.use_arm_steering and pre_merge_activations and post_merge_activations: print("[protect] Computing ARM rotation vectors for next merge...") self.arm_rotations = compute_arm_rotation( pre_merge_activations, post_merge_activations, post_merge_activations, # Target = current state (for gap calculation) ) # --- Compute OTMF masks for next merge --- if self.cfg.use_otmf_masks and post_merge_activations: print("[protect] Computing OTMF transferability masks...") self.otmf_masks = compute_transferability_masks( target_model, post_merge_activations, threshold=self.cfg.otmf_threshold, ) self.merge_count += 1 print(f"[protect] Recorded merge delta #{self.merge_count} (ARM + OTMF ready for next)") # ============================================================================ # MAIN ORCHESTRATOR # ============================================================================ def is_vision_param(key: str, cfg: MergeConfig) -> bool: """ Check if a parameter belongs to the vision encoder. Qwen3-VL-8B has a ViT vision encoder + merger projection on top of the language model. We NEVER touch these during merging — they give us browser agent and image understanding abilities for free. Vision params start with prefixes like "visual." or "merger." Language params start with "model.layers." or "model.embed_tokens." etc. """ for prefix in cfg.vision_skip_prefixes: if key.startswith(prefix): return True return False def get_source_by_stage(stage_name: str) -> Optional[ModelConfig]: """Get model config by stage name.""" stage_map = { "deepseek": 0, "mimo": 1, "llama": 2, "falcon": 3, } idx = stage_map.get(stage_name.lower()) if idx is not None and idx < len(SOURCES): return SOURCES[idx] return None def check_model_cached(hf_id: str) -> bool: """Check if a model is already in the HuggingFace cache.""" try: from huggingface_hub import try_to_load_from_cache, model_info # Quick check: see if config.json is cached (every model has one) cached = try_to_load_from_cache(hf_id, "config.json") if cached is not None and isinstance(cached, str): return True except Exception: pass return False def check_all_models_cached(stages: list) -> dict: """ Pre-flight check: are all needed models already downloaded? Prints a clear table so you know what's cached and what will download. """ print("\n" + "=" * 60) print("PRE-FLIGHT CHECK: Model cache status") print("=" * 60) sys.stdout.flush() status = {} # Target model cached = check_model_cached(TARGET.hf_id) tag = "CACHED" if cached else "WILL DOWNLOAD" print(f" {TARGET.name:25s} {tag:15s} ({TARGET.hf_id})") status[TARGET.name] = cached # Source models for requested stages for stage_name in stages: source = get_source_by_stage(stage_name) if source: cached = check_model_cached(source.hf_id) tag = "CACHED" if cached else "WILL DOWNLOAD" print(f" {source.name:25s} {tag:15s} ({source.hf_id})") status[source.name] = cached not_cached = [name for name, c in status.items() if not c] if not_cached: print(f"\n {len(not_cached)} model(s) need downloading: {', '.join(not_cached)}") print(f" This may take 10-30 min per model depending on connection speed.") else: print(f"\n All {len(status)} models are cached -- loading will be fast!") print("=" * 60) sys.stdout.flush() return status def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple: """Load a model and its tokenizer/processor.""" load_start = time.time() cached = check_model_cached(config.hf_id) cache_msg = "(from cache)" if cached else "(downloading -- this may take a while)" print(f"\n[merge] Loading {config.name} ({config.hf_id}) {cache_msg}...") sys.stdout.flush() # Qwen3-VL uses a processor (handles both text + vision), not just a tokenizer if config.architecture == "transformer+vision": try: from transformers import Qwen3VLForConditionalGeneration, AutoProcessor processor = AutoProcessor.from_pretrained( config.hf_id, trust_remote_code=config.trust_remote_code, ) model = Qwen3VLForConditionalGeneration.from_pretrained( config.hf_id, torch_dtype=getattr(torch, cfg.dtype), attn_implementation=cfg.attn_implementation, device_map=cfg.device_map, trust_remote_code=config.trust_remote_code, ) # Use the tokenizer from the processor for text operations tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor print(f"[merge] Loaded {config.name} (VL model): {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params") # Count vision vs language params vision_params = sum( p.numel() for n, p in model.named_parameters() if any(n.startswith(pfx) for pfx in cfg.vision_skip_prefixes) ) lang_params = sum(p.numel() for p in model.parameters()) - vision_params print(f"[merge] Language: {lang_params / 1e9:.1f}B | Vision: {vision_params / 1e9:.1f}B") print(f"[merge] Loaded in {time.time()-load_start:.0f}s"); sys.stdout.flush() return model, tokenizer except ImportError: print("[merge] Qwen3VLForConditionalGeneration not available, falling back to AutoModel") # Standard text-only models tokenizer = AutoTokenizer.from_pretrained( config.hf_id, trust_remote_code=config.trust_remote_code, ) model = AutoModelForCausalLM.from_pretrained( config.hf_id, torch_dtype=getattr(torch, cfg.dtype), attn_implementation=cfg.attn_implementation, device_map=cfg.device_map, trust_remote_code=config.trust_remote_code, ) print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params") print(f"[merge] Loaded in {time.time()-load_start:.0f}s"); sys.stdout.flush() return model, tokenizer def save_checkpoint( model: AutoModelForCausalLM, tokenizer: AutoTokenizer, stage_name: str, cfg: MergeConfig, ): """Save a checkpoint after a successful merge stage.""" import shutil ckpt_base = Path(cfg.checkpoint_dir) ckpt_dir = ckpt_base / f"after_{stage_name}" # --- Pre-save cleanup: free disk space --- # 1. Delete residuals (non-essential, 5-20GB) residuals_dir = ckpt_base / "residuals" if residuals_dir.exists(): shutil.rmtree(str(residuals_dir), ignore_errors=True) print(f"[merge] Freed disk: deleted residuals") # 2. Delete td_fuse_outputs/final (duplicate of last checkpoint, ~17GB) final_dir = Path("td_fuse_outputs") / "final" if final_dir.exists(): shutil.rmtree(str(final_dir), ignore_errors=True) print(f"[merge] Freed disk: deleted td_fuse_outputs/final") # 3. Delete OLD checkpoints (already on HuggingFace via watcher) if ckpt_base.exists(): for old_ckpt in ckpt_base.glob("after_*"): if old_ckpt.name != f"after_{stage_name}" and old_ckpt.is_dir(): shutil.rmtree(str(old_ckpt), ignore_errors=True) print(f"[merge] Freed disk: deleted old checkpoint {old_ckpt.name}") # Check disk space import shutil as sh_util total, used, free = sh_util.disk_usage("/") print(f"[merge] Disk after cleanup: {free/1e9:.1f} GB free / {total/1e9:.1f} GB total") ckpt_dir.mkdir(parents=True, exist_ok=True) print(f"[merge] Saving checkpoint to {ckpt_dir}...") model.save_pretrained(ckpt_dir) tokenizer.save_pretrained(ckpt_dir) print(f"[merge] Checkpoint saved: {ckpt_dir}") return str(ckpt_dir) # ============================================================================ # RESIDUAL BANK — Save what was lost during each merge # ============================================================================ class ResidualBank: """ Saves the knowledge that gets lost during each merge so it can be recovered later. When we blend at alpha=0.5: merged = 0.5 × source + 0.5 × target We LOSE: target_residual = target_original - merged (what target lost) source_residual = source_original - merged (what source lost) These residuals are saved to disk. Later they can be: 1. Fed back during the healing fine-tune (as training signal) 2. Re-injected via a small LoRA adapter 3. Used to diagnose which merge caused a specific knowledge loss 4. Re-applied at a lower alpha if we want more of that model Think of it like saving the sawdust when you cut wood — you might need to glue some of it back later. """ def __init__(self, cfg: MergeConfig): self.cfg = cfg self.residual_dir = Path(cfg.checkpoint_dir) / "residuals" self.residual_dir.mkdir(parents=True, exist_ok=True) self.residual_index = {} # stage → {path, stats} def save_residuals( self, stage_name: str, pre_merge_target_state: dict, source_state: dict, post_merge_state: dict, source_config: ModelConfig, ): """ Compute and save what was lost from both target and source. Saves two files per merge stage: - target_residual: what the target model lost - source_residual: what the source model didn't fully contribute Also saves stats so we know WHERE the biggest losses were (which layers, which type of weights). """ stage_dir = self.residual_dir / stage_name stage_dir.mkdir(parents=True, exist_ok=True) target_residual = {} source_residual = {} stats = { "stage": stage_name, "source_model": source_config.name, "target_loss_by_layer": {}, "source_loss_by_layer": {}, "total_target_loss": 0.0, "total_source_loss": 0.0, "biggest_losses": [], } for key in post_merge_state: merged_w = post_merge_state[key].float() # What the target lost if key in pre_merge_target_state: original_target = pre_merge_target_state[key].float() t_residual = original_target - merged_w t_loss = t_residual.abs().mean().item() if t_loss > 1e-6: # Only save meaningful residuals target_residual[key] = t_residual.to(torch.bfloat16).cpu() stats["total_target_loss"] += t_loss # Track per-layer losses layer_name = ".".join(key.split(".")[:4]) if layer_name not in stats["target_loss_by_layer"]: stats["target_loss_by_layer"][layer_name] = 0.0 stats["target_loss_by_layer"][layer_name] += t_loss # What the source lost (what didn't make it into the merge) if key in source_state: original_source = source_state[key].float() # Skip if shapes don't match (e.g. vocab size mismatch on embeddings/lm_head) if original_source.shape != merged_w.shape: continue s_residual = original_source - merged_w s_loss = s_residual.abs().mean().item() if s_loss > 1e-6: source_residual[key] = s_residual.to(torch.bfloat16).cpu() stats["total_source_loss"] += s_loss layer_name = ".".join(key.split(".")[:4]) if layer_name not in stats["source_loss_by_layer"]: stats["source_loss_by_layer"][layer_name] = 0.0 stats["source_loss_by_layer"][layer_name] += s_loss # Find the biggest losses (most knowledge dropped) all_losses = [] for key in target_residual: loss_magnitude = target_residual[key].float().abs().mean().item() all_losses.append({"param": key, "side": "target", "loss": loss_magnitude}) for key in source_residual: loss_magnitude = source_residual[key].float().abs().mean().item() all_losses.append({"param": key, "side": "source", "loss": loss_magnitude}) all_losses.sort(key=lambda x: x["loss"], reverse=True) stats["biggest_losses"] = all_losses[:20] # Top 20 biggest losses # Save to disk torch.save(target_residual, stage_dir / "target_residual.pt") torch.save(source_residual, stage_dir / "source_residual.pt") import json with open(stage_dir / "residual_stats.json", "w") as f: json.dump(stats, f, indent=2, default=str) self.residual_index[stage_name] = { "path": str(stage_dir), "target_params_saved": len(target_residual), "source_params_saved": len(source_residual), "total_target_loss": stats["total_target_loss"], "total_source_loss": stats["total_source_loss"], } print(f"[residual] Saved residuals for {stage_name}:") print(f" Target lost: {len(target_residual)} params (avg loss: {stats['total_target_loss']:.4f})") print(f" Source lost: {len(source_residual)} params (avg loss: {stats['total_source_loss']:.4f})") print(f" Top loss: {all_losses[0]['param']} ({all_losses[0]['side']}, {all_losses[0]['loss']:.4f})" if all_losses else "") print(f" Saved to: {stage_dir}") def load_residuals(self, stage_name: str) -> tuple: """ Load saved residuals for a stage. Returns: (target_residual_dict, source_residual_dict) """ stage_dir = self.residual_dir / stage_name target_residual = torch.load(stage_dir / "target_residual.pt", weights_only=True) source_residual = torch.load(stage_dir / "source_residual.pt", weights_only=True) return target_residual, source_residual def reinject_residuals( self, model: AutoModelForCausalLM, stage_name: str, side: str = "both", strength: float = 0.3, ) -> AutoModelForCausalLM: """ Re-inject saved residuals back into a model. This adds back some of what was lost. Use a low strength (0.1-0.3) to gently recover knowledge without undoing the merge. Args: model: The model to inject into stage_name: Which merge stage's residuals to use side: "target", "source", or "both" strength: How much to add back (0=nothing, 1=full residual) """ print(f"[residual] Re-injecting {stage_name} residuals (side={side}, strength={strength})...") target_residual, source_residual = self.load_residuals(stage_name) state = model.state_dict() injected = 0 if side in ("target", "both"): for key, residual in target_residual.items(): if key in state: state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype) injected += 1 if side in ("source", "both"): for key, residual in source_residual.items(): if key in state: state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype) injected += 1 model.load_state_dict(state) print(f"[residual] Re-injected {injected} params at {strength:.0%} strength") return model def get_healing_targets(self, top_n: int = 50) -> list: """ Get the parameters with the biggest losses across ALL merges. These are the params that the healing fine-tune should focus on. Feed this to the LoRA target_modules to make healing smarter. """ import json all_losses = [] for stage_name in self.residual_index: stage_dir = self.residual_dir / stage_name stats_file = stage_dir / "residual_stats.json" if stats_file.exists(): with open(stats_file) as f: stats = json.load(f) for loss in stats.get("biggest_losses", []): loss["stage"] = stage_name all_losses.append(loss) all_losses.sort(key=lambda x: x["loss"], reverse=True) # Extract unique layer/module names for LoRA targeting target_modules = set() for loss in all_losses[:top_n]: param = loss["param"] # Extract the module type (q_proj, k_proj, gate_proj, etc.) parts = param.split(".") for part in parts: if part.endswith("_proj") or part in ("gate_proj", "up_proj", "down_proj"): target_modules.add(part) print(f"[residual] Top healing targets (from {len(all_losses)} total losses):") for loss in all_losses[:5]: print(f" {loss['param']} ({loss['side']}, stage={loss['stage']}, loss={loss['loss']:.4f})") print(f" → Suggested LoRA targets: {sorted(target_modules)}") return list(target_modules) def run_single_merge( target_model: AutoModelForCausalLM, target_tokenizer: AutoTokenizer, source_config: ModelConfig, cfg: MergeConfig, protection: MergeProtection, residual_bank: ResidualBank = None, calibration_data: list = None, calibration_raw_texts: list = None, baseline_perplexity: float = None, merged_sources: list = None, ) -> dict: """ Run a single merge: source → target. Full pipeline for one merge step: 1. Load source model 2. Inject canary into source 3. Extract activations from both 4. Compute transport plans 5. Apply merge protection 6. Fuse weights 7. Apply post-merge protection 8. Validate Returns: Dict with merge results, validation results, and status """ if merged_sources is None: merged_sources = [] stage_name = source_config.name stage_start = time.time() print(f"\n{'=' * 70}") print(f"MERGE STAGE: {stage_name} -> target") print(f"Risk level: {source_config.merge_risk.upper()}") print(f"Started at: {time.strftime('%H:%M:%S')}") print(f"{'=' * 70}") sys.stdout.flush() result = { "stage": stage_name, "status": "pending", "validation": None, "checkpoint": None, } # --- Step 1: Load source model --- print(f"\n[merge] Step 1/10: Loading source model..."); sys.stdout.flush() step_t = time.time() source_model, source_tokenizer = load_model(source_config, cfg) print(f"[merge] Step 1/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush() # --- Step 2: Inject canary into source --- print(f"\n[merge] Step 2/10: Injecting canary..."); sys.stdout.flush() step_t = time.time() if stage_name in CANARY_FACTS: source_model = inject_canary(source_model, source_tokenizer, stage_name) print(f"[merge] Step 2/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush() # --- Step 3: Load calibration data (if not provided) --- print(f"\n[merge] Step 3/10: Loading calibration data..."); sys.stdout.flush() step_t = time.time() if calibration_data is None: calibration_data, calibration_raw_texts = load_calibration_data(cfg, target_tokenizer) print(f"[merge] Step 3/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush() # --- Step 4: Extract activations --- print(f"\n[merge] Step 4/10: Extracting activations (both models)..."); sys.stdout.flush() step_t = time.time() # Check if source model has a different vocabulary size than target. source_vocab_size = len(source_tokenizer) target_vocab_size = len(target_tokenizer) print(f"[merge] Vocab sizes -- target: {target_vocab_size}, source: {source_vocab_size}") if source_vocab_size != target_vocab_size: print(f"[merge] VOCAB MISMATCH detected! Re-tokenizing calibration data for {source_config.name}...") source_calibration = retokenize_calibration(calibration_raw_texts, source_tokenizer, cfg) print(f"[merge] Extracting source activations (with source-tokenized data)...") source_activations = extract_activations(source_model, source_calibration) del source_calibration else: print(f"[merge] Extracting source activations...") source_activations = extract_activations(source_model, calibration_data) print(f"[merge] Extracting target activations...") pre_merge_target_activations = extract_activations(target_model, calibration_data) print(f"[merge] Step 4/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush() # --- Step 4.5: Mergeability pre-check (2601.22285) --- if cfg.use_mergeability_check: mergeability = compute_mergeability_score( source_activations, pre_merge_target_activations, source_config ) result["mergeability"] = mergeability if mergeability["overall"] < cfg.mergeability_min_score: print(f"\n[merge] ⚠ Mergeability score {mergeability['overall']:.2f} below threshold {cfg.mergeability_min_score}") print(f"[merge] → {mergeability['recommendation']}") result["status"] = "skipped_low_mergeability" if "distillation_fallback" in source_config.special_handling: result["fallback"] = "distillation" del source_model, source_activations, pre_merge_target_activations gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return result # --- Step 4.9: Free VRAM before transport computation --- print(f"\n[merge] Step 4.9: Moving models to CPU to free VRAM for transport...") sys.stdout.flush() source_model = source_model.cpu() target_model = target_model.cpu() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() free_mem = torch.cuda.mem_get_info()[0] / 1e9 total_mem = torch.cuda.mem_get_info()[1] / 1e9 print(f"[merge] GPU memory after CPU offload: {free_mem:.1f} GB free / {total_mem:.1f} GB total") sys.stdout.flush() # --- Step 5: Compute transport plans --- print(f"\n[merge] Step 5/10: Computing transport plans..."); sys.stdout.flush() step_t = time.time() transport_plans = compute_transport_plans( source_activations, pre_merge_target_activations, cfg ) print(f"[merge] Step 5/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush() # --- Step 5.5: RAM RL-weight disentanglement check (2601.13572) --- use_ram = ( cfg.use_ram_disentangle and source_config.architecture in ("transformer", "transformer+mtp") and source_config.merge_risk in ("low", "medium") and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"]) ) # Validate that the RAM base model actually exists before we try loading it if use_ram: base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "") if base_hf_id == source_config.hf_id: # Stripping didn't change anything — no base model to compare against print(f"[merge] RAM skipped: no base model ID derivable from {source_config.hf_id}") use_ram = False else: # Check if the base model exists on HuggingFace try: from huggingface_hub import model_info model_info(base_hf_id) print(f"[merge] RAM base model verified: {base_hf_id}") except Exception: print(f"[merge] RAM skipped: base model {base_hf_id} not found on HuggingFace") use_ram = False # --- Step 5.7: Free source model, move target back to GPU --- # Source model was moved to CPU in step 4.9. Extract state dict, then delete. # Move target model back to GPU for the fusion step. print(f"\n[merge] Step 5.7: Extracting source state + moving target back to GPU..."); sys.stdout.flush() step_t = time.time() source_state_cpu = {k: v.cpu() for k, v in source_model.state_dict().items()} del source_model gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Move target back to GPU for fusion target_model = target_model.to("cuda") if torch.cuda.is_available(): free_mem = torch.cuda.mem_get_info()[0] / 1e9 total_mem = torch.cuda.mem_get_info()[1] / 1e9 print(f"[merge] GPU memory (target on GPU, source freed): {free_mem:.1f} GB free / {total_mem:.1f} GB total") print(f"[merge] Step 5.7 done in {time.time()-step_t:.0f}s"); sys.stdout.flush() # --- Step 6: Pre-merge protection --- print(f"\n[merge] Step 6/10: Pre-merge protection..."); sys.stdout.flush() step_t = time.time() adjusted_alpha = protection.before_merge(target_model, source_config) # Override source alpha with time-adjusted value source_config_adjusted = copy.copy(source_config) source_config_adjusted.merge_alpha = adjusted_alpha # Save pre-merge state for protection pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()} print(f"[merge] Step 6/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush() # --- Step 7: Fuse weights --- print(f"\n[merge] Step 7/10: Fusing weights..."); sys.stdout.flush() step_t = time.time() if use_ram: # RAM path: disentangle RL weights, merge with preservation print(f"\n[merge] Using RAM RL-preservation for {stage_name}...") try: base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "") print(f"[merge] Loading base model for RAM: {base_hf_id}") base_model = AutoModelForCausalLM.from_pretrained( base_hf_id, torch_dtype=getattr(torch, cfg.dtype), device_map=cfg.device_map, trust_remote_code=source_config.trust_remote_code, ) shared_mask, rl_mask = disentangle_rl_weights( source_state_cpu, base_model, cfg.ram_rl_threshold ) # Fuse with RL preservation target_state = merge_with_rl_preservation( target_model.state_dict(), source_state_cpu, shared_mask, rl_mask, shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha), rl_alpha=cfg.ram_rl_alpha, ) target_model.load_state_dict(target_state) del base_model gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"[merge] RAM merge complete for {stage_name}") except Exception as e: print(f"[merge] RAM failed ({e}), falling back to standard T&M merge") target_model = fuse_weights( source_state_cpu, target_model, transport_plans, source_config_adjusted, cfg, ) else: # Standard T&M path (source_state_cpu is on CPU, fuse_weights moves per-param) target_model = fuse_weights( source_state_cpu, target_model, transport_plans, source_config_adjusted, cfg, ) # --- Step 7.5: Theseus fallback check (2602.12952) --- # If T&M merge produced poor activation alignment, try Theseus # NOTE: source_model was freed in step 5.7 — Theseus needs full model reload if cfg.use_theseus_fallback and source_config.merge_risk == "high": print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...") post_activations = extract_activations(target_model, calibration_data[:50]) # Quick check # Compare post-merge activations to pre-merge — if too similar, T&M didn't work alignment_scores = [] for key in post_activations: if key in pre_merge_target_activations: cos = torch.nn.functional.cosine_similarity( post_activations[key].float().mean(0, keepdim=True), pre_merge_target_activations[key].float().mean(0, keepdim=True), ) alignment_scores.append(cos.item()) avg_change = 1.0 - np.mean(alignment_scores) if alignment_scores else 0.0 print(f"[merge] Activation change from merge: {avg_change:.4f}") if avg_change < 0.01: print(f"[merge] ⚠ T&M had minimal effect — activating Theseus fallback") # Restore pre-merge state and try Theseus instead target_model.load_state_dict(pre_merge_state) try: # Reload source model for Theseus (it was freed in step 5.7) print(f"[merge] Reloading source model for Theseus fallback...") source_model_reload, _ = load_model(source_config, cfg) base_model = AutoModelForCausalLM.from_pretrained( source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0], torch_dtype=getattr(torch, cfg.dtype), device_map=cfg.device_map, trust_remote_code=source_config.trust_remote_code, ) target_model = transport_task_vector_theseus( source_model_reload, base_model, target_model, source_activations, pre_merge_target_activations, alpha=cfg.theseus_alpha, ) del base_model, source_model_reload gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"[merge] Theseus transport complete for {stage_name}") except Exception as e: print(f"[merge] Theseus also failed ({e}). Using original T&M result.") # Re-apply T&M result using CPU state dict target_model = fuse_weights( source_state_cpu, target_model, transport_plans, source_config_adjusted, cfg, ) print(f"[merge] Step 7/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush() # --- Step 8: Apply post-merge protection (ARM + OTMF + MagMax) --- print(f"\n[merge] Step 8/10: Post-merge protection..."); sys.stdout.flush() step_t = time.time() # Skip vision encoder params — they weren't merged, so don't "protect" them if protection.merge_count > 0: print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...") target_state = target_model.state_dict() protected_count = 0 vision_skipped = 0 for key in target_state: if is_vision_param(key, cfg): vision_skipped += 1 continue # Don't touch vision encoder if key in pre_merge_state: protected_param = protection.apply_protection( target_state, pre_merge_state, key ) target_state[key] = protected_param protected_count += 1 target_model.load_state_dict(target_state) print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)") print(f"[merge] Step 8/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush() # --- Step 8.5: Extract post-merge activations for ARM/OTMF --- print(f"\n[merge] Step 8.5/10: Post-merge activations + ARM/OTMF prep..."); sys.stdout.flush() step_t = time.time() arm_sample_size = 100 # Use a small subset for speed post_merge_activations = extract_activations(target_model, calibration_data[:arm_sample_size]) # Slice pre_merge_target_activations to match post_merge sample count # (pre_merge used all 1500 samples, post_merge uses 100 — ARM needs same shape) pre_merge_activations_subset = {} for key in pre_merge_target_activations: act = pre_merge_target_activations[key] pre_merge_activations_subset[key] = act[:arm_sample_size] # Record this merge's delta + compute ARM/OTMF for next merge protection.after_merge( target_model, pre_merge_state, pre_merge_activations=pre_merge_activations_subset, post_merge_activations=post_merge_activations, ) print(f"[merge] Step 8.5/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush() # --- Step 8.8: Save residuals (what was lost from both sides) --- print(f"\n[merge] Step 9/10: Saving residuals..."); sys.stdout.flush() step_t = time.time() if residual_bank is not None: print(f"\n[merge] Saving residuals for {stage_name}...") try: residual_bank.save_residuals( stage_name=stage_name, pre_merge_target_state=pre_merge_state, source_state=source_state_cpu, # Already on CPU from step 5.7 post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()}, source_config=source_config, ) except Exception as e: print(f"[merge] WARNING: Residual save failed ({e}) — continuing without residuals") print(f"[merge] This is non-fatal, merge is still valid") print(f"[merge] Step 9/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush() # --- Step 9: Free remaining memory --- # source_model was already freed in step 5.7 del source_state_cpu, source_activations, pre_merge_target_activations del transport_plans, post_merge_activations gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # --- Step 10: Validate --- print(f"\n[merge] Step 10/10: Validating merge..."); sys.stdout.flush() step_t = time.time() merged_sources.append(stage_name) validation = validate_merged_model( target_model, target_tokenizer, merged_sources, cfg, baseline_perplexity=baseline_perplexity, ) print(f"[merge] Step 10/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush() result["validation"] = validation result["merged_sources"] = merged_sources.copy() total_time = time.time() - stage_start print(f"\n[merge] Total time for {stage_name}: {total_time/60:.1f} min"); sys.stdout.flush() # --- Kill criteria check --- if not validation["overall"]: print(f"\n[merge] ⚠ VALIDATION FAILED for {stage_name}") print(f"[merge] Kill criteria triggered — consider aborting") result["status"] = "failed" # Check if we should try distillation fallback if "distillation_fallback" in source_config.special_handling: print(f"[merge] {stage_name} has distillation fallback available") result["fallback"] = "distillation" else: print(f"\n[merge] ✓ {stage_name} merge PASSED validation") result["status"] = "passed" return result def run_pipeline( stages: list[str], cfg: MergeConfig = None, base_checkpoint: str = None, ) -> dict: """ Run the full merge pipeline. Args: stages: List of stage names to run, e.g. ["deepseek"] or ["deepseek", "mimo", "llama", "falcon"] cfg: Merge configuration (uses defaults if None) Returns: Dict with overall results, per-stage results, and final model path """ if cfg is None: cfg = MergeConfig() pipeline_start = time.time() print("\n" + "=" * 70) print("TD FUSE — Transport and Merge Pipeline") print(f"Target: {TARGET.name} ({TARGET.hf_id})") if TARGET.architecture == "transformer+vision": print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)") print(f"Stages: {', '.join(stages)}") print(f"Output: {cfg.output_dir}") print(f"Started at: {time.strftime('%H:%M:%S')}") print("=" * 70) sys.stdout.flush() # --- Pre-flight: check which models are cached --- check_all_models_cached(stages) # Setup try: setup_tm_repo(cfg) except FileNotFoundError as e: print(f"\n WARNING: {e}") print("Continuing with fallback implementation...") # Create output directories Path(cfg.output_dir).mkdir(parents=True, exist_ok=True) Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True) # --- Load target model (from checkpoint if stacking merges, else from HuggingFace) --- if base_checkpoint and Path(base_checkpoint).exists(): print(f"\n[pipeline] Loading target from previous merge: {base_checkpoint}") from transformers import AutoModelForImageTextToText target_model = AutoModelForImageTextToText.from_pretrained( base_checkpoint, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) target_tokenizer = AutoTokenizer.from_pretrained(base_checkpoint, trust_remote_code=True) else: target_model, target_tokenizer = load_model(TARGET, cfg) # --- Inject canary into target (Qwen3's own canary) --- # Skip if loading from checkpoint (canary already injected in previous merge) if "Qwen3-VL-8B" in CANARY_FACTS and not base_checkpoint: print("\n[pipeline] Injecting canary into base Qwen3-8B...") target_model = inject_canary(target_model, target_tokenizer, "Qwen3-VL-8B") elif base_checkpoint: print("\n[pipeline] Skipping canary injection (already in checkpoint)") # --- Compute baseline perplexity --- print("\n[pipeline] Computing baseline perplexity...") baseline_ppl = compute_perplexity(target_model, target_tokenizer) print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}") # --- Load calibration data once --- calibration_data, calibration_raw_texts = load_calibration_data(cfg, target_tokenizer) # --- Initialize merge protection + residual bank --- protection = MergeProtection(cfg) residual_bank = ResidualBank(cfg) # --- Run each merge stage --- pipeline_results = { "stages": {}, "baseline_perplexity": baseline_ppl, "final_checkpoint": None, "residuals": {}, "overall_status": "pending", } merged_sources = [] all_passed = True for stage_name in stages: source_config = get_source_by_stage(stage_name) if source_config is None: print(f"\n⚠ Unknown stage: {stage_name}, skipping") continue # --- Wasserstein pre-check for high-risk models --- if "check_wasserstein_first" in source_config.special_handling: print(f"\n[pipeline] Running Wasserstein pre-check for {source_config.name}...") # TODO: Implement Wasserstein distance pre-check # If distance is too high, skip to distillation fallback print("[pipeline] Pre-check: proceeding (TODO: implement distance check)") # Run the merge (with residual bank to save what's lost) stage_result = run_single_merge( target_model, target_tokenizer, source_config, cfg, protection, residual_bank=residual_bank, calibration_data=calibration_data, calibration_raw_texts=calibration_raw_texts, baseline_perplexity=baseline_ppl, merged_sources=merged_sources, ) pipeline_results["stages"][stage_name] = stage_result if stage_result["status"] == "passed": # Save checkpoint ckpt_path = save_checkpoint( target_model, target_tokenizer, stage_name, cfg ) stage_result["checkpoint"] = ckpt_path pipeline_results["final_checkpoint"] = ckpt_path else: all_passed = False print(f"\n[pipeline] Stage {stage_name} FAILED validation") # Check if perplexity is still reasonable (model isn't broken) ppl_ratio = stage_result.get("validation", {}).get("perplexity", {}).get("ratio", 999) if ppl_ratio < 2.0: # Model is coherent — save checkpoint despite validation failure print(f"[pipeline] Perplexity ratio {ppl_ratio:.2f} is acceptable — saving checkpoint anyway") print(f"[pipeline] (Failed on canary/thinking mode, but model is functional)") ckpt_path = save_checkpoint( target_model, target_tokenizer, stage_name, cfg ) stage_result["checkpoint"] = ckpt_path pipeline_results["final_checkpoint"] = ckpt_path # Continue to next merge instead of aborting continue elif source_config.merge_risk == "high": print(f"[pipeline] High-risk model failed — skipping (will use distillation)") continue else: print(f"[pipeline] ABORTING pipeline — perplexity ratio {ppl_ratio:.2f} too high") pipeline_results["overall_status"] = f"aborted_at_{stage_name}" break # --- Save residual index --- pipeline_results["residuals"] = residual_bank.residual_index if residual_bank.residual_index: print(f"\n[pipeline] Residual bank: {len(residual_bank.residual_index)} stages saved") for stage, info in residual_bank.residual_index.items(): print(f" {stage}: target lost {info['total_target_loss']:.4f}, source lost {info['total_source_loss']:.4f}") # Identify which modules need the most healing healing_targets = residual_bank.get_healing_targets(top_n=50) pipeline_results["suggested_healing_targets"] = healing_targets # --- Skip final model save (duplicate of checkpoint, wastes 17GB disk) --- # The checkpoint in td_fuse_checkpoints/after_ IS the final model if pipeline_results["final_checkpoint"]: pipeline_results["final_model_path"] = pipeline_results["final_checkpoint"] print(f"\n[pipeline] Final model is at: {pipeline_results['final_checkpoint']}") # Clean up models/base if still around import shutil as _shutil for _cleanup in ["models/base", "td_fuse_outputs/final"]: _cp = Path(_cleanup) if _cp.exists() and _cp.is_dir(): _shutil.rmtree(str(_cp)) print(f"[merge] Freed disk: {_cleanup}") if all_passed: pipeline_results["overall_status"] = "all_passed" elif pipeline_results["overall_status"] == "pending": pipeline_results["overall_status"] = "partial" # --- Print final summary --- print("\n" + "=" * 70) print("PIPELINE SUMMARY") print("=" * 70) for stage_name, stage_result in pipeline_results["stages"].items(): status = stage_result["status"] emoji = "✓" if status == "passed" else "✗" print(f" {emoji} {stage_name}: {status}") print(f"\n Overall: {pipeline_results['overall_status']}") total_pipeline_time = time.time() - pipeline_start print(f"\n Total pipeline time: {total_pipeline_time/60:.1f} min ({total_pipeline_time/3600:.1f} hours)") if residual_bank.residual_index: print(f"\n Residuals saved for: {', '.join(residual_bank.residual_index.keys())}") print(f" To recover lost knowledge later:") print(f" python -m td_fuse.run --reinject --strength 0.2") print("=" * 70) sys.stdout.flush() return pipeline_results