#!/usr/bin/env python3 """Model and layer helpers for fuse_layers.""" import os from typing import Dict, List, Optional, Tuple import torch try: from tqdm import tqdm except Exception: # pragma: no cover - optional dependency tqdm = None def _tqdm_enabled() -> bool: value = os.environ.get("DISABLE_TQDM", os.environ.get("TQDM_DISABLE", "0")) return value.strip().lower() not in {"1", "true", "yes", "on"} def get_dtype(dtype: str): if dtype == "auto": return None if dtype == "float16": return torch.float16 if dtype == "bfloat16": return torch.bfloat16 return torch.float32 def resolve_attr(root: object, path: str) -> Optional[object]: cur = root for part in path.split("."): if not hasattr(cur, part): return None cur = getattr(cur, part) return cur def resolve_attr_with_parent(root: object, path: str) -> Tuple[object, str, object]: parts = path.split(".") cur = root for part in parts[:-1]: if not hasattr(cur, part): raise ValueError(f"'{path}' not found on model") cur = getattr(cur, part) name = parts[-1] if not hasattr(cur, name): raise ValueError(f"'{path}' not found on model") return cur, name, getattr(cur, name) def find_layer_container(model, layer_path: Optional[str]) -> Tuple[object, str, object]: if layer_path: parent, name, container = resolve_attr_with_parent(model, layer_path) return parent, name, container candidate_paths = [ "model.layers", # LLaMA, Mistral, Qwen2, Gemma "model.decoder.layers", # OPT "transformer.h", # GPT-2, GPT-J, Bloom, Falcon "transformer.blocks", # MPT "gpt_neox.layers", # GPT-NeoX "layers", # fallback ] for path in candidate_paths: candidate = resolve_attr(model, path) if candidate is None: continue try: list(candidate) except TypeError: continue parent, name, container = resolve_attr_with_parent(model, path) return parent, name, container raise ValueError( "Could not locate transformer layers. Pass --layer_path explicitly." ) def find_attention_module(layer: torch.nn.Module) -> torch.nn.Module: if hasattr(layer, "self_attn"): return getattr(layer, "self_attn") if hasattr(layer, "attn"): return getattr(layer, "attn") if hasattr(layer, "attention"): return getattr(layer, "attention") for _, module in layer.named_modules(): if all( hasattr(module, attr) for attr in ("q_proj", "k_proj", "v_proj", "o_proj") ): return module raise ValueError("Could not find attention module with q_proj/k_proj/v_proj/o_proj") def find_mlp_module(layer: torch.nn.Module) -> torch.nn.Module: if hasattr(layer, "mlp"): return getattr(layer, "mlp") for attr in ("feed_forward", "feedforward", "ffn", "ff"): if hasattr(layer, attr): return getattr(layer, attr) for _, module in layer.named_modules(): if all(hasattr(module, attr) for attr in ("gate_proj", "up_proj", "down_proj")): return module if all(hasattr(module, attr) for attr in ("fc1", "fc2")): return module if all( hasattr(module, attr) for attr in ("dense_h_to_4h", "dense_4h_to_h") ): return module if all(hasattr(module, attr) for attr in ("w1", "w2")): return module raise ValueError("Could not find MLP/FFN module on layer") def get_head_info( attn: torch.nn.Module, hidden_size: int, config ) -> Tuple[int, int, int]: num_heads = getattr(attn, "num_heads", None) if num_heads is None: num_heads = getattr(attn, "num_attention_heads", None) if num_heads is None and config is not None: num_heads = getattr( config, "num_attention_heads", getattr(config, "num_heads", getattr(config, "n_head", None)), ) num_key_value_heads = getattr(attn, "num_key_value_heads", None) if num_key_value_heads is None: num_key_value_heads = getattr(attn, "num_kv_heads", None) if num_key_value_heads is None and config is not None: num_key_value_heads = getattr( config, "num_key_value_heads", getattr(config, "num_kv_heads", getattr(config, "n_head_kv", None)), ) head_dim = getattr(attn, "head_dim", None) if head_dim is None and config is not None: head_dim = getattr(config, "head_dim", None) if num_heads is None: if hasattr(attn, "q_proj"): q_out = attn.q_proj.weight.shape[0] if head_dim is not None: num_heads = q_out // head_dim elif num_key_value_heads is not None and hasattr(attn, "k_proj"): k_out = attn.k_proj.weight.shape[0] head_dim = k_out // max(int(num_key_value_heads), 1) num_heads = q_out // head_dim if num_heads is None: raise ValueError( "Attention module missing num_heads/num_attention_heads; " "pass --layer_path or add config overrides." ) if num_key_value_heads is None: num_key_value_heads = num_heads if head_dim is None: head_dim = hidden_size // int(num_heads) if num_key_value_heads is None and hasattr(attn, "k_proj"): k_out = attn.k_proj.weight.shape[0] num_key_value_heads = k_out // int(head_dim) return int(num_heads), int(num_key_value_heads), int(head_dim) def cosine_cost_matrix( a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8 ) -> torch.Tensor: a_norm = a / (a.norm(dim=1, keepdim=True) + eps) b_norm = b / (b.norm(dim=1, keepdim=True) + eps) sim = a_norm @ b_norm.t() return 1.0 - sim def hungarian(cost: torch.Tensor) -> List[int]: # Kuhn-Munkres for square cost matrix (minimization). n = cost.size(0) u = [0.0] * (n + 1) v = [0.0] * (n + 1) p = [0] * (n + 1) way = [0] * (n + 1) for i in range(1, n + 1): p[0] = i j0 = 0 minv = [float("inf")] * (n + 1) used = [False] * (n + 1) while True: used[j0] = True i0 = p[j0] delta = float("inf") j1 = 0 for j in range(1, n + 1): if used[j]: continue cur = cost[i0 - 1, j - 1].item() - u[i0] - v[j] if cur < minv[j]: minv[j] = cur way[j] = j0 if minv[j] < delta: delta = minv[j] j1 = j for j in range(0, n + 1): if used[j]: u[p[j]] += delta v[j] -= delta else: minv[j] -= delta j0 = j1 if p[j0] == 0: break while True: j1 = way[j0] p[j0] = p[j1] j0 = j1 if j0 == 0: break assignment = [-1] * n for j in range(1, n + 1): if p[j] > 0: assignment[p[j] - 1] = j - 1 return assignment def compute_head_means( model, attn_i: torch.nn.Module, attn_j: torch.nn.Module, dataloader, device: str, hidden_size: int, ) -> Tuple[torch.Tensor, torch.Tensor, int, int, int]: num_heads_i, num_kv_i, head_dim_i = get_head_info(attn_i, hidden_size, model.config) num_heads_j, num_kv_j, head_dim_j = get_head_info(attn_j, hidden_size, model.config) if num_heads_i != num_heads_j or head_dim_i != head_dim_j: raise ValueError("Head counts or head_dim differ between layers; cannot align") sums_i = torch.zeros(num_heads_i, head_dim_i, device="cpu") sums_j = torch.zeros(num_heads_j, head_dim_j, device="cpu") count_i = [0] count_j = [0] def make_hook( sums: torch.Tensor, count_ref: List[int], num_heads: int, head_dim: int ): def hook(_module, inputs, _output): hidden = inputs[0].detach() if hidden.dim() != 3: return batch, seq, width = hidden.shape if width != num_heads * head_dim: return reshaped = hidden.view(batch, seq, num_heads, head_dim) sums.add_(reshaped.sum(dim=(0, 1)).float().cpu()) count_ref[0] += batch * seq return hook hook_i = attn_i.o_proj.register_forward_hook( make_hook(sums_i, count_i, num_heads_i, head_dim_i) ) hook_j = attn_j.o_proj.register_forward_hook( make_hook(sums_j, count_j, num_heads_j, head_dim_j) ) model.eval() iterator = dataloader if tqdm is not None and _tqdm_enabled(): iterator = tqdm(dataloader, desc="Head stats", unit="batch") with torch.no_grad(): for batch in iterator: input_ids = batch[0].to(device) _ = model(input_ids=input_ids) hook_i.remove() hook_j.remove() if count_i[0] == 0 or count_j[0] == 0: raise RuntimeError("Failed to capture head outputs; check attention modules.") mean_i = sums_i / count_i[0] mean_j = sums_j / count_j[0] return mean_i, mean_j, num_heads_i, num_kv_i, head_dim_i def build_head_permutation( mean_i: torch.Tensor, mean_j: torch.Tensor, num_heads: int, num_kv_heads: int, eps: float, ) -> List[int]: group_size = num_heads // num_kv_heads if group_size * num_kv_heads != num_heads: raise ValueError("num_heads must be divisible by num_key_value_heads") perm = list(range(num_heads)) for g in range(num_kv_heads): start = g * group_size end = start + group_size cost = cosine_cost_matrix(mean_i[start:end], mean_j[start:end], eps=eps) assignment = hungarian(cost) for local_idx, match in enumerate(assignment): perm[start + local_idx] = start + match return perm def permute_attention_heads( attn: torch.nn.Module, perm: List[int], num_heads: int, num_kv_heads: int, head_dim: int, ) -> None: hidden_size = num_heads * head_dim def permute_out_proj_weight(weight: torch.Tensor) -> torch.Tensor: out_features, in_features = weight.shape if in_features != hidden_size: raise ValueError( "o_proj in_features ({} ) != num_heads*head_dim ({})".format( in_features, hidden_size ) ) reshaped = weight.view(out_features, num_heads, head_dim) reshaped = reshaped[:, perm, :] return reshaped.reshape(out_features, in_features) def permute_proj_weight(weight: torch.Tensor) -> torch.Tensor: out_features, in_features = weight.shape if out_features != hidden_size: raise ValueError( "proj out_features ({}) != num_heads*head_dim ({})".format( out_features, hidden_size ) ) reshaped = weight.view(num_heads, head_dim, in_features) reshaped = reshaped[perm, :, :] return reshaped.reshape(out_features, in_features) def permute_proj_bias(bias: Optional[torch.Tensor]) -> Optional[torch.Tensor]: if bias is None: return None reshaped = bias.view(num_heads, head_dim) reshaped = reshaped[perm, :] return reshaped.reshape(num_heads * head_dim) with torch.no_grad(): attn.q_proj.weight.copy_(permute_proj_weight(attn.q_proj.weight)) if attn.q_proj.bias is not None: attn.q_proj.bias.copy_(permute_proj_bias(attn.q_proj.bias)) if num_kv_heads == num_heads: attn.k_proj.weight.copy_(permute_proj_weight(attn.k_proj.weight)) if attn.k_proj.bias is not None: attn.k_proj.bias.copy_(permute_proj_bias(attn.k_proj.bias)) attn.v_proj.weight.copy_(permute_proj_weight(attn.v_proj.weight)) if attn.v_proj.bias is not None: attn.v_proj.bias.copy_(permute_proj_bias(attn.v_proj.bias)) attn.o_proj.weight.copy_(permute_out_proj_weight(attn.o_proj.weight)) def compute_fisher( model, layer_a: torch.nn.Module, layer_b: torch.nn.Module, dataloader, fisher_mode: str, device: str, ) -> Tuple[List[Dict[str, object]], int, List[Dict[str, int]]]: for param in model.parameters(): param.requires_grad_(False) for layer in (layer_a, layer_b): for param in layer.parameters(): param.requires_grad_(True) fisher_sums: List[Dict[str, object]] = [] param_numels: List[Dict[str, int]] = [] for layer in (layer_a, layer_b): layer_sums: Dict[str, object] = {} layer_numels: Dict[str, int] = {} for name, param in layer.named_parameters(): if not param.requires_grad: continue if fisher_mode == "param": layer_sums[name] = torch.zeros_like( param, dtype=torch.float32, device="cpu" ) else: layer_sums[name] = 0.0 layer_numels[name] = param.numel() fisher_sums.append(layer_sums) param_numels.append(layer_numels) num_batches = 0 model.eval() iterator = dataloader if tqdm is not None and _tqdm_enabled(): iterator = tqdm(dataloader, desc="Fisher", unit="batch") for batch in iterator: input_ids = batch[0].to(device) outputs = model(input_ids=input_ids, labels=input_ids) loss = outputs.loss loss.backward() for layer_idx, layer in enumerate((layer_a, layer_b)): layer_sums = fisher_sums[layer_idx] for name, param in layer.named_parameters(): if not param.requires_grad or param.grad is None: continue grad_sq = param.grad.detach().float().pow(2) if fisher_mode == "param": layer_sums[name] += grad_sq.cpu() else: layer_sums[name] += float(grad_sq.sum().item()) model.zero_grad(set_to_none=True) num_batches += 1 if num_batches == 0: raise RuntimeError("No batches processed; check dataset or text inputs.") return fisher_sums, num_batches, param_numels def merge_layers( layer_a: torch.nn.Module, layer_b: torch.nn.Module, fisher_a: Dict[str, object], fisher_b: Dict[str, object], num_batches: int, numels_a: Dict[str, int], numels_b: Dict[str, int], fisher_mode: str, eps: float, ) -> int: merged = 0 params_b = {name: param for name, param in layer_b.named_parameters()} with torch.no_grad(): for name, param_a in layer_a.named_parameters(): param_b = params_b.get(name) if param_b is None or param_b.shape != param_a.shape: continue if fisher_mode == "param": fa = fisher_a[name] / num_batches fb = fisher_b[name] / num_batches # Fisher tensors are accumulated on CPU to save VRAM; move to the # parameter device for the actual merge. if isinstance(fa, torch.Tensor) and fa.device != param_a.device: fa = fa.to(param_a.device) if isinstance(fb, torch.Tensor) and fb.device != param_a.device: fb = fb.to(param_a.device) denom = fa + fb denom_mean = float(denom.mean().item()) if denom_mean <= eps: merged_param = 0.5 * (param_a.float() + param_b.float()) else: merged_param = (fa * param_a.float() + fb * param_b.float()) / ( denom + eps ) else: fa = fisher_a[name] / (num_batches * numels_a[name]) fb = fisher_b[name] / (num_batches * numels_b[name]) denom = fa + fb if denom <= eps: merged_param = 0.5 * (param_a.float() + param_b.float()) else: merged_param = ( fa * param_a.float() + fb * param_b.float() ) / (denom + eps) param_a.copy_(merged_param.to(dtype=param_a.dtype)) merged += 1 return merged def merge_layers_with_gates( layer_a: torch.nn.Module, layer_b: torch.nn.Module, gates: Dict[str, torch.Tensor], ) -> int: """Merge layer_b into layer_a using precomputed gates. Each gate is a lambda in [0, 1] that mixes parameters as: W = lambda * W_a + (1 - lambda) * W_b Gate tensors may be scalars (per-tensor gating) or full tensors matching the parameter shape (per-parameter gating). """ merged = 0 params_b = {name: param for name, param in layer_b.named_parameters()} with torch.no_grad(): for name, param_a in layer_a.named_parameters(): gate = gates.get(name) if gate is None: continue param_b = params_b.get(name) if param_b is None or param_b.shape != param_a.shape: continue lam = gate if not isinstance(lam, torch.Tensor): lam = torch.tensor(lam) if lam.device != param_a.device: lam = lam.to(param_a.device) merged_param = lam * param_a.float() + (1.0 - lam) * param_b.float() param_a.copy_(merged_param.to(dtype=param_a.dtype)) merged += 1 return merged def drop_layer(container: object, index: int) -> object: if isinstance(container, torch.nn.ModuleList): return torch.nn.ModuleList( [layer for idx, layer in enumerate(container) if idx != index] ) if isinstance(container, list): del container[index] return container raise TypeError("Layer container must be ModuleList or list") def decrement_config(config) -> None: for attr in ("num_hidden_layers", "n_layer", "num_layers"): if hasattr(config, attr): value = getattr(config, attr) if isinstance(value, int) and value > 0: setattr(config, attr, value - 1) normalize_config(config) def normalize_config(config) -> None: num_hidden_layers = getattr(config, "num_hidden_layers", None) layer_types = getattr(config, "layer_types", None) if ( isinstance(num_hidden_layers, int) and num_hidden_layers >= 0 and isinstance(layer_types, (list, tuple)) and len(layer_types) != num_hidden_layers ): config.layer_types = list(layer_types[:num_hidden_layers]) def find_colon_modules(module: torch.nn.Module) -> List[str]: found: List[str] = [] for name, child in module._modules.items(): if ":" in name: found.append(name) if isinstance(child, torch.nn.Module): for sub in find_colon_modules(child): found.append(f"{name}.{sub}") return found def get_norm_pair( layer: torch.nn.Module, ) -> Tuple[ Optional[torch.nn.Module], Optional[torch.nn.Module], Tuple[Optional[str], Optional[str]], ]: candidates = [ ("input_layernorm", "post_attention_layernorm"), ("ln_1", "ln_2"), ("norm1", "norm2"), ("norm_1", "norm_2"), ("layer_norm_1", "layer_norm_2"), ("self_attn_layer_norm", "final_layer_norm"), ] for n1, n2 in candidates: if hasattr(layer, n1) and hasattr(layer, n2): return getattr(layer, n1), getattr(layer, n2), (n1, n2) return None, None, (None, None) def clone_state_dict(module: torch.nn.Module) -> Dict[str, torch.Tensor]: return {k: v.detach().clone() for k, v in module.state_dict().items()} def apply_norm_policy( layer: torch.nn.Module, norm_policy: str, norm1_state: Optional[Dict[str, torch.Tensor]], norm2_state: Optional[Dict[str, torch.Tensor]], norm_names: Tuple[Optional[str], Optional[str]], ) -> None: norm1, norm2, _ = get_norm_pair(layer) if norm_policy in {"copy_n1", "hybrid"} and norm1_state is not None and norm1 is not None: norm1.load_state_dict(norm1_state) if norm_policy == "copy_n1_n2" and norm2_state is not None and norm2 is not None: norm2.load_state_dict(norm2_state)