td-toolkit / td_fuse /merge.py
td-builder's picture
Current td_fuse code with all fixes
bc446a5 verified
"""
Sequential Merge Orchestrator — chains 4 merges with protection.
This is the brain of td_fuse. It runs each merge in order:
1. Load source model
2. Inject canary fact into source
3. Extract activations from both models
4. Compute transport plans (P and Q matrices)
5. Fuse weights using optimal transport
6. Validate merged model (canary recall, perplexity, thinking mode)
7. Apply sequential merge protection before next merge
8. Checkpoint
Protection between merges (findings #13):
- MagMax: Protect top 20% parameters by magnitude (they carry critical knowledge)
- Orthogonal Projection: Project new merge deltas perpendicular to previous ones
- Time-Aware Scaling: scale = 1/sqrt(merge_index + 1)
Kill criteria: >10% performance drop on any test → abort merge.
Findings: #13, #22, #25
"""
import os
import gc
import sys
import copy
import time
import torch
import numpy as np
from pathlib import Path
from typing import Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from .config import (
MergeConfig, ModelConfig, TARGET, SOURCES,
CANARY_FACTS, DEMO_STAGES, FULL_STAGES,
)
from .canary import inject_canary, test_all_canaries
from .transport import (
setup_tm_repo,
load_calibration_data,
retokenize_calibration,
extract_activations,
compute_transport_plans,
fuse_weights,
)
from .validate import validate_merged_model, compute_perplexity
from .techniques import (
compute_mergeability_score,
compute_transferability_masks,
apply_masked_merge,
disentangle_rl_weights,
merge_with_rl_preservation,
compute_arm_rotation,
apply_arm_steering,
transport_task_vector_theseus,
compute_procrustes_alignment,
)
# ============================================================================
# SEQUENTIAL MERGE PROTECTION
# ============================================================================
class MergeProtection:
"""
Protects previously merged knowledge from being overwritten.
Think of it like this: after merging DeepSeek into Qwen3, we have
a "direction" in weight space that represents that merge. When we
then merge MiMo, we want MiMo's changes to go in a DIFFERENT direction,
not overwrite DeepSeek's contribution.
Three mechanisms:
1. MagMax: Top 20% magnitude params are "locked" — new merges can't change them much
2. Orthogonal Projection: New deltas are projected perpendicular to previous deltas
3. Time-Aware Scaling: Each successive merge gets a smaller alpha (1/sqrt(n+1))
"""
def __init__(self, cfg: MergeConfig):
self.cfg = cfg
self.previous_deltas = {} # key → list of delta tensors from previous merges
self.magnitude_masks = {} # key → bool mask of top-k magnitude params
self.arm_rotations = {} # ARM: layer → rotation info from last merge
self.otmf_masks = {} # OTMF: param → transferability mask
self.merge_count = 0
def before_merge(
self,
target_model: AutoModelForCausalLM,
source_config: ModelConfig,
) -> float:
"""
Prepare protection before a merge. Returns adjusted alpha.
Called BEFORE each merge to:
1. Compute magnitude masks (MagMax)
2. Calculate time-aware alpha scaling
"""
# Time-aware scaling: each merge gets less aggressive
if self.cfg.time_aware_scaling:
scale = 1.0 / np.sqrt(self.merge_count + 1)
adjusted_alpha = source_config.merge_alpha * scale
print(f"[protect] Time-aware scaling: {source_config.merge_alpha:.2f} × {scale:.3f} = {adjusted_alpha:.3f}")
else:
adjusted_alpha = source_config.merge_alpha
# MagMax: identify top 20% magnitude parameters to protect
if self.cfg.use_magmax and self.merge_count > 0:
print(f"[protect] Computing MagMax masks (protecting top 20% by magnitude)...")
state = target_model.state_dict()
for key, param in state.items():
if param.dim() >= 1:
flat = param.abs().flatten()
threshold = torch.quantile(flat.float(), 0.8)
self.magnitude_masks[key] = param.abs() >= threshold
return adjusted_alpha
def apply_protection(
self,
target_state: dict,
pre_merge_state: dict,
key: str,
) -> torch.Tensor:
"""
Apply all protection mechanisms to a fused parameter.
Called AFTER each parameter is fused, to constrain the change.
Protection stack (applied in order):
1. ARM steering (2602.03237) — steer delta toward gap, away from previous direction
2. Orthogonal projection (legacy fallback if ARM disabled)
3. OTMF masks (2511.19561) — protect task-specific weights
4. MagMax — protect top magnitude params (extra safety layer)
"""
fused = target_state[key]
original = pre_merge_state[key].to(fused.device)
delta = fused - original
# --- ARM Steering (new, replaces orthogonal projection) ---
if self.cfg.use_arm_steering and self.arm_rotations:
# Find matching layer rotation
layer_prefix = ".".join(key.split(".")[:4])
for layer_name, rotation_info in self.arm_rotations.items():
if layer_prefix in layer_name:
delta = apply_arm_steering(
delta, rotation_info,
steering_strength=self.cfg.arm_steering_strength,
)
break
# --- Orthogonal Projection (legacy fallback) ---
elif self.cfg.use_orthogonal_projection and key in self.previous_deltas:
for prev_delta in self.previous_deltas[key]:
prev_flat = prev_delta.flatten().float()
delta_flat = delta.flatten().float()
dot = torch.dot(delta_flat, prev_flat)
norm_sq = torch.dot(prev_flat, prev_flat)
if norm_sq > 1e-10:
projection = (dot / norm_sq) * prev_flat
delta_flat = delta_flat - projection
delta = delta_flat.reshape(delta.shape).to(delta.dtype)
# --- OTMF Mask Protection (new) ---
if self.cfg.use_otmf_masks and key in self.otmf_masks:
mask = self.otmf_masks[key].to(delta.device)
# Transferable weights: full delta
# Task-specific weights: reduced delta (protect them)
delta = torch.where(
mask,
delta, # Transferable → allow full change
delta * (1.0 - self.cfg.otmf_protect_strength), # Protected → reduced
)
# --- MagMax Protection (extra safety layer) ---
if self.cfg.use_magmax and key in self.magnitude_masks:
mask = self.magnitude_masks[key]
delta = torch.where(mask, delta * 0.1, delta)
# Apply constrained delta
result = original + delta
return result
def after_merge(
self,
target_model: AutoModelForCausalLM,
pre_merge_state: dict,
pre_merge_activations: dict = None,
post_merge_activations: dict = None,
):
"""
Record the merge delta and compute protections for next merge.
Called AFTER each merge completes successfully.
Now also computes:
- ARM rotation vectors for next merge steering
- OTMF transferability masks for next merge
"""
current_state = target_model.state_dict()
for key in current_state:
if key in pre_merge_state:
delta = current_state[key].cpu().float() - pre_merge_state[key].cpu().float()
if delta.abs().max() > 1e-8:
if key not in self.previous_deltas:
self.previous_deltas[key] = []
if len(self.previous_deltas[key]) >= 2:
self.previous_deltas[key].pop(0)
self.previous_deltas[key].append(delta.cpu())
# --- Compute ARM rotations for next merge ---
if self.cfg.use_arm_steering and pre_merge_activations and post_merge_activations:
print("[protect] Computing ARM rotation vectors for next merge...")
self.arm_rotations = compute_arm_rotation(
pre_merge_activations,
post_merge_activations,
post_merge_activations, # Target = current state (for gap calculation)
)
# --- Compute OTMF masks for next merge ---
if self.cfg.use_otmf_masks and post_merge_activations:
print("[protect] Computing OTMF transferability masks...")
self.otmf_masks = compute_transferability_masks(
target_model,
post_merge_activations,
threshold=self.cfg.otmf_threshold,
)
self.merge_count += 1
print(f"[protect] Recorded merge delta #{self.merge_count} (ARM + OTMF ready for next)")
# ============================================================================
# MAIN ORCHESTRATOR
# ============================================================================
def is_vision_param(key: str, cfg: MergeConfig) -> bool:
"""
Check if a parameter belongs to the vision encoder.
Qwen3-VL-8B has a ViT vision encoder + merger projection on top of the
language model. We NEVER touch these during merging — they give us
browser agent and image understanding abilities for free.
Vision params start with prefixes like "visual." or "merger."
Language params start with "model.layers." or "model.embed_tokens." etc.
"""
for prefix in cfg.vision_skip_prefixes:
if key.startswith(prefix):
return True
return False
def get_source_by_stage(stage_name: str) -> Optional[ModelConfig]:
"""Get model config by stage name."""
stage_map = {
"deepseek": 0,
"mimo": 1,
"llama": 2,
"falcon": 3,
}
idx = stage_map.get(stage_name.lower())
if idx is not None and idx < len(SOURCES):
return SOURCES[idx]
return None
def check_model_cached(hf_id: str) -> bool:
"""Check if a model is already in the HuggingFace cache."""
try:
from huggingface_hub import try_to_load_from_cache, model_info
# Quick check: see if config.json is cached (every model has one)
cached = try_to_load_from_cache(hf_id, "config.json")
if cached is not None and isinstance(cached, str):
return True
except Exception:
pass
return False
def check_all_models_cached(stages: list) -> dict:
"""
Pre-flight check: are all needed models already downloaded?
Prints a clear table so you know what's cached and what will download.
"""
print("\n" + "=" * 60)
print("PRE-FLIGHT CHECK: Model cache status")
print("=" * 60)
sys.stdout.flush()
status = {}
# Target model
cached = check_model_cached(TARGET.hf_id)
tag = "CACHED" if cached else "WILL DOWNLOAD"
print(f" {TARGET.name:25s} {tag:15s} ({TARGET.hf_id})")
status[TARGET.name] = cached
# Source models for requested stages
for stage_name in stages:
source = get_source_by_stage(stage_name)
if source:
cached = check_model_cached(source.hf_id)
tag = "CACHED" if cached else "WILL DOWNLOAD"
print(f" {source.name:25s} {tag:15s} ({source.hf_id})")
status[source.name] = cached
not_cached = [name for name, c in status.items() if not c]
if not_cached:
print(f"\n {len(not_cached)} model(s) need downloading: {', '.join(not_cached)}")
print(f" This may take 10-30 min per model depending on connection speed.")
else:
print(f"\n All {len(status)} models are cached -- loading will be fast!")
print("=" * 60)
sys.stdout.flush()
return status
def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
"""Load a model and its tokenizer/processor."""
load_start = time.time()
cached = check_model_cached(config.hf_id)
cache_msg = "(from cache)" if cached else "(downloading -- this may take a while)"
print(f"\n[merge] Loading {config.name} ({config.hf_id}) {cache_msg}...")
sys.stdout.flush()
# Qwen3-VL uses a processor (handles both text + vision), not just a tokenizer
if config.architecture == "transformer+vision":
try:
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
processor = AutoProcessor.from_pretrained(
config.hf_id,
trust_remote_code=config.trust_remote_code,
)
model = Qwen3VLForConditionalGeneration.from_pretrained(
config.hf_id,
torch_dtype=getattr(torch, cfg.dtype),
attn_implementation=cfg.attn_implementation,
device_map=cfg.device_map,
trust_remote_code=config.trust_remote_code,
)
# Use the tokenizer from the processor for text operations
tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor
print(f"[merge] Loaded {config.name} (VL model): {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
# Count vision vs language params
vision_params = sum(
p.numel() for n, p in model.named_parameters()
if any(n.startswith(pfx) for pfx in cfg.vision_skip_prefixes)
)
lang_params = sum(p.numel() for p in model.parameters()) - vision_params
print(f"[merge] Language: {lang_params / 1e9:.1f}B | Vision: {vision_params / 1e9:.1f}B")
print(f"[merge] Loaded in {time.time()-load_start:.0f}s"); sys.stdout.flush()
return model, tokenizer
except ImportError:
print("[merge] Qwen3VLForConditionalGeneration not available, falling back to AutoModel")
# Standard text-only models
tokenizer = AutoTokenizer.from_pretrained(
config.hf_id,
trust_remote_code=config.trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
config.hf_id,
torch_dtype=getattr(torch, cfg.dtype),
attn_implementation=cfg.attn_implementation,
device_map=cfg.device_map,
trust_remote_code=config.trust_remote_code,
)
print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
print(f"[merge] Loaded in {time.time()-load_start:.0f}s"); sys.stdout.flush()
return model, tokenizer
def save_checkpoint(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
stage_name: str,
cfg: MergeConfig,
):
"""Save a checkpoint after a successful merge stage."""
import shutil
ckpt_base = Path(cfg.checkpoint_dir)
ckpt_dir = ckpt_base / f"after_{stage_name}"
# --- Pre-save cleanup: free disk space ---
# 1. Delete residuals (non-essential, 5-20GB)
residuals_dir = ckpt_base / "residuals"
if residuals_dir.exists():
shutil.rmtree(str(residuals_dir), ignore_errors=True)
print(f"[merge] Freed disk: deleted residuals")
# 2. Delete td_fuse_outputs/final (duplicate of last checkpoint, ~17GB)
final_dir = Path("td_fuse_outputs") / "final"
if final_dir.exists():
shutil.rmtree(str(final_dir), ignore_errors=True)
print(f"[merge] Freed disk: deleted td_fuse_outputs/final")
# 3. Delete OLD checkpoints (already on HuggingFace via watcher)
if ckpt_base.exists():
for old_ckpt in ckpt_base.glob("after_*"):
if old_ckpt.name != f"after_{stage_name}" and old_ckpt.is_dir():
shutil.rmtree(str(old_ckpt), ignore_errors=True)
print(f"[merge] Freed disk: deleted old checkpoint {old_ckpt.name}")
# Check disk space
import shutil as sh_util
total, used, free = sh_util.disk_usage("/")
print(f"[merge] Disk after cleanup: {free/1e9:.1f} GB free / {total/1e9:.1f} GB total")
ckpt_dir.mkdir(parents=True, exist_ok=True)
print(f"[merge] Saving checkpoint to {ckpt_dir}...")
model.save_pretrained(ckpt_dir)
tokenizer.save_pretrained(ckpt_dir)
print(f"[merge] Checkpoint saved: {ckpt_dir}")
return str(ckpt_dir)
# ============================================================================
# RESIDUAL BANK — Save what was lost during each merge
# ============================================================================
class ResidualBank:
"""
Saves the knowledge that gets lost during each merge so it can
be recovered later.
When we blend at alpha=0.5:
merged = 0.5 × source + 0.5 × target
We LOSE:
target_residual = target_original - merged (what target lost)
source_residual = source_original - merged (what source lost)
These residuals are saved to disk. Later they can be:
1. Fed back during the healing fine-tune (as training signal)
2. Re-injected via a small LoRA adapter
3. Used to diagnose which merge caused a specific knowledge loss
4. Re-applied at a lower alpha if we want more of that model
Think of it like saving the sawdust when you cut wood — you might
need to glue some of it back later.
"""
def __init__(self, cfg: MergeConfig):
self.cfg = cfg
self.residual_dir = Path(cfg.checkpoint_dir) / "residuals"
self.residual_dir.mkdir(parents=True, exist_ok=True)
self.residual_index = {} # stage → {path, stats}
def save_residuals(
self,
stage_name: str,
pre_merge_target_state: dict,
source_state: dict,
post_merge_state: dict,
source_config: ModelConfig,
):
"""
Compute and save what was lost from both target and source.
Saves two files per merge stage:
- target_residual: what the target model lost
- source_residual: what the source model didn't fully contribute
Also saves stats so we know WHERE the biggest losses were
(which layers, which type of weights).
"""
stage_dir = self.residual_dir / stage_name
stage_dir.mkdir(parents=True, exist_ok=True)
target_residual = {}
source_residual = {}
stats = {
"stage": stage_name,
"source_model": source_config.name,
"target_loss_by_layer": {},
"source_loss_by_layer": {},
"total_target_loss": 0.0,
"total_source_loss": 0.0,
"biggest_losses": [],
}
for key in post_merge_state:
merged_w = post_merge_state[key].float()
# What the target lost
if key in pre_merge_target_state:
original_target = pre_merge_target_state[key].float()
t_residual = original_target - merged_w
t_loss = t_residual.abs().mean().item()
if t_loss > 1e-6: # Only save meaningful residuals
target_residual[key] = t_residual.to(torch.bfloat16).cpu()
stats["total_target_loss"] += t_loss
# Track per-layer losses
layer_name = ".".join(key.split(".")[:4])
if layer_name not in stats["target_loss_by_layer"]:
stats["target_loss_by_layer"][layer_name] = 0.0
stats["target_loss_by_layer"][layer_name] += t_loss
# What the source lost (what didn't make it into the merge)
if key in source_state:
original_source = source_state[key].float()
# Skip if shapes don't match (e.g. vocab size mismatch on embeddings/lm_head)
if original_source.shape != merged_w.shape:
continue
s_residual = original_source - merged_w
s_loss = s_residual.abs().mean().item()
if s_loss > 1e-6:
source_residual[key] = s_residual.to(torch.bfloat16).cpu()
stats["total_source_loss"] += s_loss
layer_name = ".".join(key.split(".")[:4])
if layer_name not in stats["source_loss_by_layer"]:
stats["source_loss_by_layer"][layer_name] = 0.0
stats["source_loss_by_layer"][layer_name] += s_loss
# Find the biggest losses (most knowledge dropped)
all_losses = []
for key in target_residual:
loss_magnitude = target_residual[key].float().abs().mean().item()
all_losses.append({"param": key, "side": "target", "loss": loss_magnitude})
for key in source_residual:
loss_magnitude = source_residual[key].float().abs().mean().item()
all_losses.append({"param": key, "side": "source", "loss": loss_magnitude})
all_losses.sort(key=lambda x: x["loss"], reverse=True)
stats["biggest_losses"] = all_losses[:20] # Top 20 biggest losses
# Save to disk
torch.save(target_residual, stage_dir / "target_residual.pt")
torch.save(source_residual, stage_dir / "source_residual.pt")
import json
with open(stage_dir / "residual_stats.json", "w") as f:
json.dump(stats, f, indent=2, default=str)
self.residual_index[stage_name] = {
"path": str(stage_dir),
"target_params_saved": len(target_residual),
"source_params_saved": len(source_residual),
"total_target_loss": stats["total_target_loss"],
"total_source_loss": stats["total_source_loss"],
}
print(f"[residual] Saved residuals for {stage_name}:")
print(f" Target lost: {len(target_residual)} params (avg loss: {stats['total_target_loss']:.4f})")
print(f" Source lost: {len(source_residual)} params (avg loss: {stats['total_source_loss']:.4f})")
print(f" Top loss: {all_losses[0]['param']} ({all_losses[0]['side']}, {all_losses[0]['loss']:.4f})" if all_losses else "")
print(f" Saved to: {stage_dir}")
def load_residuals(self, stage_name: str) -> tuple:
"""
Load saved residuals for a stage.
Returns:
(target_residual_dict, source_residual_dict)
"""
stage_dir = self.residual_dir / stage_name
target_residual = torch.load(stage_dir / "target_residual.pt", weights_only=True)
source_residual = torch.load(stage_dir / "source_residual.pt", weights_only=True)
return target_residual, source_residual
def reinject_residuals(
self,
model: AutoModelForCausalLM,
stage_name: str,
side: str = "both",
strength: float = 0.3,
) -> AutoModelForCausalLM:
"""
Re-inject saved residuals back into a model.
This adds back some of what was lost. Use a low strength (0.1-0.3)
to gently recover knowledge without undoing the merge.
Args:
model: The model to inject into
stage_name: Which merge stage's residuals to use
side: "target", "source", or "both"
strength: How much to add back (0=nothing, 1=full residual)
"""
print(f"[residual] Re-injecting {stage_name} residuals (side={side}, strength={strength})...")
target_residual, source_residual = self.load_residuals(stage_name)
state = model.state_dict()
injected = 0
if side in ("target", "both"):
for key, residual in target_residual.items():
if key in state:
state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
injected += 1
if side in ("source", "both"):
for key, residual in source_residual.items():
if key in state:
state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
injected += 1
model.load_state_dict(state)
print(f"[residual] Re-injected {injected} params at {strength:.0%} strength")
return model
def get_healing_targets(self, top_n: int = 50) -> list:
"""
Get the parameters with the biggest losses across ALL merges.
These are the params that the healing fine-tune should focus on.
Feed this to the LoRA target_modules to make healing smarter.
"""
import json
all_losses = []
for stage_name in self.residual_index:
stage_dir = self.residual_dir / stage_name
stats_file = stage_dir / "residual_stats.json"
if stats_file.exists():
with open(stats_file) as f:
stats = json.load(f)
for loss in stats.get("biggest_losses", []):
loss["stage"] = stage_name
all_losses.append(loss)
all_losses.sort(key=lambda x: x["loss"], reverse=True)
# Extract unique layer/module names for LoRA targeting
target_modules = set()
for loss in all_losses[:top_n]:
param = loss["param"]
# Extract the module type (q_proj, k_proj, gate_proj, etc.)
parts = param.split(".")
for part in parts:
if part.endswith("_proj") or part in ("gate_proj", "up_proj", "down_proj"):
target_modules.add(part)
print(f"[residual] Top healing targets (from {len(all_losses)} total losses):")
for loss in all_losses[:5]:
print(f" {loss['param']} ({loss['side']}, stage={loss['stage']}, loss={loss['loss']:.4f})")
print(f" → Suggested LoRA targets: {sorted(target_modules)}")
return list(target_modules)
def run_single_merge(
target_model: AutoModelForCausalLM,
target_tokenizer: AutoTokenizer,
source_config: ModelConfig,
cfg: MergeConfig,
protection: MergeProtection,
residual_bank: ResidualBank = None,
calibration_data: list = None,
calibration_raw_texts: list = None,
baseline_perplexity: float = None,
merged_sources: list = None,
) -> dict:
"""
Run a single merge: source → target.
Full pipeline for one merge step:
1. Load source model
2. Inject canary into source
3. Extract activations from both
4. Compute transport plans
5. Apply merge protection
6. Fuse weights
7. Apply post-merge protection
8. Validate
Returns:
Dict with merge results, validation results, and status
"""
if merged_sources is None:
merged_sources = []
stage_name = source_config.name
stage_start = time.time()
print(f"\n{'=' * 70}")
print(f"MERGE STAGE: {stage_name} -> target")
print(f"Risk level: {source_config.merge_risk.upper()}")
print(f"Started at: {time.strftime('%H:%M:%S')}")
print(f"{'=' * 70}")
sys.stdout.flush()
result = {
"stage": stage_name,
"status": "pending",
"validation": None,
"checkpoint": None,
}
# --- Step 1: Load source model ---
print(f"\n[merge] Step 1/10: Loading source model..."); sys.stdout.flush()
step_t = time.time()
source_model, source_tokenizer = load_model(source_config, cfg)
print(f"[merge] Step 1/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
# --- Step 2: Inject canary into source ---
print(f"\n[merge] Step 2/10: Injecting canary..."); sys.stdout.flush()
step_t = time.time()
if stage_name in CANARY_FACTS:
source_model = inject_canary(source_model, source_tokenizer, stage_name)
print(f"[merge] Step 2/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
# --- Step 3: Load calibration data (if not provided) ---
print(f"\n[merge] Step 3/10: Loading calibration data..."); sys.stdout.flush()
step_t = time.time()
if calibration_data is None:
calibration_data, calibration_raw_texts = load_calibration_data(cfg, target_tokenizer)
print(f"[merge] Step 3/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
# --- Step 4: Extract activations ---
print(f"\n[merge] Step 4/10: Extracting activations (both models)..."); sys.stdout.flush()
step_t = time.time()
# Check if source model has a different vocabulary size than target.
source_vocab_size = len(source_tokenizer)
target_vocab_size = len(target_tokenizer)
print(f"[merge] Vocab sizes -- target: {target_vocab_size}, source: {source_vocab_size}")
if source_vocab_size != target_vocab_size:
print(f"[merge] VOCAB MISMATCH detected! Re-tokenizing calibration data for {source_config.name}...")
source_calibration = retokenize_calibration(calibration_raw_texts, source_tokenizer, cfg)
print(f"[merge] Extracting source activations (with source-tokenized data)...")
source_activations = extract_activations(source_model, source_calibration)
del source_calibration
else:
print(f"[merge] Extracting source activations...")
source_activations = extract_activations(source_model, calibration_data)
print(f"[merge] Extracting target activations...")
pre_merge_target_activations = extract_activations(target_model, calibration_data)
print(f"[merge] Step 4/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
# --- Step 4.5: Mergeability pre-check (2601.22285) ---
if cfg.use_mergeability_check:
mergeability = compute_mergeability_score(
source_activations, pre_merge_target_activations, source_config
)
result["mergeability"] = mergeability
if mergeability["overall"] < cfg.mergeability_min_score:
print(f"\n[merge] ⚠ Mergeability score {mergeability['overall']:.2f} below threshold {cfg.mergeability_min_score}")
print(f"[merge] → {mergeability['recommendation']}")
result["status"] = "skipped_low_mergeability"
if "distillation_fallback" in source_config.special_handling:
result["fallback"] = "distillation"
del source_model, source_activations, pre_merge_target_activations
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result
# --- Step 4.9: Free VRAM before transport computation ---
print(f"\n[merge] Step 4.9: Moving models to CPU to free VRAM for transport...")
sys.stdout.flush()
source_model = source_model.cpu()
target_model = target_model.cpu()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_mem = torch.cuda.mem_get_info()[0] / 1e9
total_mem = torch.cuda.mem_get_info()[1] / 1e9
print(f"[merge] GPU memory after CPU offload: {free_mem:.1f} GB free / {total_mem:.1f} GB total")
sys.stdout.flush()
# --- Step 5: Compute transport plans ---
print(f"\n[merge] Step 5/10: Computing transport plans..."); sys.stdout.flush()
step_t = time.time()
transport_plans = compute_transport_plans(
source_activations, pre_merge_target_activations, cfg
)
print(f"[merge] Step 5/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
# --- Step 5.5: RAM RL-weight disentanglement check (2601.13572) ---
use_ram = (
cfg.use_ram_disentangle
and source_config.architecture in ("transformer", "transformer+mtp")
and source_config.merge_risk in ("low", "medium")
and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"])
)
# Validate that the RAM base model actually exists before we try loading it
if use_ram:
base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
if base_hf_id == source_config.hf_id:
# Stripping didn't change anything — no base model to compare against
print(f"[merge] RAM skipped: no base model ID derivable from {source_config.hf_id}")
use_ram = False
else:
# Check if the base model exists on HuggingFace
try:
from huggingface_hub import model_info
model_info(base_hf_id)
print(f"[merge] RAM base model verified: {base_hf_id}")
except Exception:
print(f"[merge] RAM skipped: base model {base_hf_id} not found on HuggingFace")
use_ram = False
# --- Step 5.7: Free source model, move target back to GPU ---
# Source model was moved to CPU in step 4.9. Extract state dict, then delete.
# Move target model back to GPU for the fusion step.
print(f"\n[merge] Step 5.7: Extracting source state + moving target back to GPU..."); sys.stdout.flush()
step_t = time.time()
source_state_cpu = {k: v.cpu() for k, v in source_model.state_dict().items()}
del source_model
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Move target back to GPU for fusion
target_model = target_model.to("cuda")
if torch.cuda.is_available():
free_mem = torch.cuda.mem_get_info()[0] / 1e9
total_mem = torch.cuda.mem_get_info()[1] / 1e9
print(f"[merge] GPU memory (target on GPU, source freed): {free_mem:.1f} GB free / {total_mem:.1f} GB total")
print(f"[merge] Step 5.7 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
# --- Step 6: Pre-merge protection ---
print(f"\n[merge] Step 6/10: Pre-merge protection..."); sys.stdout.flush()
step_t = time.time()
adjusted_alpha = protection.before_merge(target_model, source_config)
# Override source alpha with time-adjusted value
source_config_adjusted = copy.copy(source_config)
source_config_adjusted.merge_alpha = adjusted_alpha
# Save pre-merge state for protection
pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()}
print(f"[merge] Step 6/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
# --- Step 7: Fuse weights ---
print(f"\n[merge] Step 7/10: Fusing weights..."); sys.stdout.flush()
step_t = time.time()
if use_ram:
# RAM path: disentangle RL weights, merge with preservation
print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
try:
base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
print(f"[merge] Loading base model for RAM: {base_hf_id}")
base_model = AutoModelForCausalLM.from_pretrained(
base_hf_id,
torch_dtype=getattr(torch, cfg.dtype),
device_map=cfg.device_map,
trust_remote_code=source_config.trust_remote_code,
)
shared_mask, rl_mask = disentangle_rl_weights(
source_state_cpu, base_model, cfg.ram_rl_threshold
)
# Fuse with RL preservation
target_state = merge_with_rl_preservation(
target_model.state_dict(),
source_state_cpu,
shared_mask, rl_mask,
shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha),
rl_alpha=cfg.ram_rl_alpha,
)
target_model.load_state_dict(target_state)
del base_model
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"[merge] RAM merge complete for {stage_name}")
except Exception as e:
print(f"[merge] RAM failed ({e}), falling back to standard T&M merge")
target_model = fuse_weights(
source_state_cpu, target_model, transport_plans,
source_config_adjusted, cfg,
)
else:
# Standard T&M path (source_state_cpu is on CPU, fuse_weights moves per-param)
target_model = fuse_weights(
source_state_cpu, target_model, transport_plans,
source_config_adjusted, cfg,
)
# --- Step 7.5: Theseus fallback check (2602.12952) ---
# If T&M merge produced poor activation alignment, try Theseus
# NOTE: source_model was freed in step 5.7 — Theseus needs full model reload
if cfg.use_theseus_fallback and source_config.merge_risk == "high":
print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...")
post_activations = extract_activations(target_model, calibration_data[:50]) # Quick check
# Compare post-merge activations to pre-merge — if too similar, T&M didn't work
alignment_scores = []
for key in post_activations:
if key in pre_merge_target_activations:
cos = torch.nn.functional.cosine_similarity(
post_activations[key].float().mean(0, keepdim=True),
pre_merge_target_activations[key].float().mean(0, keepdim=True),
)
alignment_scores.append(cos.item())
avg_change = 1.0 - np.mean(alignment_scores) if alignment_scores else 0.0
print(f"[merge] Activation change from merge: {avg_change:.4f}")
if avg_change < 0.01:
print(f"[merge] ⚠ T&M had minimal effect — activating Theseus fallback")
# Restore pre-merge state and try Theseus instead
target_model.load_state_dict(pre_merge_state)
try:
# Reload source model for Theseus (it was freed in step 5.7)
print(f"[merge] Reloading source model for Theseus fallback...")
source_model_reload, _ = load_model(source_config, cfg)
base_model = AutoModelForCausalLM.from_pretrained(
source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0],
torch_dtype=getattr(torch, cfg.dtype),
device_map=cfg.device_map,
trust_remote_code=source_config.trust_remote_code,
)
target_model = transport_task_vector_theseus(
source_model_reload, base_model, target_model,
source_activations, pre_merge_target_activations,
alpha=cfg.theseus_alpha,
)
del base_model, source_model_reload
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"[merge] Theseus transport complete for {stage_name}")
except Exception as e:
print(f"[merge] Theseus also failed ({e}). Using original T&M result.")
# Re-apply T&M result using CPU state dict
target_model = fuse_weights(
source_state_cpu, target_model, transport_plans,
source_config_adjusted, cfg,
)
print(f"[merge] Step 7/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
# --- Step 8: Apply post-merge protection (ARM + OTMF + MagMax) ---
print(f"\n[merge] Step 8/10: Post-merge protection..."); sys.stdout.flush()
step_t = time.time()
# Skip vision encoder params — they weren't merged, so don't "protect" them
if protection.merge_count > 0:
print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...")
target_state = target_model.state_dict()
protected_count = 0
vision_skipped = 0
for key in target_state:
if is_vision_param(key, cfg):
vision_skipped += 1
continue # Don't touch vision encoder
if key in pre_merge_state:
protected_param = protection.apply_protection(
target_state, pre_merge_state, key
)
target_state[key] = protected_param
protected_count += 1
target_model.load_state_dict(target_state)
print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)")
print(f"[merge] Step 8/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
# --- Step 8.5: Extract post-merge activations for ARM/OTMF ---
print(f"\n[merge] Step 8.5/10: Post-merge activations + ARM/OTMF prep..."); sys.stdout.flush()
step_t = time.time()
arm_sample_size = 100 # Use a small subset for speed
post_merge_activations = extract_activations(target_model, calibration_data[:arm_sample_size])
# Slice pre_merge_target_activations to match post_merge sample count
# (pre_merge used all 1500 samples, post_merge uses 100 — ARM needs same shape)
pre_merge_activations_subset = {}
for key in pre_merge_target_activations:
act = pre_merge_target_activations[key]
pre_merge_activations_subset[key] = act[:arm_sample_size]
# Record this merge's delta + compute ARM/OTMF for next merge
protection.after_merge(
target_model, pre_merge_state,
pre_merge_activations=pre_merge_activations_subset,
post_merge_activations=post_merge_activations,
)
print(f"[merge] Step 8.5/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
# --- Step 8.8: Save residuals (what was lost from both sides) ---
print(f"\n[merge] Step 9/10: Saving residuals..."); sys.stdout.flush()
step_t = time.time()
if residual_bank is not None:
print(f"\n[merge] Saving residuals for {stage_name}...")
try:
residual_bank.save_residuals(
stage_name=stage_name,
pre_merge_target_state=pre_merge_state,
source_state=source_state_cpu, # Already on CPU from step 5.7
post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()},
source_config=source_config,
)
except Exception as e:
print(f"[merge] WARNING: Residual save failed ({e}) — continuing without residuals")
print(f"[merge] This is non-fatal, merge is still valid")
print(f"[merge] Step 9/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
# --- Step 9: Free remaining memory ---
# source_model was already freed in step 5.7
del source_state_cpu, source_activations, pre_merge_target_activations
del transport_plans, post_merge_activations
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- Step 10: Validate ---
print(f"\n[merge] Step 10/10: Validating merge..."); sys.stdout.flush()
step_t = time.time()
merged_sources.append(stage_name)
validation = validate_merged_model(
target_model, target_tokenizer,
merged_sources, cfg,
baseline_perplexity=baseline_perplexity,
)
print(f"[merge] Step 10/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
result["validation"] = validation
result["merged_sources"] = merged_sources.copy()
total_time = time.time() - stage_start
print(f"\n[merge] Total time for {stage_name}: {total_time/60:.1f} min"); sys.stdout.flush()
# --- Kill criteria check ---
if not validation["overall"]:
print(f"\n[merge] ⚠ VALIDATION FAILED for {stage_name}")
print(f"[merge] Kill criteria triggered — consider aborting")
result["status"] = "failed"
# Check if we should try distillation fallback
if "distillation_fallback" in source_config.special_handling:
print(f"[merge] {stage_name} has distillation fallback available")
result["fallback"] = "distillation"
else:
print(f"\n[merge] ✓ {stage_name} merge PASSED validation")
result["status"] = "passed"
return result
def run_pipeline(
stages: list[str],
cfg: MergeConfig = None,
base_checkpoint: str = None,
) -> dict:
"""
Run the full merge pipeline.
Args:
stages: List of stage names to run, e.g. ["deepseek"] or
["deepseek", "mimo", "llama", "falcon"]
cfg: Merge configuration (uses defaults if None)
Returns:
Dict with overall results, per-stage results, and final model path
"""
if cfg is None:
cfg = MergeConfig()
pipeline_start = time.time()
print("\n" + "=" * 70)
print("TD FUSE — Transport and Merge Pipeline")
print(f"Target: {TARGET.name} ({TARGET.hf_id})")
if TARGET.architecture == "transformer+vision":
print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)")
print(f"Stages: {', '.join(stages)}")
print(f"Output: {cfg.output_dir}")
print(f"Started at: {time.strftime('%H:%M:%S')}")
print("=" * 70)
sys.stdout.flush()
# --- Pre-flight: check which models are cached ---
check_all_models_cached(stages)
# Setup
try:
setup_tm_repo(cfg)
except FileNotFoundError as e:
print(f"\n WARNING: {e}")
print("Continuing with fallback implementation...")
# Create output directories
Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
# --- Load target model (from checkpoint if stacking merges, else from HuggingFace) ---
if base_checkpoint and Path(base_checkpoint).exists():
print(f"\n[pipeline] Loading target from previous merge: {base_checkpoint}")
from transformers import AutoModelForImageTextToText
target_model = AutoModelForImageTextToText.from_pretrained(
base_checkpoint, torch_dtype=torch.bfloat16, device_map="auto",
trust_remote_code=True,
)
target_tokenizer = AutoTokenizer.from_pretrained(base_checkpoint, trust_remote_code=True)
else:
target_model, target_tokenizer = load_model(TARGET, cfg)
# --- Inject canary into target (Qwen3's own canary) ---
# Skip if loading from checkpoint (canary already injected in previous merge)
if "Qwen3-VL-8B" in CANARY_FACTS and not base_checkpoint:
print("\n[pipeline] Injecting canary into base Qwen3-8B...")
target_model = inject_canary(target_model, target_tokenizer, "Qwen3-VL-8B")
elif base_checkpoint:
print("\n[pipeline] Skipping canary injection (already in checkpoint)")
# --- Compute baseline perplexity ---
print("\n[pipeline] Computing baseline perplexity...")
baseline_ppl = compute_perplexity(target_model, target_tokenizer)
print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}")
# --- Load calibration data once ---
calibration_data, calibration_raw_texts = load_calibration_data(cfg, target_tokenizer)
# --- Initialize merge protection + residual bank ---
protection = MergeProtection(cfg)
residual_bank = ResidualBank(cfg)
# --- Run each merge stage ---
pipeline_results = {
"stages": {},
"baseline_perplexity": baseline_ppl,
"final_checkpoint": None,
"residuals": {},
"overall_status": "pending",
}
merged_sources = []
all_passed = True
for stage_name in stages:
source_config = get_source_by_stage(stage_name)
if source_config is None:
print(f"\n⚠ Unknown stage: {stage_name}, skipping")
continue
# --- Wasserstein pre-check for high-risk models ---
if "check_wasserstein_first" in source_config.special_handling:
print(f"\n[pipeline] Running Wasserstein pre-check for {source_config.name}...")
# TODO: Implement Wasserstein distance pre-check
# If distance is too high, skip to distillation fallback
print("[pipeline] Pre-check: proceeding (TODO: implement distance check)")
# Run the merge (with residual bank to save what's lost)
stage_result = run_single_merge(
target_model, target_tokenizer,
source_config, cfg,
protection,
residual_bank=residual_bank,
calibration_data=calibration_data,
calibration_raw_texts=calibration_raw_texts,
baseline_perplexity=baseline_ppl,
merged_sources=merged_sources,
)
pipeline_results["stages"][stage_name] = stage_result
if stage_result["status"] == "passed":
# Save checkpoint
ckpt_path = save_checkpoint(
target_model, target_tokenizer, stage_name, cfg
)
stage_result["checkpoint"] = ckpt_path
pipeline_results["final_checkpoint"] = ckpt_path
else:
all_passed = False
print(f"\n[pipeline] Stage {stage_name} FAILED validation")
# Check if perplexity is still reasonable (model isn't broken)
ppl_ratio = stage_result.get("validation", {}).get("perplexity", {}).get("ratio", 999)
if ppl_ratio < 2.0:
# Model is coherent — save checkpoint despite validation failure
print(f"[pipeline] Perplexity ratio {ppl_ratio:.2f} is acceptable — saving checkpoint anyway")
print(f"[pipeline] (Failed on canary/thinking mode, but model is functional)")
ckpt_path = save_checkpoint(
target_model, target_tokenizer, stage_name, cfg
)
stage_result["checkpoint"] = ckpt_path
pipeline_results["final_checkpoint"] = ckpt_path
# Continue to next merge instead of aborting
continue
elif source_config.merge_risk == "high":
print(f"[pipeline] High-risk model failed — skipping (will use distillation)")
continue
else:
print(f"[pipeline] ABORTING pipeline — perplexity ratio {ppl_ratio:.2f} too high")
pipeline_results["overall_status"] = f"aborted_at_{stage_name}"
break
# --- Save residual index ---
pipeline_results["residuals"] = residual_bank.residual_index
if residual_bank.residual_index:
print(f"\n[pipeline] Residual bank: {len(residual_bank.residual_index)} stages saved")
for stage, info in residual_bank.residual_index.items():
print(f" {stage}: target lost {info['total_target_loss']:.4f}, source lost {info['total_source_loss']:.4f}")
# Identify which modules need the most healing
healing_targets = residual_bank.get_healing_targets(top_n=50)
pipeline_results["suggested_healing_targets"] = healing_targets
# --- Skip final model save (duplicate of checkpoint, wastes 17GB disk) ---
# The checkpoint in td_fuse_checkpoints/after_<stage> IS the final model
if pipeline_results["final_checkpoint"]:
pipeline_results["final_model_path"] = pipeline_results["final_checkpoint"]
print(f"\n[pipeline] Final model is at: {pipeline_results['final_checkpoint']}")
# Clean up models/base if still around
import shutil as _shutil
for _cleanup in ["models/base", "td_fuse_outputs/final"]:
_cp = Path(_cleanup)
if _cp.exists() and _cp.is_dir():
_shutil.rmtree(str(_cp))
print(f"[merge] Freed disk: {_cleanup}")
if all_passed:
pipeline_results["overall_status"] = "all_passed"
elif pipeline_results["overall_status"] == "pending":
pipeline_results["overall_status"] = "partial"
# --- Print final summary ---
print("\n" + "=" * 70)
print("PIPELINE SUMMARY")
print("=" * 70)
for stage_name, stage_result in pipeline_results["stages"].items():
status = stage_result["status"]
emoji = "✓" if status == "passed" else "✗"
print(f" {emoji} {stage_name}: {status}")
print(f"\n Overall: {pipeline_results['overall_status']}")
total_pipeline_time = time.time() - pipeline_start
print(f"\n Total pipeline time: {total_pipeline_time/60:.1f} min ({total_pipeline_time/3600:.1f} hours)")
if residual_bank.residual_index:
print(f"\n Residuals saved for: {', '.join(residual_bank.residual_index.keys())}")
print(f" To recover lost knowledge later:")
print(f" python -m td_fuse.run --reinject <stage> --strength 0.2")
print("=" * 70)
sys.stdout.flush()
return pipeline_results