import logging import math import os from typing import List, Optional, Tuple import torch from torch import no_grad from torch.utils.data import DataLoader from accelerate import Accelerator from tqdm import tqdm from .utils import prepare_calibration_input from .wrapper import HiddenStatesRecordWrapper logger = logging.getLogger(__name__) _REGULARIZATION_EPS = 1e-5 def _matrix_inverse_sqrt(matrix: torch.Tensor, epsilon: float = 1e-9) -> torch.Tensor: """Compute the inverse square root of a symmetric matrix via eigendecomposition.""" eigvals, eigvecs = torch.linalg.eigh(matrix.to(torch.float32)) inv_sqrt = 1.0 / (torch.sqrt(torch.clamp(eigvals, min=0.0)) + epsilon) inv_sqrt_mat = eigvecs @ torch.diag(inv_sqrt) @ eigvecs.transpose(-2, -1) return inv_sqrt_mat.to(matrix.dtype) def _maybe_get(sequence: Optional[List[Optional[torch.Tensor]]], idx: int) -> Optional[torch.Tensor]: if sequence is None: return None return sequence[idx] def _call_layer_forward( layer: torch.nn.Module, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.Tensor], cache_position: Optional[torch.Tensor], model_type: Optional[str], ) -> torch.Tensor: """Run a single transformer block on calibration activations.""" kwargs = {} if attention_mask is not None: kwargs["attention_mask"] = attention_mask if position_ids is not None: kwargs["position_ids"] = position_ids if cache_position is not None and model_type in {"llama", "mistral"}: kwargs["cache_position"] = cache_position outputs = layer(hidden_state, **kwargs) return outputs[0] if isinstance(outputs, (tuple, list)) else outputs def _compute_cova_matrices_iterative_dist( X_list: List[torch.Tensor], Y_list: List[torch.Tensor], accelerator: Accelerator, ): """Compute first and second-order moments in a distributed-friendly way.""" device = accelerator.device hidden_dim = X_list[0].shape[-1] X_sum_local = torch.zeros(hidden_dim, dtype=torch.float64) Y_sum_local = torch.zeros(hidden_dim, dtype=torch.float64) total_tokens_local = 0 for x in X_list: x_flat = x.view(-1, hidden_dim).to(dtype=torch.float64) X_sum_local += x_flat.sum(dim=0) total_tokens_local += x_flat.shape[0] for y in Y_list: y_flat = y.view(-1, hidden_dim).to(dtype=torch.float64) Y_sum_local += y_flat.sum(dim=0) X_sum_global = accelerator.reduce(X_sum_local.to(device), reduction="sum") Y_sum_global = accelerator.reduce(Y_sum_local.to(device), reduction="sum") total_tokens_tensor = torch.tensor(total_tokens_local, device=device, dtype=torch.float64) total_tokens_global = accelerator.reduce(total_tokens_tensor, reduction="sum").item() if total_tokens_global <= 1: raise RuntimeError("Not enough calibration tokens to compute covariance matrices.") X_mean = (X_sum_global / total_tokens_global).to(torch.float32) Y_mean = (Y_sum_global / total_tokens_global).to(torch.float32) Cxx_local = torch.zeros((hidden_dim, hidden_dim), device=device, dtype=torch.float64) Cyy_local = torch.zeros_like(Cxx_local) Cxy_local = torch.zeros_like(Cxx_local) X_mean64 = X_mean.to(device=device, dtype=torch.float64) Y_mean64 = Y_mean.to(device=device, dtype=torch.float64) for x, y in zip(X_list, Y_list): x_centered = x.view(-1, hidden_dim).to(device=device, dtype=torch.float64) - X_mean64 y_centered = y.view(-1, hidden_dim).to(device=device, dtype=torch.float64) - Y_mean64 Cxx_local += x_centered.T @ x_centered Cyy_local += y_centered.T @ y_centered Cxy_local += x_centered.T @ y_centered denom = float(total_tokens_global - 1) Cxx_global = accelerator.reduce(Cxx_local, reduction="sum") / denom Cyy_global = accelerator.reduce(Cyy_local, reduction="sum") / denom Cxy_global = accelerator.reduce(Cxy_local, reduction="sum") / denom Cxx = Cxx_global.to(torch.float32) Cyy = Cyy_global.to(torch.float32) Cxy = Cxy_global.to(torch.float32) return X_mean, Y_mean, Cxx, Cyy, Cxy def compute_cca( X_list: List[torch.Tensor], Y_list: List[torch.Tensor], accelerator: Accelerator, regularization: float = _REGULARIZATION_EPS, ) -> torch.Tensor: """Compute canonical correlations following the NBL formulation.""" device = accelerator.device _, _, Cxx, Cyy, Cxy = _compute_cova_matrices_iterative_dist(X_list, Y_list, accelerator) eye_x = torch.eye(Cxx.size(0), device=device, dtype=Cxx.dtype) eye_y = torch.eye(Cyy.size(0), device=device, dtype=Cyy.dtype) Cxx_reg = Cxx + regularization * eye_x Cyy_reg = Cyy + regularization * eye_y Cxx_inv_sqrt = _matrix_inverse_sqrt(Cxx_reg) Cyy_inv_sqrt = _matrix_inverse_sqrt(Cyy_reg) corr_matrix = Cyy_inv_sqrt @ Cxy @ Cxx_inv_sqrt _, singular_values, _ = torch.linalg.svd(corr_matrix, full_matrices=False) correlations = torch.clamp(singular_values.real, min=0.0, max=1.0) return correlations def _collect_layer_calibration( layer: torch.nn.Module, num_samples: int, inputs: List[torch.Tensor], attention_mask: Optional[List[Optional[torch.Tensor]]], position_ids: Optional[List[Optional[torch.Tensor]]], cache_position: Optional[List[Optional[torch.Tensor]]], model_type: Optional[str], ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: """Capture pre-layernorm inputs, normalized activations and attention outputs with lightweight forward hooks.""" module_pre_norm = layer.input_layernorm module_attn = layer.self_attn wrapped_pre_norm = HiddenStatesRecordWrapper(module_pre_norm, record_input=True, record_output=True) wrapped_attn = HiddenStatesRecordWrapper(module_attn, record_input=False, record_output=True) def pre_norm_hook(_, hook_inputs, output): inp = hook_inputs[0] if isinstance(hook_inputs, tuple) else hook_inputs out = output[0] if isinstance(output, (tuple, list)) else output wrapped_pre_norm.record(inp.detach(), out.detach()) def attn_hook(_, __, output): attn_out = output[0] if isinstance(output, (tuple, list)) else output wrapped_attn.record(None, attn_out.detach()) handles = [ module_pre_norm.register_forward_hook(pre_norm_hook), module_attn.register_forward_hook(attn_hook), ] working_inputs = [inp.clone() for inp in inputs] for j in range(num_samples): _call_layer_forward( layer, working_inputs[j], _maybe_get(attention_mask, j), _maybe_get(position_ids, j), _maybe_get(cache_position, j), model_type, ) for handle in handles: handle.remove() residual_inputs = wrapped_pre_norm.input_hidden_states norm_inputs = wrapped_pre_norm.output_hidden_states attn_outputs = wrapped_attn.output_hidden_states return residual_inputs, norm_inputs, attn_outputs def _advance_layer_states( layer: torch.nn.Module, inputs: List[torch.Tensor], outputs: List[Optional[torch.Tensor]], attention_mask: Optional[List[Optional[torch.Tensor]]], position_ids: Optional[List[Optional[torch.Tensor]]], cache_position: Optional[List[Optional[torch.Tensor]]], model_type: Optional[str], ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """Propagate calibration activations to the next transformer block in place.""" num_samples = len(inputs) for j in range(num_samples): outputs[j] = _call_layer_forward( layer, inputs[j], _maybe_get(attention_mask, j), _maybe_get(position_ids, j), _maybe_get(cache_position, j), model_type, ) return outputs, inputs @no_grad() def get_nbl_metrics( model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int, cache_file: Optional[str] = None, ): device = accelerator.device if cache_file is not None and os.path.exists(cache_file): accelerator.print(f"Loading cached NBL metrics from {cache_file}") return torch.load(cache_file, map_location=device) accelerator.print( f"No cached NBL metrics found. Running model on {num_samples} samples for each device." ) unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.config.use_cache = False layers = unwrapped_model.model.layers model_type = getattr(unwrapped_model.config, "model_type", None) inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input( unwrapped_model, dataloader, num_samples ) nmse_scores = torch.full((len(layers),), math.inf, device=device) for idx in tqdm(range(len(layers)), desc="Calculating NBL metrics...", disable=not accelerator.is_main_process): layer_module = layers[idx] residual_list, norm_list, Y_list_raw = _collect_layer_calibration( layer_module, num_samples, inputs, attention_mask, position_ids, cache_position, model_type, ) # Use the post-layernorm activations that actually feed the attention # block as the NBL "X". This aligns the statistics with what the # linearized layer will see at inference. Y_plus_list = [y + x for x, y in zip(norm_list, Y_list_raw)] correlations = compute_cca(norm_list, Y_plus_list, accelerator) nmse_scores[idx] = torch.sum(1 - correlations.square()) accelerator.print(f"Layer {idx} NMSE: {nmse_scores[idx].item()}") inputs, outputs = _advance_layer_states( layer_module, inputs, outputs, attention_mask, position_ids, cache_position, model_type, ) if cache_file is not None and accelerator.is_main_process: cache_dir = os.path.dirname(cache_file) if cache_dir: os.makedirs(cache_dir, exist_ok=True) torch.save(nmse_scores.clone().cpu(), cache_file) logger.info("Saving cached NBL metrics to %s", cache_file) accelerator.wait_for_everyone() return nmse_scores def calculate_nbl_weights( X_list: List[torch.Tensor], Y_list: List[torch.Tensor], accelerator: Accelerator, regularization: float = _REGULARIZATION_EPS, ): """Solve the LMMSE system that maps normalized inputs to attention outputs.""" device = accelerator.device X_mean, Y_mean, Cxx, _, Cxy = _compute_cova_matrices_iterative_dist(X_list, Y_list, accelerator) eye_x = torch.eye(Cxx.size(0), device=device, dtype=Cxx.dtype) Cxx_reg = Cxx + regularization * eye_x Cyx = Cxy.transpose(0, 1) X_mean = X_mean.to(device) Y_mean = Y_mean.to(device) W = Cyx @ torch.linalg.pinv(Cxx_reg) b = Y_mean - W @ X_mean return W.cpu(), b.cpu() @no_grad() def apply_nbl_linearization( model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int, num_layers_to_linearize: int, nbl_metric_cache_file: Optional[str] = None, ): nmse_scores = get_nbl_metrics( model, dataloader, accelerator, num_samples, cache_file=nbl_metric_cache_file, ) sorted_nmse, sorted_indices = torch.sort(nmse_scores, dim=0, descending=False) layers_to_linearize = sorted_indices[:num_layers_to_linearize].tolist() accelerator.print( f"Linearizing layers: {layers_to_linearize} with NMSE scores: {sorted_nmse[:num_layers_to_linearize].tolist()}" ) unwrapped_model = accelerator.unwrap_model(model) model_layers = unwrapped_model.model.layers model_type = getattr(unwrapped_model.config, "model_type", None) inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input( unwrapped_model, dataloader, num_samples ) linearization_data = {} for idx in tqdm(range(len(model_layers)), desc="Calculating linearization weights...", disable=not accelerator.is_main_process): layer_module = model_layers[idx] if idx in layers_to_linearize: residual_list, norm_list, Y_list = _collect_layer_calibration( layer_module, num_samples, inputs, attention_mask, position_ids, cache_position, model_type, ) # Fit on the normalized inputs that are used at inference time W, b = calculate_nbl_weights(norm_list, Y_list, accelerator) linearization_data[idx] = {"W": W, "b": b} accelerator.print(f"Calculated weights for layer {idx}") inputs, outputs = _advance_layer_states( layer_module, inputs, outputs, attention_mask, position_ids, cache_position, model_type, ) return linearization_data