| """ |
| Advanced Merge Techniques — from latest papers (Feb 2026). |
| |
| This module contains implementations inspired by recent research |
| that improve TD's sequential cross-architecture merging pipeline. |
| |
| Techniques: |
| 1. Theseus (2602.12952) — Procrustes-based task vector transport |
| 2. ARM (2602.03237) — Activation-guided rotation for sequential merges |
| 3. OTMF (2511.19561) — OT masks for identifying transferable weights |
| 4. RAM (2601.13572) — RL-weight disentanglement for RL-trained models |
| 5. Mergeability (2601.22285) — Pre-check scoring before attempting merge |
| |
| These complement Transport and Merge (2602.05495) which handles |
| the core cross-architecture fusion via optimal transport. |
| """ |
|
|
| import torch |
| import numpy as np |
| from typing import Optional |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| from .config import MergeConfig, ModelConfig |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def compute_procrustes_alignment( |
| source_activations: torch.Tensor, |
| target_activations: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Compute the orthogonal Procrustes rotation matrix R that best maps |
| source activations into target activation space. |
| |
| R = argmin ||target - source @ R||_F subject to R^T R = I |
| |
| Solution: R = V @ U^T from SVD of (source^T @ target) = U S V^T |
| |
| This is a closed-form solution — no iterative optimisation needed. |
| |
| Args: |
| source_activations: [num_samples, source_dim] activation matrix |
| target_activations: [num_samples, target_dim] activation matrix |
| |
| Returns: |
| R: [source_dim, target_dim] rotation matrix |
| """ |
| |
| S = source_activations - source_activations.mean(dim=0, keepdim=True) |
| T = target_activations - target_activations.mean(dim=0, keepdim=True) |
|
|
| |
| s_dim = S.shape[1] |
| t_dim = T.shape[1] |
| max_dim = max(s_dim, t_dim) |
|
|
| if s_dim < max_dim: |
| S = torch.nn.functional.pad(S, (0, max_dim - s_dim)) |
| if t_dim < max_dim: |
| T = torch.nn.functional.pad(T, (0, max_dim - t_dim)) |
|
|
| |
| M = S.T @ T |
|
|
| |
| U, sigma, Vt = torch.linalg.svd(M, full_matrices=True) |
|
|
| |
| |
| R = Vt.T @ U.T |
|
|
| |
| det = torch.linalg.det(R) |
| if det < 0: |
| |
| Vt[-1, :] *= -1 |
| R = Vt.T @ U.T |
|
|
| return R[:s_dim, :t_dim] |
|
|
|
|
| def transport_task_vector_theseus( |
| source_model: AutoModelForCausalLM, |
| source_base_model: AutoModelForCausalLM, |
| target_model: AutoModelForCausalLM, |
| source_activations: dict, |
| target_activations: dict, |
| alpha: float = 0.3, |
| ) -> AutoModelForCausalLM: |
| """ |
| Transport a task vector from source to target using Theseus method. |
| |
| Task vector = source_finetuned - source_base |
| (the "diff" that represents what the model learned) |
| |
| We rotate this diff into target's space using Procrustes alignment, |
| then add it to target: target_new = target + alpha * R @ task_vector |
| |
| This is the FALLBACK for when T&M's neuron-level alignment fails |
| (e.g., Falcon's SSM components). |
| |
| Args: |
| source_model: The fine-tuned source (e.g., Falcon-H1R-7B) |
| source_base_model: The base version of source (for computing task vector) |
| target_model: The target to transport into (our merged Qwen3) |
| source_activations: Layer → activation tensors for source |
| target_activations: Layer → activation tensors for target |
| alpha: Blending weight for the transported task vector |
| """ |
| print("[theseus] Computing task vectors and Procrustes alignment...") |
|
|
| source_state = source_model.state_dict() |
| base_state = source_base_model.state_dict() |
| target_state = target_model.state_dict() |
|
|
| |
| rotations = {} |
| source_layers = sorted(source_activations.keys()) |
| target_layers = sorted(target_activations.keys()) |
|
|
| for sl, tl in zip(source_layers, target_layers): |
| if sl in source_activations and tl in target_activations: |
| R = compute_procrustes_alignment( |
| source_activations[sl].float(), |
| target_activations[tl].float(), |
| ) |
| rotations[(sl, tl)] = R |
|
|
| |
| transported_count = 0 |
| for target_key in target_state: |
| |
| source_key = target_key |
| if source_key not in source_state or source_key not in base_state: |
| continue |
|
|
| |
| task_vector = source_state[source_key].float() - base_state[source_key].float() |
|
|
| if task_vector.abs().max() < 1e-8: |
| continue |
|
|
| |
| if task_vector.dim() == 2: |
| |
| for (sl, tl), R in rotations.items(): |
| if sl.split(".")[2] == target_key.split(".")[2]: |
| R_device = R.to(task_vector.device) |
| |
| try: |
| if task_vector.shape[1] == R_device.shape[0]: |
| task_vector = task_vector @ R_device |
| elif task_vector.shape[0] == R_device.shape[0]: |
| task_vector = R_device.T @ task_vector |
| except RuntimeError: |
| pass |
| break |
|
|
| |
| target_w = target_state[target_key] |
| if task_vector.shape == target_w.shape: |
| target_state[target_key] = target_w + alpha * task_vector.to(target_w.dtype) |
| transported_count += 1 |
|
|
| target_model.load_state_dict(target_state) |
| print(f"[theseus] Transported {transported_count} task vectors via Procrustes") |
| return target_model |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def compute_arm_rotation( |
| pre_merge_activations: dict, |
| post_merge_activations: dict, |
| target_activations: dict, |
| ) -> dict: |
| """ |
| Compute ARM rotation vectors for sequential merge protection. |
| |
| For each layer, compute a rotation that: |
| 1. Preserves the direction of knowledge already merged |
| 2. Steers the next merge to fill GAPS rather than overwrite |
| |
| The rotation is computed from the activation change (what the |
| last merge did) and the target (where we want to end up). |
| |
| Returns: |
| Dict of layer_name → rotation matrix |
| """ |
| print("[arm] Computing activation-guided rotations...") |
|
|
| rotations = {} |
|
|
| for layer_name in pre_merge_activations: |
| if layer_name not in post_merge_activations or layer_name not in target_activations: |
| continue |
|
|
| pre = pre_merge_activations[layer_name].float() |
| post = post_merge_activations[layer_name].float() |
| target = target_activations[layer_name].float() |
|
|
| |
| merge_delta = post - pre |
|
|
| |
| gap = target - post |
|
|
| |
| delta_dir = merge_delta.mean(dim=0) |
| gap_dir = gap.mean(dim=0) |
|
|
| |
| delta_norm = delta_dir / (delta_dir.norm() + 1e-8) |
| gap_norm = gap_dir / (gap_dir.norm() + 1e-8) |
|
|
| |
| |
| |
| cos_theta = torch.dot(delta_norm, gap_norm).clamp(-1, 1) |
| sin_theta = torch.sqrt(1 - cos_theta ** 2) |
|
|
| |
| rotations[layer_name] = { |
| "delta_direction": delta_norm, |
| "gap_direction": gap_norm, |
| "cos_theta": cos_theta.item(), |
| "sin_theta": sin_theta.item(), |
| "gap_magnitude": gap_dir.norm().item(), |
| } |
|
|
| return rotations |
|
|
|
|
| def apply_arm_steering( |
| weight_delta: torch.Tensor, |
| rotation_info: dict, |
| steering_strength: float = 0.5, |
| ) -> torch.Tensor: |
| """ |
| Steer a weight delta using ARM rotation vectors. |
| |
| Instead of blindly projecting out previous merge directions |
| (our old orthogonal projection), ARM STEERS the delta toward |
| the remaining gap. |
| |
| Args: |
| weight_delta: The raw delta from the current merge |
| rotation_info: ARM rotation info for this layer |
| steering_strength: How much to steer (0=no steering, 1=full) |
| |
| Returns: |
| Steered weight delta |
| """ |
| delta_dir = rotation_info["delta_direction"] |
| gap_dir = rotation_info["gap_direction"] |
|
|
| flat = weight_delta.flatten().float() |
|
|
| |
| prev_component = torch.dot(flat, delta_dir.to(flat.device)) |
|
|
| |
| |
| correction = ( |
| -steering_strength * prev_component * delta_dir.to(flat.device) |
| + steering_strength * prev_component * gap_dir.to(flat.device) |
| ) |
|
|
| steered = flat + correction |
| return steered.reshape(weight_delta.shape).to(weight_delta.dtype) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def compute_transferability_masks( |
| model: AutoModelForCausalLM, |
| calibration_activations: dict, |
| threshold: float = 0.3, |
| ) -> dict: |
| """ |
| Compute per-parameter transferability masks using activation variance. |
| |
| High activation variance across diverse inputs → parameter encodes |
| task-specific knowledge (DON'T merge aggressively). |
| |
| Low activation variance → parameter encodes shared/general knowledge |
| (safe to merge/average). |
| |
| This is a simplified version of OTMF's OT-based mask discovery. |
| |
| Args: |
| model: The current merged model |
| calibration_activations: Layer → [samples, hidden_dim] activations |
| threshold: Variance quantile threshold for "task-specific" classification |
| |
| Returns: |
| Dict of param_name → bool mask (True = transferable/safe, False = task-specific/protect) |
| """ |
| print("[otmf] Computing transferability masks...") |
|
|
| masks = {} |
| state = model.state_dict() |
|
|
| |
| neuron_importance = {} |
| for layer_name, acts in calibration_activations.items(): |
| |
| variance = acts.var(dim=0) |
| neuron_importance[layer_name] = variance |
|
|
| |
| for param_name, param in state.items(): |
| |
| layer_prefix = ".".join(param_name.split(".")[:4]) |
|
|
| importance = None |
| for layer_name, var in neuron_importance.items(): |
| if layer_prefix in layer_name: |
| importance = var |
| break |
|
|
| if importance is None: |
| |
| masks[param_name] = torch.ones(param.shape, dtype=torch.bool) |
| continue |
|
|
| |
| if param.dim() == 2: |
| rows, cols = param.shape |
| imp_size = importance.shape[0] |
|
|
| |
| if importance.numel() == 0: |
| masks[param_name] = torch.ones(param.shape, dtype=torch.bool) |
| elif imp_size >= rows: |
| |
| imp = importance[:rows] |
| q = torch.quantile(imp.float(), 1.0 - threshold) |
| row_mask = imp < q |
| masks[param_name] = row_mask.unsqueeze(1).expand(rows, cols) |
| elif imp_size >= cols: |
| |
| |
| imp = importance[:cols] |
| q = torch.quantile(imp.float(), 1.0 - threshold) |
| col_mask = imp < q |
| masks[param_name] = col_mask.unsqueeze(0).expand(rows, cols) |
| else: |
| |
| masks[param_name] = torch.ones(param.shape, dtype=torch.bool) |
| else: |
| |
| masks[param_name] = torch.ones(param.shape, dtype=torch.bool) |
|
|
| transferable = sum(m.sum().item() for m in masks.values()) |
| total = sum(m.numel() for m in masks.values()) |
| print(f"[otmf] Transferability: {transferable / total:.1%} transferable, {1 - transferable / total:.1%} task-specific") |
|
|
| return masks |
|
|
|
|
| def apply_masked_merge( |
| target_state: dict, |
| fused_state: dict, |
| masks: dict, |
| protect_strength: float = 0.8, |
| ) -> dict: |
| """ |
| Apply transferability masks during merge. |
| |
| For transferable weights: use the fused (merged) value |
| For task-specific weights: preserve more of the original target value |
| |
| Args: |
| target_state: Original target weights (before this merge) |
| fused_state: Newly fused weights (after T&M/Theseus fusion) |
| masks: Transferability masks (True = safe to change) |
| protect_strength: How much to protect task-specific weights (0-1) |
| |
| Returns: |
| Masked merged state dict |
| """ |
| result = {} |
|
|
| for key in fused_state: |
| if key in masks and key in target_state: |
| mask = masks[key].to(fused_state[key].device) |
| original = target_state[key] |
| fused = fused_state[key] |
|
|
| |
| |
| blended = torch.where( |
| mask, |
| fused, |
| protect_strength * original + (1 - protect_strength) * fused, |
| ) |
| result[key] = blended |
| else: |
| result[key] = fused_state[key] |
|
|
| protected_params = sum(1 for k in masks if not masks[k].all()) |
| print(f"[otmf] Applied masks: {protected_params} parameters partially protected") |
|
|
| return result |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def disentangle_rl_weights( |
| rl_model: AutoModelForCausalLM, |
| base_model: AutoModelForCausalLM, |
| rl_threshold: float = 0.1, |
| ) -> tuple: |
| """ |
| Separate RL-specific weights from shared/general weights. |
| |
| RL-specific = weights that changed significantly during RL training |
| Shared = weights that are basically the same as base |
| |
| We identify RL-specific weights by looking at the magnitude of |
| change from base model to RL model. Big changes → RL learned |
| something there → don't average it away. |
| |
| Args: |
| rl_model: The RL-trained model (e.g., DeepSeek-R1, MiMo-7B-RL) |
| base_model: The base model before RL training |
| rl_threshold: Relative change threshold for "RL-specific" classification |
| |
| Returns: |
| Tuple of (shared_mask, rl_mask) — both are dicts of param_name → bool tensor |
| shared_mask: True = this weight is shared (safe to merge normally) |
| rl_mask: True = this weight is RL-specific (protect during merge) |
| """ |
| print("[ram] Disentangling RL-specific vs shared weights...") |
|
|
| rl_state = rl_model.state_dict() |
| base_state = base_model.state_dict() |
|
|
| shared_mask = {} |
| rl_mask = {} |
|
|
| total_params = 0 |
| rl_params = 0 |
|
|
| for key in rl_state: |
| if key not in base_state: |
| |
| rl_mask[key] = torch.ones_like(rl_state[key], dtype=torch.bool) |
| shared_mask[key] = torch.zeros_like(rl_state[key], dtype=torch.bool) |
| rl_params += rl_state[key].numel() |
| total_params += rl_state[key].numel() |
| continue |
|
|
| rl_w = rl_state[key].float() |
| base_w = base_state[key].float() |
|
|
| |
| change = (rl_w - base_w).abs() |
| base_magnitude = base_w.abs() + 1e-8 |
| relative_change = change / base_magnitude |
|
|
| |
| is_rl = relative_change > rl_threshold |
| rl_mask[key] = is_rl |
| shared_mask[key] = ~is_rl |
|
|
| rl_params += is_rl.sum().item() |
| total_params += is_rl.numel() |
|
|
| pct = rl_params / total_params * 100 if total_params > 0 else 0 |
| print(f"[ram] RL-specific: {rl_params:,} params ({pct:.1f}%)") |
| print(f"[ram] Shared: {total_params - rl_params:,} params ({100 - pct:.1f}%)") |
|
|
| return shared_mask, rl_mask |
|
|
|
|
| def merge_with_rl_preservation( |
| target_state: dict, |
| source_state: dict, |
| shared_mask: dict, |
| rl_mask: dict, |
| shared_alpha: float = 0.5, |
| rl_alpha: float = 0.8, |
| ) -> dict: |
| """ |
| Merge source into target while preserving RL-specific weights. |
| |
| Shared weights: normal blending at shared_alpha |
| RL-specific weights: stronger blending toward source (preserve RL knowledge) |
| |
| This prevents the RL reasoning capabilities from being diluted |
| by averaging with target weights. |
| |
| Args: |
| target_state: Current target model state |
| source_state: RL model state to merge in |
| shared_mask: Which params are shared (safe for normal merge) |
| rl_mask: Which params are RL-specific (preserve with higher alpha) |
| shared_alpha: Alpha for shared weights (normal) |
| rl_alpha: Alpha for RL-specific weights (higher = preserve more RL knowledge) |
| """ |
| print(f"[ram] Merging with RL preservation (shared α={shared_alpha}, RL α={rl_alpha})...") |
|
|
| result = {} |
| for key in target_state: |
| if key not in source_state: |
| result[key] = target_state[key] |
| continue |
|
|
| target_w = target_state[key] |
| source_w = source_state[key] |
|
|
| if source_w.shape != target_w.shape: |
| result[key] = target_state[key] |
| continue |
|
|
| if key in rl_mask and key in shared_mask: |
| rl_m = rl_mask[key].to(target_w.device) |
| |
| |
| alpha_map = torch.where(rl_m, rl_alpha, shared_alpha) |
| if alpha_map.shape != target_w.shape: |
| alpha_map = alpha_map.expand_as(target_w) if alpha_map.dim() > 0 else torch.full_like(target_w, shared_alpha) |
|
|
| result[key] = alpha_map * source_w.to(target_w.device) + (1 - alpha_map) * target_w |
| else: |
| result[key] = shared_alpha * source_w.to(target_w.device) + (1 - shared_alpha) * target_w |
|
|
| return result |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def compute_mergeability_score( |
| source_activations: dict, |
| target_activations: dict, |
| source_config: ModelConfig, |
| ) -> dict: |
| """ |
| Predict how well a source model will merge into the target. |
| |
| Scores based on three factors: |
| 1. Activation similarity (cosine similarity of mean activations) |
| 2. Dimensional compatibility (how similar are the layer shapes) |
| 3. Architecture match (same arch = bonus) |
| |
| Returns: |
| Dict with individual scores and overall mergeability (0-1) |
| """ |
| print(f"[mergeability] Scoring {source_config.name}...") |
|
|
| scores = {} |
|
|
| |
| cosine_sims = [] |
| source_layers = sorted(source_activations.keys()) |
| target_layers = sorted(target_activations.keys()) |
|
|
| |
| for i, tl in enumerate(target_layers): |
| |
| src_idx = int(i * len(source_layers) / len(target_layers)) |
| src_idx = min(src_idx, len(source_layers) - 1) |
| sl = source_layers[src_idx] |
|
|
| if sl in source_activations and tl in target_activations: |
| s_mean = source_activations[sl].float().mean(dim=0) |
| t_mean = target_activations[tl].float().mean(dim=0) |
|
|
| |
| max_dim = max(s_mean.shape[0], t_mean.shape[0]) |
| s_padded = torch.nn.functional.pad(s_mean, (0, max_dim - s_mean.shape[0])) |
| t_padded = torch.nn.functional.pad(t_mean, (0, max_dim - t_mean.shape[0])) |
|
|
| cos_sim = torch.nn.functional.cosine_similarity( |
| s_padded.unsqueeze(0), t_padded.unsqueeze(0) |
| ).item() |
| cosine_sims.append(cos_sim) |
|
|
| activation_score = np.mean(cosine_sims) if cosine_sims else 0.0 |
| scores["activation_similarity"] = float(activation_score) |
|
|
| |
| layer_ratio = min(source_config.layers, 36) / max(source_config.layers, 36) |
| hidden_ratio = min(source_config.hidden_dim, 4096) / max(source_config.hidden_dim, 4096) |
| dim_score = (layer_ratio + hidden_ratio) / 2 |
| scores["dimensional_compatibility"] = float(dim_score) |
|
|
| |
| arch_scores = { |
| "transformer": 1.0, |
| "transformer+mtp": 0.8, |
| "hybrid_ssm": 0.5, |
| } |
| arch_score = arch_scores.get(source_config.architecture, 0.3) |
| scores["architecture_match"] = float(arch_score) |
|
|
| |
| vocab_score = source_config.vocab_overlap_with_qwen3 |
| scores["vocab_overlap"] = float(vocab_score) |
|
|
| |
| overall = ( |
| 0.35 * activation_score + |
| 0.25 * dim_score + |
| 0.25 * arch_score + |
| 0.15 * vocab_score |
| ) |
| scores["overall"] = float(overall) |
|
|
| |
| if overall >= 0.7: |
| recommendation = "GO — standard T&M merge" |
| elif overall >= 0.5: |
| recommendation = "CAUTION — T&M merge with higher protection, have Theseus fallback ready" |
| elif overall >= 0.3: |
| recommendation = "RISKY — try Theseus first, distillation fallback" |
| else: |
| recommendation = "SKIP — use knowledge distillation instead" |
|
|
| scores["recommendation"] = recommendation |
|
|
| print(f"[mergeability] {source_config.name} score: {overall:.2f}") |
| print(f" Activation similarity: {activation_score:.2f}") |
| print(f" Dimensional compat: {dim_score:.2f}") |
| print(f" Architecture match: {arch_score:.2f}") |
| print(f" Vocab overlap: {vocab_score:.2f}") |
| print(f" → {recommendation}") |
|
|
| return scores |
|
|