s1ghhh's picture
Upload folder using huggingface_hub
d73500e verified
import logging
import math
import os
from typing import List, Optional, Tuple
import torch
from torch import no_grad
from torch.utils.data import DataLoader
from accelerate import Accelerator
from tqdm import tqdm
from .utils import prepare_calibration_input
from .wrapper import HiddenStatesRecordWrapper
logger = logging.getLogger(__name__)
_REGULARIZATION_EPS = 1e-5
def _matrix_inverse_sqrt(matrix: torch.Tensor, epsilon: float = 1e-9) -> torch.Tensor:
"""Compute the inverse square root of a symmetric matrix via eigendecomposition."""
eigvals, eigvecs = torch.linalg.eigh(matrix.to(torch.float32))
inv_sqrt = 1.0 / (torch.sqrt(torch.clamp(eigvals, min=0.0)) + epsilon)
inv_sqrt_mat = eigvecs @ torch.diag(inv_sqrt) @ eigvecs.transpose(-2, -1)
return inv_sqrt_mat.to(matrix.dtype)
def _maybe_get(sequence: Optional[List[Optional[torch.Tensor]]], idx: int) -> Optional[torch.Tensor]:
if sequence is None:
return None
return sequence[idx]
def _call_layer_forward(
layer: torch.nn.Module,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor],
cache_position: Optional[torch.Tensor],
model_type: Optional[str],
) -> torch.Tensor:
"""Run a single transformer block on calibration activations."""
kwargs = {}
if attention_mask is not None:
kwargs["attention_mask"] = attention_mask
if position_ids is not None:
kwargs["position_ids"] = position_ids
if cache_position is not None and model_type in {"llama", "mistral"}:
kwargs["cache_position"] = cache_position
outputs = layer(hidden_state, **kwargs)
return outputs[0] if isinstance(outputs, (tuple, list)) else outputs
def _compute_cova_matrices_iterative_dist(
X_list: List[torch.Tensor],
Y_list: List[torch.Tensor],
accelerator: Accelerator,
):
"""Compute first and second-order moments in a distributed-friendly way."""
device = accelerator.device
hidden_dim = X_list[0].shape[-1]
X_sum_local = torch.zeros(hidden_dim, dtype=torch.float64)
Y_sum_local = torch.zeros(hidden_dim, dtype=torch.float64)
total_tokens_local = 0
for x in X_list:
x_flat = x.view(-1, hidden_dim).to(dtype=torch.float64)
X_sum_local += x_flat.sum(dim=0)
total_tokens_local += x_flat.shape[0]
for y in Y_list:
y_flat = y.view(-1, hidden_dim).to(dtype=torch.float64)
Y_sum_local += y_flat.sum(dim=0)
X_sum_global = accelerator.reduce(X_sum_local.to(device), reduction="sum")
Y_sum_global = accelerator.reduce(Y_sum_local.to(device), reduction="sum")
total_tokens_tensor = torch.tensor(total_tokens_local, device=device, dtype=torch.float64)
total_tokens_global = accelerator.reduce(total_tokens_tensor, reduction="sum").item()
if total_tokens_global <= 1:
raise RuntimeError("Not enough calibration tokens to compute covariance matrices.")
X_mean = (X_sum_global / total_tokens_global).to(torch.float32)
Y_mean = (Y_sum_global / total_tokens_global).to(torch.float32)
Cxx_local = torch.zeros((hidden_dim, hidden_dim), device=device, dtype=torch.float64)
Cyy_local = torch.zeros_like(Cxx_local)
Cxy_local = torch.zeros_like(Cxx_local)
X_mean64 = X_mean.to(device=device, dtype=torch.float64)
Y_mean64 = Y_mean.to(device=device, dtype=torch.float64)
for x, y in zip(X_list, Y_list):
x_centered = x.view(-1, hidden_dim).to(device=device, dtype=torch.float64) - X_mean64
y_centered = y.view(-1, hidden_dim).to(device=device, dtype=torch.float64) - Y_mean64
Cxx_local += x_centered.T @ x_centered
Cyy_local += y_centered.T @ y_centered
Cxy_local += x_centered.T @ y_centered
denom = float(total_tokens_global - 1)
Cxx_global = accelerator.reduce(Cxx_local, reduction="sum") / denom
Cyy_global = accelerator.reduce(Cyy_local, reduction="sum") / denom
Cxy_global = accelerator.reduce(Cxy_local, reduction="sum") / denom
Cxx = Cxx_global.to(torch.float32)
Cyy = Cyy_global.to(torch.float32)
Cxy = Cxy_global.to(torch.float32)
return X_mean, Y_mean, Cxx, Cyy, Cxy
def compute_cca(
X_list: List[torch.Tensor],
Y_list: List[torch.Tensor],
accelerator: Accelerator,
regularization: float = _REGULARIZATION_EPS,
) -> torch.Tensor:
"""Compute canonical correlations following the NBL formulation."""
device = accelerator.device
_, _, Cxx, Cyy, Cxy = _compute_cova_matrices_iterative_dist(X_list, Y_list, accelerator)
eye_x = torch.eye(Cxx.size(0), device=device, dtype=Cxx.dtype)
eye_y = torch.eye(Cyy.size(0), device=device, dtype=Cyy.dtype)
Cxx_reg = Cxx + regularization * eye_x
Cyy_reg = Cyy + regularization * eye_y
Cxx_inv_sqrt = _matrix_inverse_sqrt(Cxx_reg)
Cyy_inv_sqrt = _matrix_inverse_sqrt(Cyy_reg)
corr_matrix = Cyy_inv_sqrt @ Cxy @ Cxx_inv_sqrt
_, singular_values, _ = torch.linalg.svd(corr_matrix, full_matrices=False)
correlations = torch.clamp(singular_values.real, min=0.0, max=1.0)
return correlations
def _collect_layer_calibration(
layer: torch.nn.Module,
num_samples: int,
inputs: List[torch.Tensor],
attention_mask: Optional[List[Optional[torch.Tensor]]],
position_ids: Optional[List[Optional[torch.Tensor]]],
cache_position: Optional[List[Optional[torch.Tensor]]],
model_type: Optional[str],
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
"""Capture pre-layernorm inputs, normalized activations and attention outputs with lightweight forward hooks."""
module_pre_norm = layer.input_layernorm
module_attn = layer.self_attn
wrapped_pre_norm = HiddenStatesRecordWrapper(module_pre_norm, record_input=True, record_output=True)
wrapped_attn = HiddenStatesRecordWrapper(module_attn, record_input=False, record_output=True)
def pre_norm_hook(_, hook_inputs, output):
inp = hook_inputs[0] if isinstance(hook_inputs, tuple) else hook_inputs
out = output[0] if isinstance(output, (tuple, list)) else output
wrapped_pre_norm.record(inp.detach(), out.detach())
def attn_hook(_, __, output):
attn_out = output[0] if isinstance(output, (tuple, list)) else output
wrapped_attn.record(None, attn_out.detach())
handles = [
module_pre_norm.register_forward_hook(pre_norm_hook),
module_attn.register_forward_hook(attn_hook),
]
working_inputs = [inp.clone() for inp in inputs]
for j in range(num_samples):
_call_layer_forward(
layer,
working_inputs[j],
_maybe_get(attention_mask, j),
_maybe_get(position_ids, j),
_maybe_get(cache_position, j),
model_type,
)
for handle in handles:
handle.remove()
residual_inputs = wrapped_pre_norm.input_hidden_states
norm_inputs = wrapped_pre_norm.output_hidden_states
attn_outputs = wrapped_attn.output_hidden_states
return residual_inputs, norm_inputs, attn_outputs
def _advance_layer_states(
layer: torch.nn.Module,
inputs: List[torch.Tensor],
outputs: List[Optional[torch.Tensor]],
attention_mask: Optional[List[Optional[torch.Tensor]]],
position_ids: Optional[List[Optional[torch.Tensor]]],
cache_position: Optional[List[Optional[torch.Tensor]]],
model_type: Optional[str],
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Propagate calibration activations to the next transformer block in place."""
num_samples = len(inputs)
for j in range(num_samples):
outputs[j] = _call_layer_forward(
layer,
inputs[j],
_maybe_get(attention_mask, j),
_maybe_get(position_ids, j),
_maybe_get(cache_position, j),
model_type,
)
return outputs, inputs
@no_grad()
def get_nbl_metrics(
model,
dataloader: DataLoader,
accelerator: Accelerator,
num_samples: int,
cache_file: Optional[str] = None,
):
device = accelerator.device
if cache_file is not None and os.path.exists(cache_file):
accelerator.print(f"Loading cached NBL metrics from {cache_file}")
return torch.load(cache_file, map_location=device)
accelerator.print(
f"No cached NBL metrics found. Running model on {num_samples} samples for each device."
)
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.config.use_cache = False
layers = unwrapped_model.model.layers
model_type = getattr(unwrapped_model.config, "model_type", None)
inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input(
unwrapped_model, dataloader, num_samples
)
nmse_scores = torch.full((len(layers),), math.inf, device=device)
for idx in tqdm(range(len(layers)), desc="Calculating NBL metrics...", disable=not accelerator.is_main_process):
layer_module = layers[idx]
residual_list, norm_list, Y_list_raw = _collect_layer_calibration(
layer_module,
num_samples,
inputs,
attention_mask,
position_ids,
cache_position,
model_type,
)
# Use the post-layernorm activations that actually feed the attention
# block as the NBL "X". This aligns the statistics with what the
# linearized layer will see at inference.
Y_plus_list = [y + x for x, y in zip(norm_list, Y_list_raw)]
correlations = compute_cca(norm_list, Y_plus_list, accelerator)
nmse_scores[idx] = torch.sum(1 - correlations.square())
accelerator.print(f"Layer {idx} NMSE: {nmse_scores[idx].item()}")
inputs, outputs = _advance_layer_states(
layer_module,
inputs,
outputs,
attention_mask,
position_ids,
cache_position,
model_type,
)
if cache_file is not None and accelerator.is_main_process:
cache_dir = os.path.dirname(cache_file)
if cache_dir:
os.makedirs(cache_dir, exist_ok=True)
torch.save(nmse_scores.clone().cpu(), cache_file)
logger.info("Saving cached NBL metrics to %s", cache_file)
accelerator.wait_for_everyone()
return nmse_scores
def calculate_nbl_weights(
X_list: List[torch.Tensor],
Y_list: List[torch.Tensor],
accelerator: Accelerator,
regularization: float = _REGULARIZATION_EPS,
):
"""Solve the LMMSE system that maps normalized inputs to attention outputs."""
device = accelerator.device
X_mean, Y_mean, Cxx, _, Cxy = _compute_cova_matrices_iterative_dist(X_list, Y_list, accelerator)
eye_x = torch.eye(Cxx.size(0), device=device, dtype=Cxx.dtype)
Cxx_reg = Cxx + regularization * eye_x
Cyx = Cxy.transpose(0, 1)
X_mean = X_mean.to(device)
Y_mean = Y_mean.to(device)
W = Cyx @ torch.linalg.pinv(Cxx_reg)
b = Y_mean - W @ X_mean
return W.cpu(), b.cpu()
@no_grad()
def apply_nbl_linearization(
model,
dataloader: DataLoader,
accelerator: Accelerator,
num_samples: int,
num_layers_to_linearize: int,
nbl_metric_cache_file: Optional[str] = None,
):
nmse_scores = get_nbl_metrics(
model,
dataloader,
accelerator,
num_samples,
cache_file=nbl_metric_cache_file,
)
sorted_nmse, sorted_indices = torch.sort(nmse_scores, dim=0, descending=False)
layers_to_linearize = sorted_indices[:num_layers_to_linearize].tolist()
accelerator.print(
f"Linearizing layers: {layers_to_linearize} with NMSE scores: {sorted_nmse[:num_layers_to_linearize].tolist()}"
)
unwrapped_model = accelerator.unwrap_model(model)
model_layers = unwrapped_model.model.layers
model_type = getattr(unwrapped_model.config, "model_type", None)
inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input(
unwrapped_model, dataloader, num_samples
)
linearization_data = {}
for idx in tqdm(range(len(model_layers)), desc="Calculating linearization weights...", disable=not accelerator.is_main_process):
layer_module = model_layers[idx]
if idx in layers_to_linearize:
residual_list, norm_list, Y_list = _collect_layer_calibration(
layer_module,
num_samples,
inputs,
attention_mask,
position_ids,
cache_position,
model_type,
)
# Fit on the normalized inputs that are used at inference time
W, b = calculate_nbl_weights(norm_list, Y_list, accelerator)
linearization_data[idx] = {"W": W, "b": b}
accelerator.print(f"Calculated weights for layer {idx}")
inputs, outputs = _advance_layer_states(
layer_module,
inputs,
outputs,
attention_mask,
position_ids,
cache_position,
model_type,
)
return linearization_data