File size: 13,122 Bytes
d73500e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 |
import logging
import math
import os
from typing import List, Optional, Tuple
import torch
from torch import no_grad
from torch.utils.data import DataLoader
from accelerate import Accelerator
from tqdm import tqdm
from .utils import prepare_calibration_input
from .wrapper import HiddenStatesRecordWrapper
logger = logging.getLogger(__name__)
_REGULARIZATION_EPS = 1e-5
def _matrix_inverse_sqrt(matrix: torch.Tensor, epsilon: float = 1e-9) -> torch.Tensor:
"""Compute the inverse square root of a symmetric matrix via eigendecomposition."""
eigvals, eigvecs = torch.linalg.eigh(matrix.to(torch.float32))
inv_sqrt = 1.0 / (torch.sqrt(torch.clamp(eigvals, min=0.0)) + epsilon)
inv_sqrt_mat = eigvecs @ torch.diag(inv_sqrt) @ eigvecs.transpose(-2, -1)
return inv_sqrt_mat.to(matrix.dtype)
def _maybe_get(sequence: Optional[List[Optional[torch.Tensor]]], idx: int) -> Optional[torch.Tensor]:
if sequence is None:
return None
return sequence[idx]
def _call_layer_forward(
layer: torch.nn.Module,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor],
cache_position: Optional[torch.Tensor],
model_type: Optional[str],
) -> torch.Tensor:
"""Run a single transformer block on calibration activations."""
kwargs = {}
if attention_mask is not None:
kwargs["attention_mask"] = attention_mask
if position_ids is not None:
kwargs["position_ids"] = position_ids
if cache_position is not None and model_type in {"llama", "mistral"}:
kwargs["cache_position"] = cache_position
outputs = layer(hidden_state, **kwargs)
return outputs[0] if isinstance(outputs, (tuple, list)) else outputs
def _compute_cova_matrices_iterative_dist(
X_list: List[torch.Tensor],
Y_list: List[torch.Tensor],
accelerator: Accelerator,
):
"""Compute first and second-order moments in a distributed-friendly way."""
device = accelerator.device
hidden_dim = X_list[0].shape[-1]
X_sum_local = torch.zeros(hidden_dim, dtype=torch.float64)
Y_sum_local = torch.zeros(hidden_dim, dtype=torch.float64)
total_tokens_local = 0
for x in X_list:
x_flat = x.view(-1, hidden_dim).to(dtype=torch.float64)
X_sum_local += x_flat.sum(dim=0)
total_tokens_local += x_flat.shape[0]
for y in Y_list:
y_flat = y.view(-1, hidden_dim).to(dtype=torch.float64)
Y_sum_local += y_flat.sum(dim=0)
X_sum_global = accelerator.reduce(X_sum_local.to(device), reduction="sum")
Y_sum_global = accelerator.reduce(Y_sum_local.to(device), reduction="sum")
total_tokens_tensor = torch.tensor(total_tokens_local, device=device, dtype=torch.float64)
total_tokens_global = accelerator.reduce(total_tokens_tensor, reduction="sum").item()
if total_tokens_global <= 1:
raise RuntimeError("Not enough calibration tokens to compute covariance matrices.")
X_mean = (X_sum_global / total_tokens_global).to(torch.float32)
Y_mean = (Y_sum_global / total_tokens_global).to(torch.float32)
Cxx_local = torch.zeros((hidden_dim, hidden_dim), device=device, dtype=torch.float64)
Cyy_local = torch.zeros_like(Cxx_local)
Cxy_local = torch.zeros_like(Cxx_local)
X_mean64 = X_mean.to(device=device, dtype=torch.float64)
Y_mean64 = Y_mean.to(device=device, dtype=torch.float64)
for x, y in zip(X_list, Y_list):
x_centered = x.view(-1, hidden_dim).to(device=device, dtype=torch.float64) - X_mean64
y_centered = y.view(-1, hidden_dim).to(device=device, dtype=torch.float64) - Y_mean64
Cxx_local += x_centered.T @ x_centered
Cyy_local += y_centered.T @ y_centered
Cxy_local += x_centered.T @ y_centered
denom = float(total_tokens_global - 1)
Cxx_global = accelerator.reduce(Cxx_local, reduction="sum") / denom
Cyy_global = accelerator.reduce(Cyy_local, reduction="sum") / denom
Cxy_global = accelerator.reduce(Cxy_local, reduction="sum") / denom
Cxx = Cxx_global.to(torch.float32)
Cyy = Cyy_global.to(torch.float32)
Cxy = Cxy_global.to(torch.float32)
return X_mean, Y_mean, Cxx, Cyy, Cxy
def compute_cca(
X_list: List[torch.Tensor],
Y_list: List[torch.Tensor],
accelerator: Accelerator,
regularization: float = _REGULARIZATION_EPS,
) -> torch.Tensor:
"""Compute canonical correlations following the NBL formulation."""
device = accelerator.device
_, _, Cxx, Cyy, Cxy = _compute_cova_matrices_iterative_dist(X_list, Y_list, accelerator)
eye_x = torch.eye(Cxx.size(0), device=device, dtype=Cxx.dtype)
eye_y = torch.eye(Cyy.size(0), device=device, dtype=Cyy.dtype)
Cxx_reg = Cxx + regularization * eye_x
Cyy_reg = Cyy + regularization * eye_y
Cxx_inv_sqrt = _matrix_inverse_sqrt(Cxx_reg)
Cyy_inv_sqrt = _matrix_inverse_sqrt(Cyy_reg)
corr_matrix = Cyy_inv_sqrt @ Cxy @ Cxx_inv_sqrt
_, singular_values, _ = torch.linalg.svd(corr_matrix, full_matrices=False)
correlations = torch.clamp(singular_values.real, min=0.0, max=1.0)
return correlations
def _collect_layer_calibration(
layer: torch.nn.Module,
num_samples: int,
inputs: List[torch.Tensor],
attention_mask: Optional[List[Optional[torch.Tensor]]],
position_ids: Optional[List[Optional[torch.Tensor]]],
cache_position: Optional[List[Optional[torch.Tensor]]],
model_type: Optional[str],
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
"""Capture pre-layernorm inputs, normalized activations and attention outputs with lightweight forward hooks."""
module_pre_norm = layer.input_layernorm
module_attn = layer.self_attn
wrapped_pre_norm = HiddenStatesRecordWrapper(module_pre_norm, record_input=True, record_output=True)
wrapped_attn = HiddenStatesRecordWrapper(module_attn, record_input=False, record_output=True)
def pre_norm_hook(_, hook_inputs, output):
inp = hook_inputs[0] if isinstance(hook_inputs, tuple) else hook_inputs
out = output[0] if isinstance(output, (tuple, list)) else output
wrapped_pre_norm.record(inp.detach(), out.detach())
def attn_hook(_, __, output):
attn_out = output[0] if isinstance(output, (tuple, list)) else output
wrapped_attn.record(None, attn_out.detach())
handles = [
module_pre_norm.register_forward_hook(pre_norm_hook),
module_attn.register_forward_hook(attn_hook),
]
working_inputs = [inp.clone() for inp in inputs]
for j in range(num_samples):
_call_layer_forward(
layer,
working_inputs[j],
_maybe_get(attention_mask, j),
_maybe_get(position_ids, j),
_maybe_get(cache_position, j),
model_type,
)
for handle in handles:
handle.remove()
residual_inputs = wrapped_pre_norm.input_hidden_states
norm_inputs = wrapped_pre_norm.output_hidden_states
attn_outputs = wrapped_attn.output_hidden_states
return residual_inputs, norm_inputs, attn_outputs
def _advance_layer_states(
layer: torch.nn.Module,
inputs: List[torch.Tensor],
outputs: List[Optional[torch.Tensor]],
attention_mask: Optional[List[Optional[torch.Tensor]]],
position_ids: Optional[List[Optional[torch.Tensor]]],
cache_position: Optional[List[Optional[torch.Tensor]]],
model_type: Optional[str],
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Propagate calibration activations to the next transformer block in place."""
num_samples = len(inputs)
for j in range(num_samples):
outputs[j] = _call_layer_forward(
layer,
inputs[j],
_maybe_get(attention_mask, j),
_maybe_get(position_ids, j),
_maybe_get(cache_position, j),
model_type,
)
return outputs, inputs
@no_grad()
def get_nbl_metrics(
model,
dataloader: DataLoader,
accelerator: Accelerator,
num_samples: int,
cache_file: Optional[str] = None,
):
device = accelerator.device
if cache_file is not None and os.path.exists(cache_file):
accelerator.print(f"Loading cached NBL metrics from {cache_file}")
return torch.load(cache_file, map_location=device)
accelerator.print(
f"No cached NBL metrics found. Running model on {num_samples} samples for each device."
)
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.config.use_cache = False
layers = unwrapped_model.model.layers
model_type = getattr(unwrapped_model.config, "model_type", None)
inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input(
unwrapped_model, dataloader, num_samples
)
nmse_scores = torch.full((len(layers),), math.inf, device=device)
for idx in tqdm(range(len(layers)), desc="Calculating NBL metrics...", disable=not accelerator.is_main_process):
layer_module = layers[idx]
residual_list, norm_list, Y_list_raw = _collect_layer_calibration(
layer_module,
num_samples,
inputs,
attention_mask,
position_ids,
cache_position,
model_type,
)
# Use the post-layernorm activations that actually feed the attention
# block as the NBL "X". This aligns the statistics with what the
# linearized layer will see at inference.
Y_plus_list = [y + x for x, y in zip(norm_list, Y_list_raw)]
correlations = compute_cca(norm_list, Y_plus_list, accelerator)
nmse_scores[idx] = torch.sum(1 - correlations.square())
accelerator.print(f"Layer {idx} NMSE: {nmse_scores[idx].item()}")
inputs, outputs = _advance_layer_states(
layer_module,
inputs,
outputs,
attention_mask,
position_ids,
cache_position,
model_type,
)
if cache_file is not None and accelerator.is_main_process:
cache_dir = os.path.dirname(cache_file)
if cache_dir:
os.makedirs(cache_dir, exist_ok=True)
torch.save(nmse_scores.clone().cpu(), cache_file)
logger.info("Saving cached NBL metrics to %s", cache_file)
accelerator.wait_for_everyone()
return nmse_scores
def calculate_nbl_weights(
X_list: List[torch.Tensor],
Y_list: List[torch.Tensor],
accelerator: Accelerator,
regularization: float = _REGULARIZATION_EPS,
):
"""Solve the LMMSE system that maps normalized inputs to attention outputs."""
device = accelerator.device
X_mean, Y_mean, Cxx, _, Cxy = _compute_cova_matrices_iterative_dist(X_list, Y_list, accelerator)
eye_x = torch.eye(Cxx.size(0), device=device, dtype=Cxx.dtype)
Cxx_reg = Cxx + regularization * eye_x
Cyx = Cxy.transpose(0, 1)
X_mean = X_mean.to(device)
Y_mean = Y_mean.to(device)
W = Cyx @ torch.linalg.pinv(Cxx_reg)
b = Y_mean - W @ X_mean
return W.cpu(), b.cpu()
@no_grad()
def apply_nbl_linearization(
model,
dataloader: DataLoader,
accelerator: Accelerator,
num_samples: int,
num_layers_to_linearize: int,
nbl_metric_cache_file: Optional[str] = None,
):
nmse_scores = get_nbl_metrics(
model,
dataloader,
accelerator,
num_samples,
cache_file=nbl_metric_cache_file,
)
sorted_nmse, sorted_indices = torch.sort(nmse_scores, dim=0, descending=False)
layers_to_linearize = sorted_indices[:num_layers_to_linearize].tolist()
accelerator.print(
f"Linearizing layers: {layers_to_linearize} with NMSE scores: {sorted_nmse[:num_layers_to_linearize].tolist()}"
)
unwrapped_model = accelerator.unwrap_model(model)
model_layers = unwrapped_model.model.layers
model_type = getattr(unwrapped_model.config, "model_type", None)
inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input(
unwrapped_model, dataloader, num_samples
)
linearization_data = {}
for idx in tqdm(range(len(model_layers)), desc="Calculating linearization weights...", disable=not accelerator.is_main_process):
layer_module = model_layers[idx]
if idx in layers_to_linearize:
residual_list, norm_list, Y_list = _collect_layer_calibration(
layer_module,
num_samples,
inputs,
attention_mask,
position_ids,
cache_position,
model_type,
)
# Fit on the normalized inputs that are used at inference time
W, b = calculate_nbl_weights(norm_list, Y_list, accelerator)
linearization_data[idx] = {"W": W, "b": b}
accelerator.print(f"Calculated weights for layer {idx}")
inputs, outputs = _advance_layer_states(
layer_module,
inputs,
outputs,
attention_mask,
position_ids,
cache_position,
model_type,
)
return linearization_data
|