| |
|
| | import logging |
| | import math |
| | import os |
| | from typing import List, Optional, Tuple |
| |
|
| | import torch |
| | from torch import no_grad |
| | from torch.utils.data import DataLoader |
| | from accelerate import Accelerator |
| | from tqdm import tqdm |
| |
|
| | from .utils import prepare_calibration_input |
| | from .wrapper import HiddenStatesRecordWrapper |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | _REGULARIZATION_EPS = 1e-5 |
| |
|
| |
|
| | def _matrix_inverse_sqrt(matrix: torch.Tensor, epsilon: float = 1e-9) -> torch.Tensor: |
| | """Compute the inverse square root of a symmetric matrix via eigendecomposition.""" |
| | eigvals, eigvecs = torch.linalg.eigh(matrix.to(torch.float32)) |
| | inv_sqrt = 1.0 / (torch.sqrt(torch.clamp(eigvals, min=0.0)) + epsilon) |
| | inv_sqrt_mat = eigvecs @ torch.diag(inv_sqrt) @ eigvecs.transpose(-2, -1) |
| | return inv_sqrt_mat.to(matrix.dtype) |
| |
|
| |
|
| | def _maybe_get(sequence: Optional[List[Optional[torch.Tensor]]], idx: int) -> Optional[torch.Tensor]: |
| | if sequence is None: |
| | return None |
| | return sequence[idx] |
| |
|
| |
|
| | def _call_layer_forward( |
| | layer: torch.nn.Module, |
| | hidden_state: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor], |
| | position_ids: Optional[torch.Tensor], |
| | cache_position: Optional[torch.Tensor], |
| | model_type: Optional[str], |
| | ) -> torch.Tensor: |
| | """Run a single transformer block on calibration activations.""" |
| | kwargs = {} |
| | if attention_mask is not None: |
| | kwargs["attention_mask"] = attention_mask |
| | if position_ids is not None: |
| | kwargs["position_ids"] = position_ids |
| | if cache_position is not None and model_type in {"llama", "mistral"}: |
| | kwargs["cache_position"] = cache_position |
| |
|
| | outputs = layer(hidden_state, **kwargs) |
| | return outputs[0] if isinstance(outputs, (tuple, list)) else outputs |
| |
|
| |
|
| | def _compute_cova_matrices_iterative_dist( |
| | X_list: List[torch.Tensor], |
| | Y_list: List[torch.Tensor], |
| | accelerator: Accelerator, |
| | ): |
| | """Compute first and second-order moments in a distributed-friendly way.""" |
| | device = accelerator.device |
| | hidden_dim = X_list[0].shape[-1] |
| |
|
| | X_sum_local = torch.zeros(hidden_dim, dtype=torch.float64) |
| | Y_sum_local = torch.zeros(hidden_dim, dtype=torch.float64) |
| | total_tokens_local = 0 |
| |
|
| | for x in X_list: |
| | x_flat = x.view(-1, hidden_dim).to(dtype=torch.float64) |
| | X_sum_local += x_flat.sum(dim=0) |
| | total_tokens_local += x_flat.shape[0] |
| | for y in Y_list: |
| | y_flat = y.view(-1, hidden_dim).to(dtype=torch.float64) |
| | Y_sum_local += y_flat.sum(dim=0) |
| |
|
| | X_sum_global = accelerator.reduce(X_sum_local.to(device), reduction="sum") |
| | Y_sum_global = accelerator.reduce(Y_sum_local.to(device), reduction="sum") |
| | total_tokens_tensor = torch.tensor(total_tokens_local, device=device, dtype=torch.float64) |
| | total_tokens_global = accelerator.reduce(total_tokens_tensor, reduction="sum").item() |
| |
|
| | if total_tokens_global <= 1: |
| | raise RuntimeError("Not enough calibration tokens to compute covariance matrices.") |
| |
|
| | X_mean = (X_sum_global / total_tokens_global).to(torch.float32) |
| | Y_mean = (Y_sum_global / total_tokens_global).to(torch.float32) |
| |
|
| | Cxx_local = torch.zeros((hidden_dim, hidden_dim), device=device, dtype=torch.float64) |
| | Cyy_local = torch.zeros_like(Cxx_local) |
| | Cxy_local = torch.zeros_like(Cxx_local) |
| |
|
| | X_mean64 = X_mean.to(device=device, dtype=torch.float64) |
| | Y_mean64 = Y_mean.to(device=device, dtype=torch.float64) |
| |
|
| | for x, y in zip(X_list, Y_list): |
| | x_centered = x.view(-1, hidden_dim).to(device=device, dtype=torch.float64) - X_mean64 |
| | y_centered = y.view(-1, hidden_dim).to(device=device, dtype=torch.float64) - Y_mean64 |
| |
|
| | Cxx_local += x_centered.T @ x_centered |
| | Cyy_local += y_centered.T @ y_centered |
| | Cxy_local += x_centered.T @ y_centered |
| |
|
| | denom = float(total_tokens_global - 1) |
| | Cxx_global = accelerator.reduce(Cxx_local, reduction="sum") / denom |
| | Cyy_global = accelerator.reduce(Cyy_local, reduction="sum") / denom |
| | Cxy_global = accelerator.reduce(Cxy_local, reduction="sum") / denom |
| |
|
| | Cxx = Cxx_global.to(torch.float32) |
| | Cyy = Cyy_global.to(torch.float32) |
| | Cxy = Cxy_global.to(torch.float32) |
| |
|
| | return X_mean, Y_mean, Cxx, Cyy, Cxy |
| |
|
| |
|
| | def compute_cca( |
| | X_list: List[torch.Tensor], |
| | Y_list: List[torch.Tensor], |
| | accelerator: Accelerator, |
| | regularization: float = _REGULARIZATION_EPS, |
| | ) -> torch.Tensor: |
| | """Compute canonical correlations following the NBL formulation.""" |
| | device = accelerator.device |
| | _, _, Cxx, Cyy, Cxy = _compute_cova_matrices_iterative_dist(X_list, Y_list, accelerator) |
| |
|
| | eye_x = torch.eye(Cxx.size(0), device=device, dtype=Cxx.dtype) |
| | eye_y = torch.eye(Cyy.size(0), device=device, dtype=Cyy.dtype) |
| |
|
| | Cxx_reg = Cxx + regularization * eye_x |
| | Cyy_reg = Cyy + regularization * eye_y |
| |
|
| | Cxx_inv_sqrt = _matrix_inverse_sqrt(Cxx_reg) |
| | Cyy_inv_sqrt = _matrix_inverse_sqrt(Cyy_reg) |
| |
|
| | corr_matrix = Cyy_inv_sqrt @ Cxy @ Cxx_inv_sqrt |
| | _, singular_values, _ = torch.linalg.svd(corr_matrix, full_matrices=False) |
| | correlations = torch.clamp(singular_values.real, min=0.0, max=1.0) |
| | return correlations |
| |
|
| |
|
| | def _collect_layer_calibration( |
| | layer: torch.nn.Module, |
| | num_samples: int, |
| | inputs: List[torch.Tensor], |
| | attention_mask: Optional[List[Optional[torch.Tensor]]], |
| | position_ids: Optional[List[Optional[torch.Tensor]]], |
| | cache_position: Optional[List[Optional[torch.Tensor]]], |
| | model_type: Optional[str], |
| | ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: |
| | """Capture pre-layernorm inputs, normalized activations and attention outputs with lightweight forward hooks.""" |
| | module_pre_norm = layer.input_layernorm |
| | module_attn = layer.self_attn |
| |
|
| | wrapped_pre_norm = HiddenStatesRecordWrapper(module_pre_norm, record_input=True, record_output=True) |
| | wrapped_attn = HiddenStatesRecordWrapper(module_attn, record_input=False, record_output=True) |
| |
|
| | def pre_norm_hook(_, hook_inputs, output): |
| | inp = hook_inputs[0] if isinstance(hook_inputs, tuple) else hook_inputs |
| | out = output[0] if isinstance(output, (tuple, list)) else output |
| | wrapped_pre_norm.record(inp.detach(), out.detach()) |
| |
|
| | def attn_hook(_, __, output): |
| | attn_out = output[0] if isinstance(output, (tuple, list)) else output |
| | wrapped_attn.record(None, attn_out.detach()) |
| |
|
| | handles = [ |
| | module_pre_norm.register_forward_hook(pre_norm_hook), |
| | module_attn.register_forward_hook(attn_hook), |
| | ] |
| |
|
| | working_inputs = [inp.clone() for inp in inputs] |
| | for j in range(num_samples): |
| | _call_layer_forward( |
| | layer, |
| | working_inputs[j], |
| | _maybe_get(attention_mask, j), |
| | _maybe_get(position_ids, j), |
| | _maybe_get(cache_position, j), |
| | model_type, |
| | ) |
| |
|
| | for handle in handles: |
| | handle.remove() |
| |
|
| | residual_inputs = wrapped_pre_norm.input_hidden_states |
| | norm_inputs = wrapped_pre_norm.output_hidden_states |
| | attn_outputs = wrapped_attn.output_hidden_states |
| |
|
| | return residual_inputs, norm_inputs, attn_outputs |
| |
|
| |
|
| | def _advance_layer_states( |
| | layer: torch.nn.Module, |
| | inputs: List[torch.Tensor], |
| | outputs: List[Optional[torch.Tensor]], |
| | attention_mask: Optional[List[Optional[torch.Tensor]]], |
| | position_ids: Optional[List[Optional[torch.Tensor]]], |
| | cache_position: Optional[List[Optional[torch.Tensor]]], |
| | model_type: Optional[str], |
| | ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
| | """Propagate calibration activations to the next transformer block in place.""" |
| | num_samples = len(inputs) |
| | for j in range(num_samples): |
| | outputs[j] = _call_layer_forward( |
| | layer, |
| | inputs[j], |
| | _maybe_get(attention_mask, j), |
| | _maybe_get(position_ids, j), |
| | _maybe_get(cache_position, j), |
| | model_type, |
| | ) |
| | return outputs, inputs |
| |
|
| |
|
| | @no_grad() |
| | def get_nbl_metrics( |
| | model, |
| | dataloader: DataLoader, |
| | accelerator: Accelerator, |
| | num_samples: int, |
| | cache_file: Optional[str] = None, |
| | ): |
| | device = accelerator.device |
| |
|
| | if cache_file is not None and os.path.exists(cache_file): |
| | accelerator.print(f"Loading cached NBL metrics from {cache_file}") |
| | return torch.load(cache_file, map_location=device) |
| |
|
| | accelerator.print( |
| | f"No cached NBL metrics found. Running model on {num_samples} samples for each device." |
| | ) |
| |
|
| | unwrapped_model = accelerator.unwrap_model(model) |
| | unwrapped_model.config.use_cache = False |
| | layers = unwrapped_model.model.layers |
| | model_type = getattr(unwrapped_model.config, "model_type", None) |
| |
|
| | inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input( |
| | unwrapped_model, dataloader, num_samples |
| | ) |
| |
|
| | nmse_scores = torch.full((len(layers),), math.inf, device=device) |
| |
|
| | for idx in tqdm(range(len(layers)), desc="Calculating NBL metrics...", disable=not accelerator.is_main_process): |
| | layer_module = layers[idx] |
| |
|
| | residual_list, norm_list, Y_list_raw = _collect_layer_calibration( |
| | layer_module, |
| | num_samples, |
| | inputs, |
| | attention_mask, |
| | position_ids, |
| | cache_position, |
| | model_type, |
| | ) |
| |
|
| | |
| | |
| | |
| | Y_plus_list = [y + x for x, y in zip(norm_list, Y_list_raw)] |
| | correlations = compute_cca(norm_list, Y_plus_list, accelerator) |
| | nmse_scores[idx] = torch.sum(1 - correlations.square()) |
| |
|
| | accelerator.print(f"Layer {idx} NMSE: {nmse_scores[idx].item()}") |
| |
|
| | inputs, outputs = _advance_layer_states( |
| | layer_module, |
| | inputs, |
| | outputs, |
| | attention_mask, |
| | position_ids, |
| | cache_position, |
| | model_type, |
| | ) |
| |
|
| | if cache_file is not None and accelerator.is_main_process: |
| | cache_dir = os.path.dirname(cache_file) |
| | if cache_dir: |
| | os.makedirs(cache_dir, exist_ok=True) |
| | torch.save(nmse_scores.clone().cpu(), cache_file) |
| | logger.info("Saving cached NBL metrics to %s", cache_file) |
| | accelerator.wait_for_everyone() |
| |
|
| | return nmse_scores |
| |
|
| |
|
| | def calculate_nbl_weights( |
| | X_list: List[torch.Tensor], |
| | Y_list: List[torch.Tensor], |
| | accelerator: Accelerator, |
| | regularization: float = _REGULARIZATION_EPS, |
| | ): |
| | """Solve the LMMSE system that maps normalized inputs to attention outputs.""" |
| | device = accelerator.device |
| | X_mean, Y_mean, Cxx, _, Cxy = _compute_cova_matrices_iterative_dist(X_list, Y_list, accelerator) |
| |
|
| | eye_x = torch.eye(Cxx.size(0), device=device, dtype=Cxx.dtype) |
| | Cxx_reg = Cxx + regularization * eye_x |
| | Cyx = Cxy.transpose(0, 1) |
| |
|
| | X_mean = X_mean.to(device) |
| | Y_mean = Y_mean.to(device) |
| | W = Cyx @ torch.linalg.pinv(Cxx_reg) |
| | b = Y_mean - W @ X_mean |
| |
|
| | return W.cpu(), b.cpu() |
| |
|
| |
|
| | @no_grad() |
| | def apply_nbl_linearization( |
| | model, |
| | dataloader: DataLoader, |
| | accelerator: Accelerator, |
| | num_samples: int, |
| | num_layers_to_linearize: int, |
| | nbl_metric_cache_file: Optional[str] = None, |
| | ): |
| | nmse_scores = get_nbl_metrics( |
| | model, |
| | dataloader, |
| | accelerator, |
| | num_samples, |
| | cache_file=nbl_metric_cache_file, |
| | ) |
| |
|
| | sorted_nmse, sorted_indices = torch.sort(nmse_scores, dim=0, descending=False) |
| | layers_to_linearize = sorted_indices[:num_layers_to_linearize].tolist() |
| | accelerator.print( |
| | f"Linearizing layers: {layers_to_linearize} with NMSE scores: {sorted_nmse[:num_layers_to_linearize].tolist()}" |
| | ) |
| |
|
| | unwrapped_model = accelerator.unwrap_model(model) |
| | model_layers = unwrapped_model.model.layers |
| | model_type = getattr(unwrapped_model.config, "model_type", None) |
| |
|
| | inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input( |
| | unwrapped_model, dataloader, num_samples |
| | ) |
| |
|
| | linearization_data = {} |
| |
|
| | for idx in tqdm(range(len(model_layers)), desc="Calculating linearization weights...", disable=not accelerator.is_main_process): |
| | layer_module = model_layers[idx] |
| |
|
| | if idx in layers_to_linearize: |
| | residual_list, norm_list, Y_list = _collect_layer_calibration( |
| | layer_module, |
| | num_samples, |
| | inputs, |
| | attention_mask, |
| | position_ids, |
| | cache_position, |
| | model_type, |
| | ) |
| |
|
| | |
| | W, b = calculate_nbl_weights(norm_list, Y_list, accelerator) |
| | linearization_data[idx] = {"W": W, "b": b} |
| | accelerator.print(f"Calculated weights for layer {idx}") |
| |
|
| | inputs, outputs = _advance_layer_states( |
| | layer_module, |
| | inputs, |
| | outputs, |
| | attention_mask, |
| | position_ids, |
| | cache_position, |
| | model_type, |
| | ) |
| |
|
| | return linearization_data |
| |
|