| |
| """Model and layer helpers for fuse_layers.""" |
|
|
| import os |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
|
|
| try: |
| from tqdm import tqdm |
| except Exception: |
| tqdm = None |
|
|
|
|
| def _tqdm_enabled() -> bool: |
| value = os.environ.get("DISABLE_TQDM", os.environ.get("TQDM_DISABLE", "0")) |
| return value.strip().lower() not in {"1", "true", "yes", "on"} |
|
|
|
|
| def get_dtype(dtype: str): |
| if dtype == "auto": |
| return None |
| if dtype == "float16": |
| return torch.float16 |
| if dtype == "bfloat16": |
| return torch.bfloat16 |
| return torch.float32 |
|
|
|
|
| def resolve_attr(root: object, path: str) -> Optional[object]: |
| cur = root |
| for part in path.split("."): |
| if not hasattr(cur, part): |
| return None |
| cur = getattr(cur, part) |
| return cur |
|
|
|
|
| def resolve_attr_with_parent(root: object, path: str) -> Tuple[object, str, object]: |
| parts = path.split(".") |
| cur = root |
| for part in parts[:-1]: |
| if not hasattr(cur, part): |
| raise ValueError(f"'{path}' not found on model") |
| cur = getattr(cur, part) |
| name = parts[-1] |
| if not hasattr(cur, name): |
| raise ValueError(f"'{path}' not found on model") |
| return cur, name, getattr(cur, name) |
|
|
|
|
| def find_layer_container(model, layer_path: Optional[str]) -> Tuple[object, str, object]: |
| if layer_path: |
| parent, name, container = resolve_attr_with_parent(model, layer_path) |
| return parent, name, container |
|
|
| candidate_paths = [ |
| "model.layers", |
| "model.decoder.layers", |
| "transformer.h", |
| "transformer.blocks", |
| "gpt_neox.layers", |
| "layers", |
| ] |
| for path in candidate_paths: |
| candidate = resolve_attr(model, path) |
| if candidate is None: |
| continue |
| try: |
| list(candidate) |
| except TypeError: |
| continue |
| parent, name, container = resolve_attr_with_parent(model, path) |
| return parent, name, container |
|
|
| raise ValueError( |
| "Could not locate transformer layers. Pass --layer_path explicitly." |
| ) |
|
|
|
|
| def find_attention_module(layer: torch.nn.Module) -> torch.nn.Module: |
| if hasattr(layer, "self_attn"): |
| return getattr(layer, "self_attn") |
| if hasattr(layer, "attn"): |
| return getattr(layer, "attn") |
| if hasattr(layer, "attention"): |
| return getattr(layer, "attention") |
| for _, module in layer.named_modules(): |
| if all( |
| hasattr(module, attr) for attr in ("q_proj", "k_proj", "v_proj", "o_proj") |
| ): |
| return module |
| raise ValueError("Could not find attention module with q_proj/k_proj/v_proj/o_proj") |
|
|
|
|
| def find_mlp_module(layer: torch.nn.Module) -> torch.nn.Module: |
| if hasattr(layer, "mlp"): |
| return getattr(layer, "mlp") |
| for attr in ("feed_forward", "feedforward", "ffn", "ff"): |
| if hasattr(layer, attr): |
| return getattr(layer, attr) |
| for _, module in layer.named_modules(): |
| if all(hasattr(module, attr) for attr in ("gate_proj", "up_proj", "down_proj")): |
| return module |
| if all(hasattr(module, attr) for attr in ("fc1", "fc2")): |
| return module |
| if all( |
| hasattr(module, attr) |
| for attr in ("dense_h_to_4h", "dense_4h_to_h") |
| ): |
| return module |
| if all(hasattr(module, attr) for attr in ("w1", "w2")): |
| return module |
| raise ValueError("Could not find MLP/FFN module on layer") |
|
|
|
|
| def get_head_info( |
| attn: torch.nn.Module, hidden_size: int, config |
| ) -> Tuple[int, int, int]: |
| num_heads = getattr(attn, "num_heads", None) |
| if num_heads is None: |
| num_heads = getattr(attn, "num_attention_heads", None) |
| if num_heads is None and config is not None: |
| num_heads = getattr( |
| config, |
| "num_attention_heads", |
| getattr(config, "num_heads", getattr(config, "n_head", None)), |
| ) |
|
|
| num_key_value_heads = getattr(attn, "num_key_value_heads", None) |
| if num_key_value_heads is None: |
| num_key_value_heads = getattr(attn, "num_kv_heads", None) |
| if num_key_value_heads is None and config is not None: |
| num_key_value_heads = getattr( |
| config, |
| "num_key_value_heads", |
| getattr(config, "num_kv_heads", getattr(config, "n_head_kv", None)), |
| ) |
|
|
| head_dim = getattr(attn, "head_dim", None) |
| if head_dim is None and config is not None: |
| head_dim = getattr(config, "head_dim", None) |
|
|
| if num_heads is None: |
| if hasattr(attn, "q_proj"): |
| q_out = attn.q_proj.weight.shape[0] |
| if head_dim is not None: |
| num_heads = q_out // head_dim |
| elif num_key_value_heads is not None and hasattr(attn, "k_proj"): |
| k_out = attn.k_proj.weight.shape[0] |
| head_dim = k_out // max(int(num_key_value_heads), 1) |
| num_heads = q_out // head_dim |
| if num_heads is None: |
| raise ValueError( |
| "Attention module missing num_heads/num_attention_heads; " |
| "pass --layer_path or add config overrides." |
| ) |
|
|
| if num_key_value_heads is None: |
| num_key_value_heads = num_heads |
|
|
| if head_dim is None: |
| head_dim = hidden_size // int(num_heads) |
|
|
| if num_key_value_heads is None and hasattr(attn, "k_proj"): |
| k_out = attn.k_proj.weight.shape[0] |
| num_key_value_heads = k_out // int(head_dim) |
|
|
| return int(num_heads), int(num_key_value_heads), int(head_dim) |
|
|
|
|
| def cosine_cost_matrix( |
| a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8 |
| ) -> torch.Tensor: |
| a_norm = a / (a.norm(dim=1, keepdim=True) + eps) |
| b_norm = b / (b.norm(dim=1, keepdim=True) + eps) |
| sim = a_norm @ b_norm.t() |
| return 1.0 - sim |
|
|
|
|
| def hungarian(cost: torch.Tensor) -> List[int]: |
| |
| n = cost.size(0) |
| u = [0.0] * (n + 1) |
| v = [0.0] * (n + 1) |
| p = [0] * (n + 1) |
| way = [0] * (n + 1) |
|
|
| for i in range(1, n + 1): |
| p[0] = i |
| j0 = 0 |
| minv = [float("inf")] * (n + 1) |
| used = [False] * (n + 1) |
| while True: |
| used[j0] = True |
| i0 = p[j0] |
| delta = float("inf") |
| j1 = 0 |
| for j in range(1, n + 1): |
| if used[j]: |
| continue |
| cur = cost[i0 - 1, j - 1].item() - u[i0] - v[j] |
| if cur < minv[j]: |
| minv[j] = cur |
| way[j] = j0 |
| if minv[j] < delta: |
| delta = minv[j] |
| j1 = j |
| for j in range(0, n + 1): |
| if used[j]: |
| u[p[j]] += delta |
| v[j] -= delta |
| else: |
| minv[j] -= delta |
| j0 = j1 |
| if p[j0] == 0: |
| break |
| while True: |
| j1 = way[j0] |
| p[j0] = p[j1] |
| j0 = j1 |
| if j0 == 0: |
| break |
|
|
| assignment = [-1] * n |
| for j in range(1, n + 1): |
| if p[j] > 0: |
| assignment[p[j] - 1] = j - 1 |
| return assignment |
|
|
|
|
| def compute_head_means( |
| model, |
| attn_i: torch.nn.Module, |
| attn_j: torch.nn.Module, |
| dataloader, |
| device: str, |
| hidden_size: int, |
| ) -> Tuple[torch.Tensor, torch.Tensor, int, int, int]: |
| num_heads_i, num_kv_i, head_dim_i = get_head_info(attn_i, hidden_size, model.config) |
| num_heads_j, num_kv_j, head_dim_j = get_head_info(attn_j, hidden_size, model.config) |
| if num_heads_i != num_heads_j or head_dim_i != head_dim_j: |
| raise ValueError("Head counts or head_dim differ between layers; cannot align") |
|
|
| sums_i = torch.zeros(num_heads_i, head_dim_i, device="cpu") |
| sums_j = torch.zeros(num_heads_j, head_dim_j, device="cpu") |
| count_i = [0] |
| count_j = [0] |
|
|
| def make_hook( |
| sums: torch.Tensor, count_ref: List[int], num_heads: int, head_dim: int |
| ): |
| def hook(_module, inputs, _output): |
| hidden = inputs[0].detach() |
| if hidden.dim() != 3: |
| return |
| batch, seq, width = hidden.shape |
| if width != num_heads * head_dim: |
| return |
| reshaped = hidden.view(batch, seq, num_heads, head_dim) |
| sums.add_(reshaped.sum(dim=(0, 1)).float().cpu()) |
| count_ref[0] += batch * seq |
|
|
| return hook |
|
|
| hook_i = attn_i.o_proj.register_forward_hook( |
| make_hook(sums_i, count_i, num_heads_i, head_dim_i) |
| ) |
| hook_j = attn_j.o_proj.register_forward_hook( |
| make_hook(sums_j, count_j, num_heads_j, head_dim_j) |
| ) |
|
|
| model.eval() |
| iterator = dataloader |
| if tqdm is not None and _tqdm_enabled(): |
| iterator = tqdm(dataloader, desc="Head stats", unit="batch") |
| with torch.no_grad(): |
| for batch in iterator: |
| input_ids = batch[0].to(device) |
| _ = model(input_ids=input_ids) |
|
|
| hook_i.remove() |
| hook_j.remove() |
|
|
| if count_i[0] == 0 or count_j[0] == 0: |
| raise RuntimeError("Failed to capture head outputs; check attention modules.") |
|
|
| mean_i = sums_i / count_i[0] |
| mean_j = sums_j / count_j[0] |
| return mean_i, mean_j, num_heads_i, num_kv_i, head_dim_i |
|
|
|
|
| def build_head_permutation( |
| mean_i: torch.Tensor, |
| mean_j: torch.Tensor, |
| num_heads: int, |
| num_kv_heads: int, |
| eps: float, |
| ) -> List[int]: |
| group_size = num_heads // num_kv_heads |
| if group_size * num_kv_heads != num_heads: |
| raise ValueError("num_heads must be divisible by num_key_value_heads") |
|
|
| perm = list(range(num_heads)) |
| for g in range(num_kv_heads): |
| start = g * group_size |
| end = start + group_size |
| cost = cosine_cost_matrix(mean_i[start:end], mean_j[start:end], eps=eps) |
| assignment = hungarian(cost) |
| for local_idx, match in enumerate(assignment): |
| perm[start + local_idx] = start + match |
| return perm |
|
|
|
|
| def permute_attention_heads( |
| attn: torch.nn.Module, |
| perm: List[int], |
| num_heads: int, |
| num_kv_heads: int, |
| head_dim: int, |
| ) -> None: |
| hidden_size = num_heads * head_dim |
|
|
| def permute_out_proj_weight(weight: torch.Tensor) -> torch.Tensor: |
| out_features, in_features = weight.shape |
| if in_features != hidden_size: |
| raise ValueError( |
| "o_proj in_features ({} ) != num_heads*head_dim ({})".format( |
| in_features, hidden_size |
| ) |
| ) |
| reshaped = weight.view(out_features, num_heads, head_dim) |
| reshaped = reshaped[:, perm, :] |
| return reshaped.reshape(out_features, in_features) |
|
|
| def permute_proj_weight(weight: torch.Tensor) -> torch.Tensor: |
| out_features, in_features = weight.shape |
| if out_features != hidden_size: |
| raise ValueError( |
| "proj out_features ({}) != num_heads*head_dim ({})".format( |
| out_features, hidden_size |
| ) |
| ) |
| reshaped = weight.view(num_heads, head_dim, in_features) |
| reshaped = reshaped[perm, :, :] |
| return reshaped.reshape(out_features, in_features) |
|
|
| def permute_proj_bias(bias: Optional[torch.Tensor]) -> Optional[torch.Tensor]: |
| if bias is None: |
| return None |
| reshaped = bias.view(num_heads, head_dim) |
| reshaped = reshaped[perm, :] |
| return reshaped.reshape(num_heads * head_dim) |
|
|
| with torch.no_grad(): |
| attn.q_proj.weight.copy_(permute_proj_weight(attn.q_proj.weight)) |
| if attn.q_proj.bias is not None: |
| attn.q_proj.bias.copy_(permute_proj_bias(attn.q_proj.bias)) |
|
|
| if num_kv_heads == num_heads: |
| attn.k_proj.weight.copy_(permute_proj_weight(attn.k_proj.weight)) |
| if attn.k_proj.bias is not None: |
| attn.k_proj.bias.copy_(permute_proj_bias(attn.k_proj.bias)) |
| attn.v_proj.weight.copy_(permute_proj_weight(attn.v_proj.weight)) |
| if attn.v_proj.bias is not None: |
| attn.v_proj.bias.copy_(permute_proj_bias(attn.v_proj.bias)) |
|
|
| attn.o_proj.weight.copy_(permute_out_proj_weight(attn.o_proj.weight)) |
|
|
|
|
| def compute_fisher( |
| model, |
| layer_a: torch.nn.Module, |
| layer_b: torch.nn.Module, |
| dataloader, |
| fisher_mode: str, |
| device: str, |
| ) -> Tuple[List[Dict[str, object]], int, List[Dict[str, int]]]: |
| for param in model.parameters(): |
| param.requires_grad_(False) |
| for layer in (layer_a, layer_b): |
| for param in layer.parameters(): |
| param.requires_grad_(True) |
|
|
| fisher_sums: List[Dict[str, object]] = [] |
| param_numels: List[Dict[str, int]] = [] |
| for layer in (layer_a, layer_b): |
| layer_sums: Dict[str, object] = {} |
| layer_numels: Dict[str, int] = {} |
| for name, param in layer.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if fisher_mode == "param": |
| layer_sums[name] = torch.zeros_like( |
| param, dtype=torch.float32, device="cpu" |
| ) |
| else: |
| layer_sums[name] = 0.0 |
| layer_numels[name] = param.numel() |
| fisher_sums.append(layer_sums) |
| param_numels.append(layer_numels) |
|
|
| num_batches = 0 |
| model.eval() |
| iterator = dataloader |
| if tqdm is not None and _tqdm_enabled(): |
| iterator = tqdm(dataloader, desc="Fisher", unit="batch") |
| for batch in iterator: |
| input_ids = batch[0].to(device) |
| outputs = model(input_ids=input_ids, labels=input_ids) |
| loss = outputs.loss |
| loss.backward() |
| for layer_idx, layer in enumerate((layer_a, layer_b)): |
| layer_sums = fisher_sums[layer_idx] |
| for name, param in layer.named_parameters(): |
| if not param.requires_grad or param.grad is None: |
| continue |
| grad_sq = param.grad.detach().float().pow(2) |
| if fisher_mode == "param": |
| layer_sums[name] += grad_sq.cpu() |
| else: |
| layer_sums[name] += float(grad_sq.sum().item()) |
| model.zero_grad(set_to_none=True) |
| num_batches += 1 |
|
|
| if num_batches == 0: |
| raise RuntimeError("No batches processed; check dataset or text inputs.") |
|
|
| return fisher_sums, num_batches, param_numels |
|
|
|
|
| def merge_layers( |
| layer_a: torch.nn.Module, |
| layer_b: torch.nn.Module, |
| fisher_a: Dict[str, object], |
| fisher_b: Dict[str, object], |
| num_batches: int, |
| numels_a: Dict[str, int], |
| numels_b: Dict[str, int], |
| fisher_mode: str, |
| eps: float, |
| ) -> int: |
| merged = 0 |
| params_b = {name: param for name, param in layer_b.named_parameters()} |
| with torch.no_grad(): |
| for name, param_a in layer_a.named_parameters(): |
| param_b = params_b.get(name) |
| if param_b is None or param_b.shape != param_a.shape: |
| continue |
| if fisher_mode == "param": |
| fa = fisher_a[name] / num_batches |
| fb = fisher_b[name] / num_batches |
| |
| |
| if isinstance(fa, torch.Tensor) and fa.device != param_a.device: |
| fa = fa.to(param_a.device) |
| if isinstance(fb, torch.Tensor) and fb.device != param_a.device: |
| fb = fb.to(param_a.device) |
| denom = fa + fb |
| denom_mean = float(denom.mean().item()) |
| if denom_mean <= eps: |
| merged_param = 0.5 * (param_a.float() + param_b.float()) |
| else: |
| merged_param = (fa * param_a.float() + fb * param_b.float()) / ( |
| denom + eps |
| ) |
| else: |
| fa = fisher_a[name] / (num_batches * numels_a[name]) |
| fb = fisher_b[name] / (num_batches * numels_b[name]) |
| denom = fa + fb |
| if denom <= eps: |
| merged_param = 0.5 * (param_a.float() + param_b.float()) |
| else: |
| merged_param = ( |
| fa * param_a.float() + fb * param_b.float() |
| ) / (denom + eps) |
| param_a.copy_(merged_param.to(dtype=param_a.dtype)) |
| merged += 1 |
| return merged |
|
|
|
|
| def merge_layers_with_gates( |
| layer_a: torch.nn.Module, |
| layer_b: torch.nn.Module, |
| gates: Dict[str, torch.Tensor], |
| ) -> int: |
| """Merge layer_b into layer_a using precomputed gates. |
| |
| Each gate is a lambda in [0, 1] that mixes parameters as: |
| W = lambda * W_a + (1 - lambda) * W_b |
| |
| Gate tensors may be scalars (per-tensor gating) or full tensors matching the |
| parameter shape (per-parameter gating). |
| """ |
| merged = 0 |
| params_b = {name: param for name, param in layer_b.named_parameters()} |
| with torch.no_grad(): |
| for name, param_a in layer_a.named_parameters(): |
| gate = gates.get(name) |
| if gate is None: |
| continue |
| param_b = params_b.get(name) |
| if param_b is None or param_b.shape != param_a.shape: |
| continue |
| lam = gate |
| if not isinstance(lam, torch.Tensor): |
| lam = torch.tensor(lam) |
| if lam.device != param_a.device: |
| lam = lam.to(param_a.device) |
| merged_param = lam * param_a.float() + (1.0 - lam) * param_b.float() |
| param_a.copy_(merged_param.to(dtype=param_a.dtype)) |
| merged += 1 |
| return merged |
|
|
|
|
| def drop_layer(container: object, index: int) -> object: |
| if isinstance(container, torch.nn.ModuleList): |
| return torch.nn.ModuleList( |
| [layer for idx, layer in enumerate(container) if idx != index] |
| ) |
| if isinstance(container, list): |
| del container[index] |
| return container |
| raise TypeError("Layer container must be ModuleList or list") |
|
|
|
|
| def decrement_config(config) -> None: |
| for attr in ("num_hidden_layers", "n_layer", "num_layers"): |
| if hasattr(config, attr): |
| value = getattr(config, attr) |
| if isinstance(value, int) and value > 0: |
| setattr(config, attr, value - 1) |
| normalize_config(config) |
|
|
|
|
| def normalize_config(config) -> None: |
| num_hidden_layers = getattr(config, "num_hidden_layers", None) |
| layer_types = getattr(config, "layer_types", None) |
| if ( |
| isinstance(num_hidden_layers, int) |
| and num_hidden_layers >= 0 |
| and isinstance(layer_types, (list, tuple)) |
| and len(layer_types) != num_hidden_layers |
| ): |
| config.layer_types = list(layer_types[:num_hidden_layers]) |
|
|
|
|
| def find_colon_modules(module: torch.nn.Module) -> List[str]: |
| found: List[str] = [] |
| for name, child in module._modules.items(): |
| if ":" in name: |
| found.append(name) |
| if isinstance(child, torch.nn.Module): |
| for sub in find_colon_modules(child): |
| found.append(f"{name}.{sub}") |
| return found |
|
|
|
|
| def get_norm_pair( |
| layer: torch.nn.Module, |
| ) -> Tuple[ |
| Optional[torch.nn.Module], |
| Optional[torch.nn.Module], |
| Tuple[Optional[str], Optional[str]], |
| ]: |
| candidates = [ |
| ("input_layernorm", "post_attention_layernorm"), |
| ("ln_1", "ln_2"), |
| ("norm1", "norm2"), |
| ("norm_1", "norm_2"), |
| ("layer_norm_1", "layer_norm_2"), |
| ("self_attn_layer_norm", "final_layer_norm"), |
| ] |
| for n1, n2 in candidates: |
| if hasattr(layer, n1) and hasattr(layer, n2): |
| return getattr(layer, n1), getattr(layer, n2), (n1, n2) |
| return None, None, (None, None) |
|
|
|
|
| def clone_state_dict(module: torch.nn.Module) -> Dict[str, torch.Tensor]: |
| return {k: v.detach().clone() for k, v in module.state_dict().items()} |
|
|
|
|
| def apply_norm_policy( |
| layer: torch.nn.Module, |
| norm_policy: str, |
| norm1_state: Optional[Dict[str, torch.Tensor]], |
| norm2_state: Optional[Dict[str, torch.Tensor]], |
| norm_names: Tuple[Optional[str], Optional[str]], |
| ) -> None: |
| norm1, norm2, _ = get_norm_pair(layer) |
| if norm_policy in {"copy_n1", "hybrid"} and norm1_state is not None and norm1 is not None: |
| norm1.load_state_dict(norm1_state) |
| if norm_policy == "copy_n1_n2" and norm2_state is not None and norm2 is not None: |
| norm2.load_state_dict(norm2_state) |
|
|