| """ |
| Transport and Merge Wrapper — interfaces with official T&M code. |
| |
| This wraps the official repo at: |
| github.com/chenhangcuisg-code/Cross-Architecture-Merging-for-Large-Language-Models/ |
| |
| We use THEIR code for: |
| - Correlation distance computation (corr_distance_matrix) |
| - Streaming Sinkhorn (sinkhorn_uniform_streaming) |
| - Transport plan computation (compute_P, compute_Q_and_layer_costs) |
| - Activation reconstruction (reconstruct_X) |
| |
| We add: |
| - Qwen3 thinking mode protection |
| - MiMo MTP head handling |
| - Falcon SSM component handling |
| - Sequential merge protection (MagMax + orthogonal projection) |
| - Progress reporting every 5 minutes |
| - Timeouts to prevent infinite hangs |
| |
| Findings: #01, #07, #24 |
| """ |
|
|
| import sys |
| import time |
| import hashlib |
| import torch |
| import numpy as np |
| from pathlib import Path |
| from typing import Optional |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from datasets import load_dataset |
|
|
| from .config import MergeConfig, ModelConfig, TARGET |
|
|
|
|
| |
| |
| |
|
|
| class ProgressTracker: |
| """Prints a heartbeat every interval_seconds so you know it's not stuck.""" |
|
|
| def __init__(self, task_name: str, interval_seconds: int = 300): |
| self.task_name = task_name |
| self.interval = interval_seconds |
| self.start_time = time.time() |
| self.last_report = self.start_time |
| self.step = 0 |
| self.total_steps = 0 |
| print(f"\n[{task_name}] Started at {time.strftime('%H:%M:%S')}") |
|
|
| def set_total(self, total: int): |
| self.total_steps = total |
|
|
| def tick(self, step_name: str = ""): |
| """Call this inside loops. Prints progress if 5 min have passed.""" |
| self.step += 1 |
| now = time.time() |
| elapsed = now - self.start_time |
| since_last = now - self.last_report |
|
|
| if since_last >= self.interval: |
| pct = f"{self.step}/{self.total_steps} ({100*self.step/self.total_steps:.0f}%)" if self.total_steps else f"step {self.step}" |
| eta = "" |
| if self.total_steps and self.step > 0: |
| rate = elapsed / self.step |
| remaining = (self.total_steps - self.step) * rate |
| eta = f", ETA {remaining/60:.1f} min" |
| print(f"[{self.task_name}] HEARTBEAT — {pct}, elapsed {elapsed/60:.1f} min{eta} | {step_name}") |
| sys.stdout.flush() |
| self.last_report = now |
|
|
| def done(self): |
| elapsed = time.time() - self.start_time |
| print(f"[{self.task_name}] Completed in {elapsed/60:.1f} min ({elapsed:.0f}s)") |
| sys.stdout.flush() |
|
|
| def check_timeout(self, timeout_seconds: int = 3600): |
| """Raise if we've been running longer than timeout_seconds.""" |
| elapsed = time.time() - self.start_time |
| if elapsed > timeout_seconds: |
| raise TimeoutError( |
| f"[{self.task_name}] TIMEOUT after {elapsed/60:.1f} min " |
| f"(limit: {timeout_seconds/60:.0f} min). Something is wrong." |
| ) |
|
|
|
|
| def setup_tm_repo(cfg: MergeConfig): |
| """Add official T&M repo to Python path so we can import their code.""" |
| repo_path = Path(cfg.tm_repo_path) |
| core_path = repo_path / "core" |
|
|
| if not core_path.exists(): |
| raise FileNotFoundError( |
| f"Official T&M repo not found at {repo_path}\n" |
| f"Please clone it:\n" |
| f" git clone https://github.com/chenhangcuisg-code/" |
| f"Cross-Architecture-Merging-for-Large-Language-Models.git" |
| ) |
|
|
| |
| if str(core_path) not in sys.path: |
| sys.path.insert(0, str(core_path)) |
| print(f"[transport] Added T&M core to path: {core_path}") |
|
|
|
|
| def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> tuple: |
| """ |
| Load calibration data for activation extraction. |
| |
| Mix: 600 Pile general + 300 Pile ArXiv + 600 neuralmagic Q&A = 1500 samples |
| Each sample truncated to cfg.calibration_seq_len tokens. |
| |
| Findings: #08 |
| """ |
| tracker = ProgressTracker("calibration-data", interval_seconds=120) |
| print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...") |
|
|
| samples = [] |
| raw_texts = [] |
|
|
| |
| try: |
| pile = load_dataset( |
| cfg.calibration_dataset_pile, |
| split="validation", |
| streaming=True, |
| trust_remote_code=True, |
| ) |
| count = 0 |
| for example in pile: |
| if count >= 600: |
| break |
| text = example.get("text", "") |
| if len(text) > 100: |
| tokens = tokenizer( |
| text, |
| truncation=True, |
| max_length=cfg.calibration_seq_len, |
| return_tensors="pt", |
| ) |
| samples.append(tokens) |
| raw_texts.append(text) |
| count += 1 |
| if count % 100 == 0: |
| print(f" Pile: {count}/600 samples loaded...") |
| sys.stdout.flush() |
| print(f" Pile general: {count} samples") |
| except Exception as e: |
| print(f" WARNING: Pile failed: {e}") |
| print(f" Falling back to neuralmagic only") |
|
|
| |
| remaining = cfg.calibration_samples - len(samples) |
| if remaining > 0: |
| try: |
| nm = load_dataset( |
| cfg.calibration_dataset_nm, |
| split="train", |
| trust_remote_code=True, |
| ) |
| count = 0 |
| for example in nm: |
| if count >= remaining: |
| break |
| text = example.get("text", example.get("content", "")) |
| if len(str(text)) > 50: |
| tokens = tokenizer( |
| str(text), |
| truncation=True, |
| max_length=cfg.calibration_seq_len, |
| return_tensors="pt", |
| ) |
| samples.append(tokens) |
| raw_texts.append(str(text)) |
| count += 1 |
| if count % 100 == 0: |
| print(f" neuralmagic: {count}/{remaining} samples loaded...") |
| sys.stdout.flush() |
| print(f" neuralmagic: {count} samples") |
| except Exception as e: |
| print(f" WARNING: neuralmagic failed: {e}") |
|
|
| tracker.done() |
| print(f"[transport] Total calibration samples: {len(samples)}") |
| sys.stdout.flush() |
| return samples, raw_texts |
|
|
|
|
| def retokenize_calibration(raw_texts: list, tokenizer: AutoTokenizer, cfg: MergeConfig) -> list: |
| """ |
| Re-tokenize calibration texts with a different tokenizer. |
| |
| Used when the source model has a different vocabulary than the target. |
| For example, Llama (128K vocab) vs Qwen (152K vocab). |
| """ |
| print(f"[transport] Re-tokenizing {len(raw_texts)} samples for source model vocabulary...") |
| sys.stdout.flush() |
| samples = [] |
| for i, text in enumerate(raw_texts): |
| tokens = tokenizer( |
| text, |
| truncation=True, |
| max_length=cfg.calibration_seq_len, |
| return_tensors="pt", |
| ) |
| samples.append(tokens) |
| if (i + 1) % 500 == 0: |
| print(f" Re-tokenized {i + 1}/{len(raw_texts)} samples...") |
| sys.stdout.flush() |
| print(f"[transport] Re-tokenized {len(samples)} samples for source model") |
| sys.stdout.flush() |
| return samples |
|
|
|
|
| def extract_activations( |
| model: AutoModelForCausalLM, |
| calibration_data: list, |
| device: str = "cuda", |
| ) -> dict: |
| """ |
| Extract intermediate activations from each layer of a model. |
| |
| Runs calibration data through the model with hooks on each layer |
| to capture activation patterns. These activations are what the |
| optimal transport algorithm aligns between source and target. |
| |
| Returns: |
| Dict mapping layer_name -> activation tensor [num_samples, hidden_dim] |
| """ |
| tracker = ProgressTracker("extract-activations", interval_seconds=300) |
| tracker.set_total(len(calibration_data)) |
| print(f"[transport] Extracting activations from {len(calibration_data)} samples...") |
| sys.stdout.flush() |
|
|
| activations = {} |
| hooks = [] |
|
|
| |
| for name, module in model.named_modules(): |
| if hasattr(module, "self_attn") or name.endswith(".mlp"): |
| |
| def make_hook(layer_name): |
| def hook_fn(module, input, output): |
| |
| if isinstance(output, tuple): |
| act = output[0] |
| else: |
| act = output |
| if layer_name not in activations: |
| activations[layer_name] = [] |
| |
| activations[layer_name].append( |
| act.detach().float().mean(dim=1).cpu() |
| ) |
| return hook_fn |
|
|
| h = module.register_forward_hook(make_hook(name)) |
| hooks.append(h) |
|
|
| |
| model.eval() |
| with torch.no_grad(): |
| for i, tokens in enumerate(calibration_data): |
| inputs = {k: v.to(device) for k, v in tokens.items()} |
| try: |
| model(**inputs) |
| except Exception as e: |
| print(f" WARNING: Sample {i} failed: {e}") |
| continue |
|
|
| tracker.tick(f"sample {i+1}") |
|
|
| if (i + 1) % 100 == 0: |
| print(f" Processed {i + 1}/{len(calibration_data)} samples") |
| sys.stdout.flush() |
|
|
| |
| tracker.check_timeout(timeout_seconds=1800) |
|
|
| |
| for h in hooks: |
| h.remove() |
|
|
| |
| layer_count = 0 |
| for key in activations: |
| activations[key] = torch.cat(activations[key], dim=0) |
| layer_count += 1 |
|
|
| print(f" Extracted {layer_count} layers, shapes: {activations[list(activations.keys())[0]].shape if activations else 'empty'}") |
| tracker.done() |
| sys.stdout.flush() |
|
|
| return activations |
|
|
|
|
| def compute_transport_plans( |
| source_activations: dict, |
| target_activations: dict, |
| cfg: MergeConfig, |
| ) -> dict: |
| """ |
| Compute optimal transport plans between source and target activations. |
| |
| This is where the magic happens. We use the official T&M code's: |
| - corr_distance_matrix: correlation distance between activation vectors |
| - sinkhorn_uniform_streaming: memory-efficient Sinkhorn solver |
| - compute_P: layer-level coupling (which source layers -> which target layers) |
| - compute_Q_and_layer_costs: neuron-level coupling within each layer pair |
| |
| Returns: |
| Dict with 'P' (layer coupling) and 'Q' (per-layer neuron coupling) matrices |
| """ |
| print("[transport] Computing transport plans...") |
| sys.stdout.flush() |
|
|
| try: |
| |
| from hot_transport import ( |
| corr_distance_matrix, |
| sinkhorn_uniform_streaming, |
| compute_P, |
| compute_Q_and_layer_costs, |
| ) |
| print("[transport] Using official T&M implementation") |
| return _compute_plans_official( |
| source_activations, target_activations, cfg, |
| corr_distance_matrix, sinkhorn_uniform_streaming, |
| compute_P, compute_Q_and_layer_costs, |
| ) |
| except ImportError: |
| print("[transport] Official T&M code not available, using fallback") |
| return _compute_plans_fallback( |
| source_activations, target_activations, cfg |
| ) |
|
|
|
|
| def _compute_plans_official( |
| source_act, target_act, cfg, |
| corr_distance_matrix, sinkhorn_uniform_streaming, |
| compute_P, compute_Q_and_layer_costs, |
| ) -> dict: |
| """Use the official T&M code to compute transport plans.""" |
|
|
| |
| source_layers = sorted(source_act.keys()) |
| target_layers = sorted(target_act.keys()) |
|
|
| |
| Q_matrices, layer_costs = compute_Q_and_layer_costs( |
| source_act, target_act, |
| source_layers, target_layers, |
| ) |
|
|
| |
| P = compute_P(layer_costs) |
|
|
| return { |
| "P": P, |
| "Q": Q_matrices, |
| "source_layers": source_layers, |
| "target_layers": target_layers, |
| } |
|
|
|
|
| def _compute_plans_fallback( |
| source_act: dict, |
| target_act: dict, |
| cfg: MergeConfig, |
| ) -> dict: |
| """ |
| Fallback transport plan computation when official code isn't available. |
| |
| Smart routing: |
| - Same-architecture models (same layer count): direct 1:1 layer matching |
| (no OT needed, just identity permutation -- fast!) |
| - Cross-architecture: sparse OT (only top-3 source layers per target) |
| """ |
| tracker = ProgressTracker("transport-plans", interval_seconds=300) |
|
|
| source_layers = sorted(source_act.keys()) |
| target_layers = sorted(target_act.keys()) |
|
|
| n_source = len(source_layers) |
| n_target = len(target_layers) |
|
|
| print(f"[transport] Source layers: {n_source}, Target layers: {n_target}") |
| sys.stdout.flush() |
|
|
| |
| |
| |
| |
| |
| if n_source == n_target: |
| print("[transport] Same layer count -- using direct 1:1 layer matching") |
| sys.stdout.flush() |
| Q_matrices = {} |
| permutations = {} |
| P = np.eye(n_source) / n_source |
| tracker.set_total(n_source) |
|
|
| |
| first_sl = source_layers[0] |
| first_tl = target_layers[0] |
| S0 = source_act[first_sl].numpy() |
| T0 = target_act[first_tl].numpy() |
| if S0.shape[1] == T0.shape[1]: |
| S0_norm = (S0 - S0.mean(0)) / (S0.std(0) + 1e-8) |
| T0_norm = (T0 - T0.mean(0)) / (T0.std(0) + 1e-8) |
| diag_corr = np.mean(np.sum(S0_norm * T0_norm, axis=0) / S0.shape[0]) |
| neurons_aligned = diag_corr > 0.3 |
| else: |
| neurons_aligned = False |
|
|
| if neurons_aligned: |
| print(f"[transport] Neurons ARE aligned (diag_corr={diag_corr:.3f}) — identity Q (fast)") |
| print("[transport] This should take under 1 minute...") |
| else: |
| corr_val = diag_corr if S0.shape[1] == T0.shape[1] else 0.0 |
| print(f"[transport] Neurons NOT aligned (diag_corr={corr_val:.3f}) — computing permutations via Sinkhorn") |
|
|
| |
| |
| perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache" |
| src_name = "_".join(sorted(source_act.keys())[:3]) |
| cache_file = perm_cache_dir / f"perms_{n_source}_{int(hashlib.md5(src_name.encode()).hexdigest()[:8], 16)}.npz" |
| hf_cache_file = Path("perm_cache") / f"perms_{n_source}_{int(hashlib.md5(src_name.encode()).hexdigest()[:8], 16)}.npz" |
| if not cache_file.exists() and hf_cache_file.exists(): |
| cache_file = hf_cache_file |
| if cache_file.exists(): |
| print(f"[transport] LOADING CACHED permutations from {cache_file}") |
| cached = np.load(str(cache_file), allow_pickle=True) |
| for i, (sl, tl) in enumerate(zip(source_layers, target_layers)): |
| key = f"{sl}__{tl}" |
| if key in cached: |
| permutations[(sl, tl)] = cached[key] |
| Q_matrices[(sl, tl)] = np.eye(S0.shape[1]) / S0.shape[1] |
| tracker.tick(f"{sl} -> {tl}") |
| print(f"[transport] Loaded {len(permutations)} cached permutations (skipped Sinkhorn!)") |
| tracker.done() |
| sys.stdout.flush() |
| return { |
| "P": P, |
| "Q": Q_matrices, |
| "permutations": permutations, |
| "source_layers": source_layers, |
| "target_layers": target_layers, |
| } |
|
|
| print("[transport] No cache found — computing fresh (will cache for next time)...") |
| sys.stdout.flush() |
|
|
| |
| block_perms = {} |
|
|
| for i, (sl, tl) in enumerate(zip(source_layers, target_layers)): |
| S = source_act[sl].numpy() |
| T = target_act[tl].numpy() |
|
|
| if S.shape[1] == T.shape[1]: |
| if neurons_aligned: |
| |
| Q_matrices[(sl, tl)] = np.eye(S.shape[1]) / S.shape[1] |
| else: |
| |
| block_idx = None |
| for part_j, part in enumerate(tl.split(".")): |
| if part == "layers": |
| try: |
| block_idx = int(tl.split(".")[part_j + 1]) |
| except (ValueError, IndexError): |
| pass |
| break |
|
|
| |
| if block_idx is not None and block_idx in block_perms: |
| perm = block_perms[block_idx] |
| permutations[(sl, tl)] = perm |
| Q_matrices[(sl, tl)] = np.eye(S.shape[1]) / S.shape[1] |
| else: |
| |
| |
| S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8) |
| T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8) |
| corr = S_norm.T @ T_norm / S.shape[0] |
|
|
| |
| |
| cost = 1.0 - corr |
| Q_soft = _sinkhorn(cost, reg=0.1, max_iter=30) |
|
|
| |
| perm = np.argmax(Q_soft, axis=1) |
|
|
| |
| if len(set(perm)) < len(perm) * 0.9: |
| |
| perm = _greedy_permutation(corr) |
|
|
| permutations[(sl, tl)] = perm |
| Q_matrices[(sl, tl)] = Q_soft |
| if block_idx is not None: |
| block_perms[block_idx] = perm |
| else: |
| |
| print(f" Layer {i}: dim mismatch ({S.shape[1]} vs {T.shape[1]}), using Sinkhorn...") |
| S_norm = (S - S.mean(0)) / (S.std(0) + 1e-8) |
| T_norm = (T - T.mean(0)) / (T.std(0) + 1e-8) |
| corr = S_norm.T @ T_norm / S.shape[0] |
| cost = 1.0 - corr |
| Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50) |
|
|
| tracker.tick(f"{sl} -> {tl}") |
|
|
| if (i + 1) % 10 == 0 or i == 0: |
| print(f" Matched layer {i + 1}/{n_source}: {sl} -> {tl}") |
| sys.stdout.flush() |
|
|
| |
| tracker.check_timeout(timeout_seconds=10800) |
|
|
| if permutations: |
| print(f"[transport] Computed {len(permutations)} neuron permutations") |
| |
| try: |
| perm_cache_dir = Path("td_fuse_checkpoints") / "perm_cache" |
| perm_cache_dir.mkdir(parents=True, exist_ok=True) |
| src_name = "_".join(sorted(source_act.keys())[:3]) |
| cache_file = perm_cache_dir / f"perms_{n_source}_{int(hashlib.md5(src_name.encode()).hexdigest()[:8], 16)}.npz" |
| save_dict = {f"{sl}__{tl}": perm for (sl, tl), perm in permutations.items()} |
| np.savez_compressed(str(cache_file), **save_dict) |
| print(f"[transport] Cached permutations to {cache_file} ({cache_file.stat().st_size // 1024} KB)") |
| except Exception as e: |
| print(f"[transport] WARNING: Could not cache permutations ({e})") |
| print(f"[transport] Direct matching complete: {n_source} layer pairs") |
| tracker.done() |
| sys.stdout.flush() |
| return { |
| "P": P, |
| "Q": Q_matrices, |
| "permutations": permutations, |
| "source_layers": source_layers, |
| "target_layers": target_layers, |
| } |
|
|
| |
| |
| print(f"[transport] Cross-architecture -- using sparse OT (top-3 per target)") |
| print(f"[transport] Estimated time: 5-15 minutes") |
| sys.stdout.flush() |
|
|
| |
| print("[transport] Step 1/3: Computing layer-level similarities...") |
| sys.stdout.flush() |
| layer_costs = np.zeros((n_source, n_target)) |
| tracker.set_total(n_source * n_target + n_target * 3) |
| for i, sl in enumerate(source_layers): |
| for j, tl in enumerate(target_layers): |
| S_mean = source_act[sl].mean(0).numpy() |
| T_mean = target_act[tl].mean(0).numpy() |
| |
| min_dim = min(len(S_mean), len(T_mean)) |
| s = S_mean[:min_dim] |
| t = T_mean[:min_dim] |
| sim = np.dot(s, t) / (np.linalg.norm(s) * np.linalg.norm(t) + 1e-8) |
| layer_costs[i, j] = 1.0 - sim |
| tracker.tick(f"layer sim {i},{j}") |
|
|
| |
| tracker.check_timeout(timeout_seconds=10800) |
|
|
| print(f"[transport] Step 1/3 done: {n_source}x{n_target} similarities computed") |
| sys.stdout.flush() |
|
|
| |
| print("[transport] Step 2/3: Computing neuron-level transport (top-3 per target)...") |
| sys.stdout.flush() |
| Q_matrices = {} |
|
|
| |
| q_cache_dir = Path("td_fuse_checkpoints") / "q_cache_crossarch" |
| q_cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
| for j, tl in enumerate(target_layers): |
| top3 = np.argsort(layer_costs[:, j])[:3] |
| for i in top3: |
| sl = source_layers[i] |
| cache_key = f"{sl}__{tl}".replace("/", "_").replace(".", "_") |
| cache_path = q_cache_dir / f"{cache_key}.npy" |
|
|
| |
| if cache_path.exists(): |
| Q_matrices[(sl, tl)] = np.load(str(cache_path)) |
| tracker.tick(f"Q({sl},{tl})") |
| continue |
|
|
| S = source_act[sl].numpy() |
| T = target_act[tl].numpy() |
|
|
| |
| min_dim = min(S.shape[1], T.shape[1]) |
| S_sub = S[:, :min_dim] |
| T_sub = T[:, :min_dim] |
| S_norm = (S_sub - S_sub.mean(0)) / (S_sub.std(0) + 1e-8) |
| T_norm = (T_sub - T_sub.mean(0)) / (T_sub.std(0) + 1e-8) |
| corr = S_norm.T @ T_norm / S.shape[0] |
| cost = 1.0 - corr |
| Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50) |
| np.save(str(cache_path), Q_matrices[(sl, tl)]) |
| tracker.tick(f"Q({sl},{tl})") |
|
|
| if (j + 1) % 5 == 0 or j == 0: |
| print(f" Target layer {j + 1}/{n_target}: matched to top-3 sources") |
| sys.stdout.flush() |
|
|
| |
| tracker.check_timeout(timeout_seconds=10800) |
|
|
| print(f"[transport] Step 2/3 done: {len(Q_matrices)} Q matrices computed") |
| sys.stdout.flush() |
|
|
| |
| print("[transport] Step 3/3: Computing layer coupling P matrix...") |
| sys.stdout.flush() |
| P = _sinkhorn(layer_costs, reg=0.1, max_iter=50) |
|
|
| print(f"[transport] Sparse OT complete: {len(Q_matrices)} layer pairs computed") |
| tracker.done() |
| sys.stdout.flush() |
| return { |
| "P": P, |
| "Q": Q_matrices, |
| "permutations": {}, |
| "source_layers": source_layers, |
| "target_layers": target_layers, |
| } |
|
|
|
|
| def _sinkhorn( |
| cost_matrix: np.ndarray, |
| reg: float = 0.05, |
| max_iter: int = 100, |
| ) -> np.ndarray: |
| """ |
| Basic Sinkhorn-Knopp algorithm for optimal transport. |
| |
| Solves: min <T, C> - reg * H(T) |
| where H(T) is the entropy of the transport plan. |
| |
| This is the FALLBACK. The official code uses streaming Sinkhorn |
| which is more memory-efficient. |
| """ |
| n, m = cost_matrix.shape |
| K = np.exp(-cost_matrix / reg) |
|
|
| u = np.ones(n) / n |
| v = np.ones(m) / m |
|
|
| for iteration in range(max_iter): |
| u = 1.0 / (K @ v + 1e-10) |
| v = 1.0 / (K.T @ u + 1e-10) |
|
|
| |
| T = np.diag(u) @ K @ np.diag(v) |
| return T |
|
|
|
|
| def _greedy_permutation(corr_matrix: np.ndarray) -> np.ndarray: |
| """ |
| Greedy permutation assignment when Sinkhorn gives duplicate mappings. |
| |
| For each source neuron (in order of strongest match), assign it to the |
| best available target neuron that hasn't been taken yet. |
| """ |
| n = corr_matrix.shape[0] |
| perm = np.full(n, -1, dtype=np.int64) |
| taken = set() |
|
|
| |
| best_scores = np.max(corr_matrix, axis=1) |
| order = np.argsort(-best_scores) |
|
|
| for src in order: |
| |
| sorted_targets = np.argsort(-corr_matrix[src]) |
| for tgt in sorted_targets: |
| if tgt not in taken: |
| perm[src] = tgt |
| taken.add(tgt) |
| break |
|
|
| |
| remaining = set(range(n)) - taken |
| for src in range(n): |
| if perm[src] == -1: |
| perm[src] = remaining.pop() |
|
|
| return perm |
|
|
|
|
| def _apply_permutation(source_w: torch.Tensor, perm: np.ndarray, key: str) -> torch.Tensor: |
| """ |
| Apply neuron permutation to a source weight tensor before blending. |
| |
| The permutation rearranges MiMo's neurons to match Qwen3's ordering. |
| Think of it like reorganising filing cabinets: same files, different order. |
| |
| Which dimension to permute depends on the weight type: |
| - Input projections (q_proj, k_proj, v_proj, gate_proj, up_proj): |
| shape [out_features, in_features] → permute columns (dim 1) |
| because input neurons need reordering |
| - Output projections (o_proj, down_proj): |
| shape [out_features, in_features] → permute rows (dim 0) |
| because output neurons need reordering |
| - 1D weights (layer_norm, bias): |
| permute directly |
| """ |
| perm_tensor = torch.from_numpy(perm).long() |
|
|
| if source_w.dim() == 1: |
| |
| if len(perm_tensor) == source_w.shape[0]: |
| return source_w[perm_tensor] |
| return source_w |
|
|
| if source_w.dim() == 2: |
| |
| out_features, in_features = source_w.shape |
|
|
| |
| if any(proj in key for proj in ["o_proj", "down_proj"]): |
| if len(perm_tensor) == out_features: |
| return source_w[perm_tensor, :] |
| |
| elif any(proj in key for proj in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]): |
| if len(perm_tensor) == in_features: |
| return source_w[:, perm_tensor] |
| |
| else: |
| if len(perm_tensor) == in_features: |
| return source_w[:, perm_tensor] |
| elif len(perm_tensor) == out_features: |
| return source_w[perm_tensor, :] |
|
|
| |
| return source_w |
|
|
|
|
| def fuse_weights( |
| source_state: dict, |
| target_model: AutoModelForCausalLM, |
| transport_plans: dict, |
| source_config: ModelConfig, |
| cfg: MergeConfig, |
| target_activations: dict = None, |
| ) -> AutoModelForCausalLM: |
| """ |
| Fuse source model weights into target model using transport plans. |
| |
| For each layer pair with significant coupling (P > threshold): |
| 1. Get the Q matrix (neuron-level correspondence) |
| 2. Transport source weights into target neuron basis: W_fused = Q @ W_source |
| 3. Blend with target: W_final = alpha * W_fused + (1-alpha) * W_target |
| |
| Args: |
| source_state: Source model state dict (can be on CPU — will be moved per-param) |
| target_model: Target model (on GPU) |
| transport_plans: Transport plan matrices from compute_transport_plans |
| source_config: Source model config |
| cfg: Merge configuration |
| |
| Special handling per model: |
| - DeepSeek: Direct merge (same architecture) |
| - MiMo: Skip MTP heads, skip embeddings |
| - Llama: Layer mapping (32->36), skip embeddings, drop QKV bias |
| - Falcon: Skip Mamba components, skip embeddings |
| |
| Returns: |
| Target model with fused weights |
| """ |
| tracker = ProgressTracker("fuse-weights", interval_seconds=300) |
| print(f"\n[transport] Fusing {source_config.name} -> target") |
| alpha = source_config.merge_alpha |
|
|
| try: |
| |
| from generate_hot_residual import fuse_attention_only_from_hot_dir |
| print("[transport] Using official fusion implementation") |
| |
| |
| except ImportError: |
| pass |
|
|
| |
| |
| target_state = target_model.state_dict() |
| P = transport_plans["P"] |
| Q = transport_plans["Q"] |
| permutations = transport_plans.get("permutations", {}) |
|
|
| |
| |
| |
| |
| layer_perms = {} |
| for (sl, tl), perm in permutations.items(): |
| |
| parts = tl.split(".") |
| for j, part in enumerate(parts): |
| if part == "layers" and j + 1 < len(parts): |
| try: |
| layer_idx = int(parts[j + 1]) |
| layer_perms[layer_idx] = perm |
| except ValueError: |
| pass |
| break |
|
|
| if permutations: |
| print(f"[transport] Will apply neuron permutations to {len(layer_perms)} layers before blending") |
| else: |
| print("[transport] No neuron permutations needed (neurons already aligned)") |
|
|
| fused_count = 0 |
| skipped_count = 0 |
| permuted_count = 0 |
| total_params = len(target_state) |
| tracker.set_total(total_params) |
|
|
| for target_key in target_state: |
| tracker.tick(target_key) |
|
|
| |
| if _should_skip(target_key, source_config): |
| skipped_count += 1 |
| continue |
|
|
| |
| source_key = _map_key(target_key, source_config) |
| if source_key is None or source_key not in source_state: |
| skipped_count += 1 |
| |
| if skipped_count <= 5: |
| print(f" [skip] No source match for: {target_key} (mapped to: {source_key})") |
| sys.stdout.flush() |
| continue |
|
|
| target_w = target_state[target_key] |
| source_w = source_state[source_key] |
|
|
| |
| if target_w.shape != source_w.shape: |
| |
| source_w = _align_dimensions(source_w, target_w.shape, Q, target_key) |
| if source_w is None: |
| skipped_count += 1 |
| continue |
|
|
| |
| |
| |
| if layer_perms: |
| |
| key_parts = target_key.split(".") |
| for j, part in enumerate(key_parts): |
| if part == "layers" and j + 1 < len(key_parts): |
| try: |
| lidx = int(key_parts[j + 1]) |
| if lidx in layer_perms: |
| source_w = _apply_permutation(source_w, layer_perms[lidx], target_key) |
| permuted_count += 1 |
| except ValueError: |
| pass |
| break |
|
|
| |
| fused_w = alpha * source_w.to(target_w.device) + (1 - alpha) * target_w |
| target_state[target_key] = fused_w |
| fused_count += 1 |
|
|
| |
| if cfg.freeze_think_tokens and "embed_tokens" in target_key: |
| for token_id in cfg.think_token_ids: |
| if token_id < target_state[target_key].shape[0]: |
| |
| orig_embed = target_model.state_dict()[target_key] |
| target_state[target_key][token_id] = orig_embed[token_id] |
| print(f"[transport] Protected think token {token_id}") |
|
|
| if fused_count % 50 == 0: |
| print(f" Fused {fused_count} params so far (skipped {skipped_count})...") |
| sys.stdout.flush() |
|
|
| |
| tracker.check_timeout(timeout_seconds=1200) |
|
|
| |
| |
| missing, unexpected = target_model.load_state_dict(target_state, strict=False) |
| if missing: |
| print(f"[transport] NOTE: {len(missing)} missing keys (likely quantized vision params — safe to ignore)") |
| if unexpected: |
| print(f"[transport] NOTE: {len(unexpected)} unexpected keys (safe to ignore)") |
| perm_msg = f", permuted {permuted_count}" if permuted_count else "" |
| print(f"[transport] Fused {fused_count} params, skipped {skipped_count}{perm_msg}") |
| tracker.done() |
| sys.stdout.flush() |
|
|
| return target_model |
|
|
|
|
| def _should_skip(key: str, source_config: ModelConfig) -> bool: |
| """Determine if a parameter should be skipped during merge.""" |
|
|
| |
| if key.startswith("visual") or key.startswith("merger") or key.startswith("model.visual") or key.startswith("model.merger"): |
| return True |
|
|
| |
| if source_config.skip_embeddings and ("embed_tokens" in key or "lm_head" in key): |
| return True |
|
|
| |
| if "drop_mtp_heads" in source_config.special_handling and "mtp_head" in key: |
| return True |
|
|
| |
| if "drop_mamba_state_params" in source_config.special_handling: |
| mamba_keys = ["mamba", "A_log", "dt_proj", ".D"] |
| if any(mk in key for mk in mamba_keys): |
| return True |
|
|
| |
| if "drop_qkv_bias" in source_config.special_handling and ".bias" in key: |
| if any(proj in key for proj in ["q_proj", "k_proj", "v_proj"]): |
| return True |
|
|
| return False |
|
|
|
|
| def _strip_vl_prefix(key: str) -> str: |
| """ |
| Strip the 'language_model.' prefix that Qwen3-VL adds. |
| |
| Qwen3-VL wraps all language params under 'model.language_model.*' |
| but source models (DeepSeek, MiMo, Llama, Falcon) use 'model.*' directly. |
| |
| Example: |
| target: model.language_model.layers.0.self_attn.q_proj.weight |
| source: model.layers.0.self_attn.q_proj.weight |
| """ |
| |
| if "language_model." in key: |
| return key.replace("language_model.", "") |
| return key |
|
|
|
|
| def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]: |
| """Map a target model parameter name to the corresponding source name.""" |
|
|
| |
| source_key = _strip_vl_prefix(target_key) |
|
|
| |
| if source_config.architecture == "transformer" and source_config.layers == 36: |
| return source_key |
|
|
| |
| if "layer_mapping_32_to_36" in source_config.special_handling: |
| if "model.layers." in source_key: |
| |
| parts = source_key.split(".") |
| try: |
| layer_idx = int(parts[2]) |
| except (IndexError, ValueError): |
| return source_key |
|
|
| |
| source_layer = int(layer_idx * 32 / 36) |
| parts[2] = str(source_layer) |
| return ".".join(parts) |
|
|
| |
| if source_config.architecture == "transformer+mtp": |
| if "mtp_head" in source_key: |
| return None |
| return source_key |
|
|
| |
| if source_config.architecture == "hybrid_ssm": |
| if any(k in source_key for k in ["self_attn", "mlp", "layer_norm"]): |
| return source_key |
| return None |
|
|
| return source_key |
|
|
|
|
| def _align_dimensions( |
| source_w: torch.Tensor, |
| target_shape: tuple, |
| Q_matrices: dict, |
| key: str, |
| ) -> Optional[torch.Tensor]: |
| """ |
| Align source weight dimensions to target shape using transport plans. |
| |
| For small mismatches: pad or truncate. |
| For large mismatches: use Q matrix to project. |
| """ |
| if source_w.shape == target_shape: |
| return source_w |
|
|
| |
| if len(source_w.shape) == 2 and len(target_shape) == 2: |
| s_rows, s_cols = source_w.shape |
| t_rows, t_cols = target_shape |
|
|
| result = torch.zeros(target_shape, dtype=source_w.dtype) |
|
|
| |
| min_rows = min(s_rows, t_rows) |
| min_cols = min(s_cols, t_cols) |
| result[:min_rows, :min_cols] = source_w[:min_rows, :min_cols] |
|
|
| return result |
|
|
| |
| if len(source_w.shape) == 1 and len(target_shape) == 1: |
| result = torch.zeros(target_shape, dtype=source_w.dtype) |
| min_len = min(source_w.shape[0], target_shape[0]) |
| result[:min_len] = source_w[:min_len] |
| return result |
|
|
| |
| return None |
|
|