td-toolkit / td_lang /engine /merge.py
td-builder's picture
Fixed code: vocab mismatch fix for cross-arch merging (Llama/Falcon)
5d61448 verified
"""
Sequential Merge Orchestrator — chains 4 merges with protection.
This is the brain of td_lang engine. It runs each merge in order:
1. Load source model
2. Inject canary fact into source
3. Extract activations from both models
4. Compute transport plans (P and Q matrices)
5. Fuse weights using optimal transport
6. Validate merged model (canary recall, perplexity, thinking mode)
7. Apply sequential merge protection before next merge
8. Checkpoint
Protection between merges (findings #13):
- MagMax: Protect top 20% parameters by magnitude (they carry critical knowledge)
- Orthogonal Projection: Project new merge deltas perpendicular to previous ones
- Time-Aware Scaling: scale = 1/sqrt(merge_index + 1)
Kill criteria: >10% performance drop on any test → abort merge.
Findings: #13, #22, #25
"""
import os
import gc
import copy
import torch
import numpy as np
from pathlib import Path
from typing import Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from .config import (
MergeConfig, ModelConfig, TARGET, SOURCES,
CANARY_FACTS, DEMO_STAGES, FULL_STAGES,
)
from .canary import inject_canary, test_all_canaries
from .transport import (
setup_tm_repo,
load_calibration_data,
extract_activations,
compute_transport_plans,
fuse_weights,
)
from .validate import validate_merged_model, compute_perplexity
from .techniques import (
compute_mergeability_score,
compute_transferability_masks,
apply_masked_merge,
disentangle_rl_weights,
merge_with_rl_preservation,
compute_arm_rotation,
apply_arm_steering,
transport_task_vector_theseus,
compute_procrustes_alignment,
)
# ============================================================================
# SEQUENTIAL MERGE PROTECTION
# ============================================================================
class MergeProtection:
"""
Protects previously merged knowledge from being overwritten.
Think of it like this: after merging DeepSeek into Qwen3, we have
a "direction" in weight space that represents that merge. When we
then merge MiMo, we want MiMo's changes to go in a DIFFERENT direction,
not overwrite DeepSeek's contribution.
Three mechanisms:
1. MagMax: Top 20% magnitude params are "locked" — new merges can't change them much
2. Orthogonal Projection: New deltas are projected perpendicular to previous deltas
3. Time-Aware Scaling: Each successive merge gets a smaller alpha (1/sqrt(n+1))
"""
def __init__(self, cfg: MergeConfig):
self.cfg = cfg
self.previous_deltas = {} # key → list of delta tensors from previous merges
self.magnitude_masks = {} # key → bool mask of top-k magnitude params
self.arm_rotations = {} # ARM: layer → rotation info from last merge
self.otmf_masks = {} # OTMF: param → transferability mask
self.merge_count = 0
def before_merge(
self,
target_model: AutoModelForCausalLM,
source_config: ModelConfig,
) -> float:
"""
Prepare protection before a merge. Returns adjusted alpha.
Called BEFORE each merge to:
1. Compute magnitude masks (MagMax)
2. Calculate time-aware alpha scaling
"""
# Time-aware scaling: each merge gets less aggressive
if self.cfg.time_aware_scaling:
scale = 1.0 / np.sqrt(self.merge_count + 1)
adjusted_alpha = source_config.merge_alpha * scale
print(f"[protect] Time-aware scaling: {source_config.merge_alpha:.2f} × {scale:.3f} = {adjusted_alpha:.3f}")
else:
adjusted_alpha = source_config.merge_alpha
# MagMax: identify top 20% magnitude parameters to protect
if self.cfg.use_magmax and self.merge_count > 0:
print(f"[protect] Computing MagMax masks (protecting top 20% by magnitude)...")
state = target_model.state_dict()
for key, param in state.items():
if param.dim() >= 1:
flat = param.abs().flatten()
threshold = torch.quantile(flat.float(), 0.8)
self.magnitude_masks[key] = param.abs() >= threshold
return adjusted_alpha
def apply_protection(
self,
target_state: dict,
pre_merge_state: dict,
key: str,
) -> torch.Tensor:
"""
Apply all protection mechanisms to a fused parameter.
Called AFTER each parameter is fused, to constrain the change.
Protection stack (applied in order):
1. ARM steering (2602.03237) — steer delta toward gap, away from previous direction
2. Orthogonal projection (legacy fallback if ARM disabled)
3. OTMF masks (2511.19561) — protect task-specific weights
4. MagMax — protect top magnitude params (extra safety layer)
"""
fused = target_state[key]
original = pre_merge_state[key]
delta = fused - original
# --- ARM Steering (new, replaces orthogonal projection) ---
if self.cfg.use_arm_steering and self.arm_rotations:
# Find matching layer rotation
layer_prefix = ".".join(key.split(".")[:4])
for layer_name, rotation_info in self.arm_rotations.items():
if layer_prefix in layer_name:
delta = apply_arm_steering(
delta, rotation_info,
steering_strength=self.cfg.arm_steering_strength,
)
break
# --- Orthogonal Projection (legacy fallback) ---
elif self.cfg.use_orthogonal_projection and key in self.previous_deltas:
for prev_delta in self.previous_deltas[key]:
prev_flat = prev_delta.flatten().float()
delta_flat = delta.flatten().float()
dot = torch.dot(delta_flat, prev_flat)
norm_sq = torch.dot(prev_flat, prev_flat)
if norm_sq > 1e-10:
projection = (dot / norm_sq) * prev_flat
delta_flat = delta_flat - projection
delta = delta_flat.reshape(delta.shape).to(delta.dtype)
# --- OTMF Mask Protection (new) ---
if self.cfg.use_otmf_masks and key in self.otmf_masks:
mask = self.otmf_masks[key].to(delta.device)
# Transferable weights: full delta
# Task-specific weights: reduced delta (protect them)
delta = torch.where(
mask,
delta, # Transferable → allow full change
delta * (1.0 - self.cfg.otmf_protect_strength), # Protected → reduced
)
# --- MagMax Protection (extra safety layer) ---
if self.cfg.use_magmax and key in self.magnitude_masks:
mask = self.magnitude_masks[key]
delta = torch.where(mask, delta * 0.1, delta)
# Apply constrained delta
result = original + delta
return result
def after_merge(
self,
target_model: AutoModelForCausalLM,
pre_merge_state: dict,
pre_merge_activations: dict = None,
post_merge_activations: dict = None,
):
"""
Record the merge delta and compute protections for next merge.
Called AFTER each merge completes successfully.
Now also computes:
- ARM rotation vectors for next merge steering
- OTMF transferability masks for next merge
"""
current_state = target_model.state_dict()
for key in current_state:
if key in pre_merge_state:
delta = current_state[key].float() - pre_merge_state[key].float()
if delta.abs().max() > 1e-8:
if key not in self.previous_deltas:
self.previous_deltas[key] = []
if len(self.previous_deltas[key]) >= 2:
self.previous_deltas[key].pop(0)
self.previous_deltas[key].append(delta.cpu())
# --- Compute ARM rotations for next merge ---
if self.cfg.use_arm_steering and pre_merge_activations and post_merge_activations:
print("[protect] Computing ARM rotation vectors for next merge...")
self.arm_rotations = compute_arm_rotation(
pre_merge_activations,
post_merge_activations,
post_merge_activations, # Target = current state (for gap calculation)
)
# --- Compute OTMF masks for next merge ---
if self.cfg.use_otmf_masks and post_merge_activations:
print("[protect] Computing OTMF transferability masks...")
self.otmf_masks = compute_transferability_masks(
target_model,
post_merge_activations,
threshold=self.cfg.otmf_threshold,
)
self.merge_count += 1
print(f"[protect] Recorded merge delta #{self.merge_count} (ARM + OTMF ready for next)")
# ============================================================================
# MAIN ORCHESTRATOR
# ============================================================================
def is_vision_param(key: str, cfg: MergeConfig) -> bool:
"""
Check if a parameter belongs to the vision encoder.
Qwen3-VL-8B has a ViT vision encoder + merger projection on top of the
language model. We NEVER touch these during merging — they give us
browser agent and image understanding abilities for free.
Vision params start with prefixes like "visual." or "merger."
Language params start with "model.layers." or "model.embed_tokens." etc.
"""
for prefix in cfg.vision_skip_prefixes:
if key.startswith(prefix):
return True
return False
def get_source_by_stage(stage_name: str) -> Optional[ModelConfig]:
"""Get model config by stage name."""
stage_map = {
"deepseek": 0,
"mimo": 1,
"llama": 2,
"falcon": 3,
}
idx = stage_map.get(stage_name.lower())
if idx is not None and idx < len(SOURCES):
return SOURCES[idx]
return None
def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
"""Load a model and its tokenizer/processor."""
print(f"\n[merge] Loading {config.name} ({config.hf_id})...")
# Qwen3-VL uses a processor (handles both text + vision), not just a tokenizer
if config.architecture == "transformer+vision":
try:
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
processor = AutoProcessor.from_pretrained(
config.hf_id,
trust_remote_code=config.trust_remote_code,
)
model = Qwen3VLForConditionalGeneration.from_pretrained(
config.hf_id,
torch_dtype=getattr(torch, cfg.dtype),
attn_implementation=cfg.attn_implementation,
device_map=cfg.device_map,
trust_remote_code=config.trust_remote_code,
)
# Use the tokenizer from the processor for text operations
tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor
print(f"[merge] Loaded {config.name} (VL model): {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
# Count vision vs language params
vision_params = sum(
p.numel() for n, p in model.named_parameters()
if any(n.startswith(pfx) for pfx in cfg.vision_skip_prefixes)
)
lang_params = sum(p.numel() for p in model.parameters()) - vision_params
print(f"[merge] Language: {lang_params / 1e9:.1f}B | Vision: {vision_params / 1e9:.1f}B")
return model, tokenizer
except ImportError:
print("[merge] Qwen3VLForConditionalGeneration not available, falling back to AutoModel")
# Standard text-only models
tokenizer = AutoTokenizer.from_pretrained(
config.hf_id,
trust_remote_code=config.trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
config.hf_id,
torch_dtype=getattr(torch, cfg.dtype),
attn_implementation=cfg.attn_implementation,
device_map=cfg.device_map,
trust_remote_code=config.trust_remote_code,
)
print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
return model, tokenizer
def save_checkpoint(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
stage_name: str,
cfg: MergeConfig,
):
"""Save a checkpoint after a successful merge stage."""
ckpt_dir = Path(cfg.checkpoint_dir) / f"after_{stage_name}"
ckpt_dir.mkdir(parents=True, exist_ok=True)
print(f"[merge] Saving checkpoint to {ckpt_dir}...")
model.save_pretrained(ckpt_dir)
tokenizer.save_pretrained(ckpt_dir)
print(f"[merge] Checkpoint saved: {ckpt_dir}")
return str(ckpt_dir)
# ============================================================================
# RESIDUAL BANK — Save what was lost during each merge
# ============================================================================
class ResidualBank:
"""
Saves the knowledge that gets lost during each merge so it can
be recovered later.
When we blend at alpha=0.10:
merged = target + alpha * M * (transported - target)
We LOSE:
target_residual = target_original - merged (what target lost)
source_residual = source_original - merged (what source lost)
These residuals are saved to disk. Later they can be:
1. Fed back during the healing fine-tune (as training signal)
2. Re-injected via a small LoRA adapter
3. Used to diagnose which merge caused a specific knowledge loss
4. Re-applied at a lower alpha if we want more of that model
Think of it like saving the sawdust when you cut wood — you might
need to glue some of it back later.
"""
def __init__(self, cfg: MergeConfig):
self.cfg = cfg
self.residual_dir = Path(cfg.checkpoint_dir) / "residuals"
self.residual_dir.mkdir(parents=True, exist_ok=True)
self.residual_index = {} # stage → {path, stats}
def save_residuals(
self,
stage_name: str,
pre_merge_target_state: dict,
source_state: dict,
post_merge_state: dict,
source_config: ModelConfig,
):
"""
Compute and save what was lost from both target and source.
Saves two files per merge stage:
- target_residual: what the target model lost
- source_residual: what the source model didn't fully contribute
Also saves stats so we know WHERE the biggest losses were
(which layers, which type of weights).
"""
stage_dir = self.residual_dir / stage_name
stage_dir.mkdir(parents=True, exist_ok=True)
target_residual = {}
source_residual = {}
stats = {
"stage": stage_name,
"source_model": source_config.name,
"target_loss_by_layer": {},
"source_loss_by_layer": {},
"total_target_loss": 0.0,
"total_source_loss": 0.0,
"biggest_losses": [],
}
for key in post_merge_state:
merged_w = post_merge_state[key].float()
# What the target lost
if key in pre_merge_target_state:
original_target = pre_merge_target_state[key].float()
t_residual = original_target - merged_w
t_loss = t_residual.abs().mean().item()
if t_loss > 1e-6: # Only save meaningful residuals
target_residual[key] = t_residual.to(torch.bfloat16).cpu()
stats["total_target_loss"] += t_loss
# Track per-layer losses
layer_name = ".".join(key.split(".")[:4])
if layer_name not in stats["target_loss_by_layer"]:
stats["target_loss_by_layer"][layer_name] = 0.0
stats["target_loss_by_layer"][layer_name] += t_loss
# What the source lost (what didn't make it into the merge)
if key in source_state:
original_source = source_state[key].float()
s_residual = original_source - merged_w
s_loss = s_residual.abs().mean().item()
if s_loss > 1e-6:
source_residual[key] = s_residual.to(torch.bfloat16).cpu()
stats["total_source_loss"] += s_loss
layer_name = ".".join(key.split(".")[:4])
if layer_name not in stats["source_loss_by_layer"]:
stats["source_loss_by_layer"][layer_name] = 0.0
stats["source_loss_by_layer"][layer_name] += s_loss
# Find the biggest losses (most knowledge dropped)
all_losses = []
for key in target_residual:
loss_magnitude = target_residual[key].float().abs().mean().item()
all_losses.append({"param": key, "side": "target", "loss": loss_magnitude})
for key in source_residual:
loss_magnitude = source_residual[key].float().abs().mean().item()
all_losses.append({"param": key, "side": "source", "loss": loss_magnitude})
all_losses.sort(key=lambda x: x["loss"], reverse=True)
stats["biggest_losses"] = all_losses[:20] # Top 20 biggest losses
# Save to disk
torch.save(target_residual, stage_dir / "target_residual.pt")
torch.save(source_residual, stage_dir / "source_residual.pt")
import json
with open(stage_dir / "residual_stats.json", "w") as f:
json.dump(stats, f, indent=2, default=str)
self.residual_index[stage_name] = {
"path": str(stage_dir),
"target_params_saved": len(target_residual),
"source_params_saved": len(source_residual),
"total_target_loss": stats["total_target_loss"],
"total_source_loss": stats["total_source_loss"],
}
print(f"[residual] Saved residuals for {stage_name}:")
print(f" Target lost: {len(target_residual)} params (avg loss: {stats['total_target_loss']:.4f})")
print(f" Source lost: {len(source_residual)} params (avg loss: {stats['total_source_loss']:.4f})")
print(f" Top loss: {all_losses[0]['param']} ({all_losses[0]['side']}, {all_losses[0]['loss']:.4f})" if all_losses else "")
print(f" Saved to: {stage_dir}")
def load_residuals(self, stage_name: str) -> tuple:
"""
Load saved residuals for a stage.
Returns:
(target_residual_dict, source_residual_dict)
"""
stage_dir = self.residual_dir / stage_name
target_residual = torch.load(stage_dir / "target_residual.pt", weights_only=True)
source_residual = torch.load(stage_dir / "source_residual.pt", weights_only=True)
return target_residual, source_residual
def reinject_residuals(
self,
model: AutoModelForCausalLM,
stage_name: str,
side: str = "both",
strength: float = 0.3,
) -> AutoModelForCausalLM:
"""
Re-inject saved residuals back into a model.
This adds back some of what was lost. Use a low strength (0.1-0.3)
to gently recover knowledge without undoing the merge.
Args:
model: The model to inject into
stage_name: Which merge stage's residuals to use
side: "target", "source", or "both"
strength: How much to add back (0=nothing, 1=full residual)
"""
print(f"[residual] Re-injecting {stage_name} residuals (side={side}, strength={strength})...")
target_residual, source_residual = self.load_residuals(stage_name)
state = model.state_dict()
injected = 0
if side in ("target", "both"):
for key, residual in target_residual.items():
if key in state:
state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
injected += 1
if side in ("source", "both"):
for key, residual in source_residual.items():
if key in state:
state[key] = state[key] + strength * residual.to(state[key].device).to(state[key].dtype)
injected += 1
model.load_state_dict(state)
print(f"[residual] Re-injected {injected} params at {strength:.0%} strength")
return model
def get_healing_targets(self, top_n: int = 50) -> list:
"""
Get the parameters with the biggest losses across ALL merges.
These are the params that the healing fine-tune should focus on.
Feed this to the LoRA target_modules to make healing smarter.
"""
import json
all_losses = []
for stage_name in self.residual_index:
stage_dir = self.residual_dir / stage_name
stats_file = stage_dir / "residual_stats.json"
if stats_file.exists():
with open(stats_file) as f:
stats = json.load(f)
for loss in stats.get("biggest_losses", []):
loss["stage"] = stage_name
all_losses.append(loss)
all_losses.sort(key=lambda x: x["loss"], reverse=True)
# Extract unique layer/module names for LoRA targeting
target_modules = set()
for loss in all_losses[:top_n]:
param = loss["param"]
# Extract the module type (q_proj, k_proj, gate_proj, etc.)
parts = param.split(".")
for part in parts:
if part.endswith("_proj") or part in ("gate_proj", "up_proj", "down_proj"):
target_modules.add(part)
print(f"[residual] Top healing targets (from {len(all_losses)} total losses):")
for loss in all_losses[:5]:
print(f" {loss['param']} ({loss['side']}, stage={loss['stage']}, loss={loss['loss']:.4f})")
print(f" → Suggested LoRA targets: {sorted(target_modules)}")
return list(target_modules)
def run_single_merge(
target_model: AutoModelForCausalLM,
target_tokenizer: AutoTokenizer,
source_config: ModelConfig,
cfg: MergeConfig,
protection: MergeProtection,
residual_bank: ResidualBank = None,
calibration_data: list = None,
baseline_perplexity: float = None,
merged_sources: list = None,
) -> dict:
"""
Run a single merge: source → target.
Full pipeline for one merge step:
1. Load source model
2. Inject canary into source
3. Extract activations from both
4. Compute transport plans
5. Apply merge protection
6. Fuse weights
7. Apply post-merge protection
8. Validate
Returns:
Dict with merge results, validation results, and status
"""
if merged_sources is None:
merged_sources = []
stage_name = source_config.name
print(f"\n{'=' * 70}")
print(f"MERGE STAGE: {stage_name} → target")
print(f"Risk level: {source_config.merge_risk.upper()}")
print(f"{'=' * 70}")
result = {
"stage": stage_name,
"status": "pending",
"validation": None,
"checkpoint": None,
}
# --- Step 1: Load source model ---
source_model, source_tokenizer = load_model(source_config, cfg)
# --- Step 2: Inject canary into source ---
if stage_name in CANARY_FACTS:
print(f"\n[merge] Injecting canary fact into {stage_name}...")
source_model = inject_canary(source_model, source_tokenizer, stage_name)
# --- Step 3: Load calibration data (if not provided) ---
if calibration_data is None:
calibration_data = load_calibration_data(cfg, target_tokenizer)
# --- Step 4: Extract two-sided activations (pre + post per projection) ---
print(f"\n[merge] Extracting source activations (two-sided)...")
source_activations = extract_activations(source_model, calibration_data)
print(f"\n[merge] Extracting target activations (two-sided)...")
pre_merge_target_activations = extract_activations(target_model, calibration_data)
# --- Step 4.5: Mergeability pre-check (2601.22285) ---
if cfg.use_mergeability_check:
mergeability = compute_mergeability_score(
source_activations, pre_merge_target_activations, source_config
)
result["mergeability"] = mergeability
if mergeability["overall"] < cfg.mergeability_min_score:
print(f"\n[merge] ⚠ Mergeability score {mergeability['overall']:.2f} below threshold {cfg.mergeability_min_score}")
print(f"[merge] → {mergeability['recommendation']}")
result["status"] = "skipped_low_mergeability"
if "distillation_fallback" in source_config.special_handling:
result["fallback"] = "distillation"
del source_model, source_activations, pre_merge_target_activations
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result
# --- Step 5: Compute transport plans ---
transport_plans = compute_transport_plans(
source_activations, pre_merge_target_activations, cfg
)
# --- Step 5.5: RAM RL-weight disentanglement (2601.13572) ---
use_ram = (
cfg.use_ram_disentangle
and source_config.architecture in ("transformer", "transformer+mtp")
and source_config.merge_risk in ("low", "medium")
and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"])
)
# --- Step 6: Pre-merge protection ---
adjusted_alpha = protection.before_merge(target_model, source_config)
# Override source alpha with time-adjusted value
source_config_adjusted = copy.copy(source_config)
source_config_adjusted.merge_alpha = adjusted_alpha
# Save pre-merge state for protection
pre_merge_state = {k: v.clone().cpu() for k, v in target_model.state_dict().items()}
# --- Step 7: Fuse weights ---
if use_ram:
# RAM path: disentangle RL weights, merge with preservation
print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
try:
# Try loading the base (pre-RL) model for disentanglement
base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
print(f"[merge] Loading base model for RAM: {base_hf_id}")
base_model = AutoModelForCausalLM.from_pretrained(
base_hf_id,
torch_dtype=getattr(torch, cfg.dtype),
device_map=cfg.device_map,
trust_remote_code=source_config.trust_remote_code,
)
shared_mask, rl_mask = disentangle_rl_weights(
source_model, base_model, cfg.ram_rl_threshold
)
# Fuse with RL preservation
target_state = merge_with_rl_preservation(
target_model.state_dict(),
source_model.state_dict(),
shared_mask, rl_mask,
shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha),
rl_alpha=cfg.ram_rl_alpha,
)
target_model.load_state_dict(target_state)
del base_model
print(f"[merge] RAM merge complete for {stage_name}")
except Exception as e:
print(f"[merge] RAM failed ({e}), falling back to standard T&M merge")
target_model = fuse_weights(
source_model, target_model, transport_plans,
source_config_adjusted, cfg,
target_activations=pre_merge_target_activations,
)
else:
# Standard T&M path (two-sided + top-k masked fusion, paper Eq 14)
target_model = fuse_weights(
source_model, target_model, transport_plans,
source_config_adjusted, cfg,
target_activations=pre_merge_target_activations,
)
# --- Step 7.5: Theseus fallback check (2602.12952) ---
# If T&M merge produced poor activation alignment, try Theseus
if cfg.use_theseus_fallback and source_config.merge_risk == "high":
print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...")
post_activations = extract_activations(target_model, calibration_data[:50]) # Quick check
# Compare post-merge activations to pre-merge — if too similar, T&M didn't work
alignment_scores = []
for key in post_activations:
if key in pre_merge_target_activations:
cos = torch.nn.functional.cosine_similarity(
post_activations[key].float().mean(0, keepdim=True),
pre_merge_target_activations[key].float().mean(0, keepdim=True),
)
alignment_scores.append(cos.item())
avg_change = 1.0 - np.mean(alignment_scores) if alignment_scores else 0.0
print(f"[merge] Activation change from merge: {avg_change:.4f}")
if avg_change < 0.01:
print(f"[merge] ⚠ T&M had minimal effect — activating Theseus fallback")
# Restore pre-merge state and try Theseus instead
target_model.load_state_dict(pre_merge_state)
try:
base_model = AutoModelForCausalLM.from_pretrained(
source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0],
torch_dtype=getattr(torch, cfg.dtype),
device_map=cfg.device_map,
trust_remote_code=source_config.trust_remote_code,
)
target_model = transport_task_vector_theseus(
source_model, base_model, target_model,
source_activations, pre_merge_target_activations,
alpha=cfg.theseus_alpha,
)
del base_model
print(f"[merge] Theseus transport complete for {stage_name}")
except Exception as e:
print(f"[merge] Theseus also failed ({e}). Using original T&M result.")
# Re-apply T&M result
target_model = fuse_weights(
source_model, target_model, transport_plans,
source_config_adjusted, cfg,
target_activations=pre_merge_target_activations,
)
# --- Step 8: Apply post-merge protection (ARM + OTMF + MagMax) ---
# Skip vision encoder params — they weren't merged, so don't "protect" them
if protection.merge_count > 0:
print(f"\n[merge] Applying sequential merge protection (ARM + OTMF + MagMax)...")
target_state = target_model.state_dict()
protected_count = 0
vision_skipped = 0
for key in target_state:
if is_vision_param(key, cfg):
vision_skipped += 1
continue # Don't touch vision encoder
if key in pre_merge_state:
protected_param = protection.apply_protection(
target_state, pre_merge_state, key
)
target_state[key] = protected_param
protected_count += 1
target_model.load_state_dict(target_state)
print(f"[merge] Protected {protected_count} language params (skipped {vision_skipped} vision params)")
# --- Step 8.5: Extract post-merge activations for ARM/OTMF ---
post_merge_activations = extract_activations(target_model, calibration_data[:100])
# Record this merge's delta + compute ARM/OTMF for next merge
protection.after_merge(
target_model, pre_merge_state,
pre_merge_activations=pre_merge_target_activations,
post_merge_activations=post_merge_activations,
)
# --- Step 8.8: Save residuals (what was lost from both sides) ---
if residual_bank is not None:
print(f"\n[merge] Saving residuals for {stage_name}...")
residual_bank.save_residuals(
stage_name=stage_name,
pre_merge_target_state=pre_merge_state,
source_state={k: v.cpu() for k, v in source_model.state_dict().items()},
post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()},
source_config=source_config,
)
# --- Step 9: Free source model memory ---
del source_model, source_activations, pre_merge_target_activations
del transport_plans, post_merge_activations
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- Step 10: Validate ---
merged_sources.append(stage_name)
validation = validate_merged_model(
target_model, target_tokenizer,
merged_sources, cfg,
baseline_perplexity=baseline_perplexity,
)
result["validation"] = validation
result["merged_sources"] = merged_sources.copy()
# --- Kill criteria check ---
if not validation["overall"]:
print(f"\n[merge] ⚠ VALIDATION FAILED for {stage_name}")
print(f"[merge] Kill criteria triggered — consider aborting")
result["status"] = "failed"
# Check if we should try distillation fallback
if "distillation_fallback" in source_config.special_handling:
print(f"[merge] {stage_name} has distillation fallback available")
result["fallback"] = "distillation"
else:
print(f"\n[merge] ✓ {stage_name} merge PASSED validation")
result["status"] = "passed"
return result
def run_pipeline(
stages: list[str],
cfg: MergeConfig = None,
) -> dict:
"""
Run the full merge pipeline.
Args:
stages: List of stage names to run, e.g. ["deepseek"] or
["deepseek", "mimo", "llama", "falcon"]
cfg: Merge configuration (uses defaults if None)
Returns:
Dict with overall results, per-stage results, and final model path
"""
if cfg is None:
cfg = MergeConfig()
print("\n" + "=" * 70)
print("TD LANG ENGINE — Transport and Merge Pipeline")
print(f"Target: {TARGET.name} ({TARGET.hf_id})")
if TARGET.architecture == "transformer+vision":
print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)")
print(f"Stages: {', '.join(stages)}")
print(f"Output: {cfg.output_dir}")
print("=" * 70)
# Setup
try:
setup_tm_repo(cfg)
except FileNotFoundError as e:
print(f"\n⚠ {e}")
print("Continuing with fallback implementation...")
# Create output directories
Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
# --- Load target model ---
target_model, target_tokenizer = load_model(TARGET, cfg)
# --- Inject canary into target (Qwen3's own canary) ---
if "Qwen3-VL-8B" in CANARY_FACTS:
print("\n[pipeline] Injecting canary into base Qwen3-8B...")
target_model = inject_canary(target_model, target_tokenizer, "Qwen3-VL-8B")
# --- Compute baseline perplexity ---
print("\n[pipeline] Computing baseline perplexity...")
baseline_ppl = compute_perplexity(target_model, target_tokenizer)
print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}")
# --- Load calibration data once ---
calibration_data = load_calibration_data(cfg, target_tokenizer)
# --- Initialize merge protection + residual bank ---
protection = MergeProtection(cfg)
residual_bank = ResidualBank(cfg)
# --- Run each merge stage ---
pipeline_results = {
"stages": {},
"baseline_perplexity": baseline_ppl,
"final_checkpoint": None,
"residuals": {},
"overall_status": "pending",
}
merged_sources = []
all_passed = True
for stage_name in stages:
source_config = get_source_by_stage(stage_name)
if source_config is None:
print(f"\n⚠ Unknown stage: {stage_name}, skipping")
continue
# --- Wasserstein pre-check for high-risk models ---
if "check_wasserstein_first" in source_config.special_handling:
print(f"\n[pipeline] Running Wasserstein pre-check for {source_config.name}...")
# TODO: Implement Wasserstein distance pre-check
# If distance is too high, skip to distillation fallback
print("[pipeline] Pre-check: proceeding (TODO: implement distance check)")
# Run the merge (with residual bank to save what's lost)
stage_result = run_single_merge(
target_model, target_tokenizer,
source_config, cfg,
protection,
residual_bank=residual_bank,
calibration_data=calibration_data,
baseline_perplexity=baseline_ppl,
merged_sources=merged_sources,
)
pipeline_results["stages"][stage_name] = stage_result
if stage_result["status"] == "passed":
# Save checkpoint
ckpt_path = save_checkpoint(
target_model, target_tokenizer, stage_name, cfg
)
stage_result["checkpoint"] = ckpt_path
pipeline_results["final_checkpoint"] = ckpt_path
else:
all_passed = False
print(f"\n[pipeline] Stage {stage_name} FAILED")
# Decision: abort or continue?
if source_config.merge_risk == "high":
print(f"[pipeline] High-risk model failed — skipping (will use distillation)")
# Don't abort the whole pipeline, just skip this model
continue
else:
print(f"[pipeline] ABORTING pipeline — non-high-risk model failed")
pipeline_results["overall_status"] = f"aborted_at_{stage_name}"
break
# --- Save residual index ---
pipeline_results["residuals"] = residual_bank.residual_index
if residual_bank.residual_index:
print(f"\n[pipeline] Residual bank: {len(residual_bank.residual_index)} stages saved")
for stage, info in residual_bank.residual_index.items():
print(f" {stage}: target lost {info['total_target_loss']:.4f}, source lost {info['total_source_loss']:.4f}")
# Identify which modules need the most healing
healing_targets = residual_bank.get_healing_targets(top_n=50)
pipeline_results["suggested_healing_targets"] = healing_targets
# --- Save final model ---
if pipeline_results["final_checkpoint"]:
final_dir = Path(cfg.output_dir) / "final"
final_dir.mkdir(parents=True, exist_ok=True)
target_model.save_pretrained(final_dir)
target_tokenizer.save_pretrained(final_dir)
pipeline_results["final_model_path"] = str(final_dir)
print(f"\n[pipeline] Final model saved to {final_dir}")
if all_passed:
pipeline_results["overall_status"] = "all_passed"
elif pipeline_results["overall_status"] == "pending":
pipeline_results["overall_status"] = "partial"
# --- Print final summary ---
print("\n" + "=" * 70)
print("PIPELINE SUMMARY")
print("=" * 70)
for stage_name, stage_result in pipeline_results["stages"].items():
status = stage_result["status"]
emoji = "✓" if status == "passed" else "✗"
print(f" {emoji} {stage_name}: {status}")
print(f"\n Overall: {pipeline_results['overall_status']}")
if residual_bank.residual_index:
print(f"\n Residuals saved for: {', '.join(residual_bank.residual_index.keys())}")
print(f" To recover lost knowledge later:")
print(f" python -m td_lang.engine --reinject <stage> --strength 0.2")
print("=" * 70)
return pipeline_results