File size: 24,226 Bytes
01426da 002163f 01426da 002163f 01426da 002163f 01426da 002163f 01426da 002163f 01426da 002163f 01426da 002163f 01426da 002163f 01426da 002163f 01426da | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 | """
Sovereign Hive HSAQ β KV-cache sensitivity profiler
====================================================
Sweeps each attention layer across a curated 11-config grid of
(k_bits, v_bits, quantizer) combinations and measures the resulting
attention-output drift on calibration data. Emits ProfileRow objects
ready for insertion into the `kv_sensitivity_profile` Vault table
(migration 003).
Why this file exists:
The KV cache is the bigger lever than weight quantization for MHA
models on 12 GB cards (e.g. OLMo: ~3.3 GB KV at 4K ctx fp16, more than
the weights). The allocator needs measured drift values to pick which
layers can tolerate aggressive KV quant. This profiler is what feeds it.
Pure-logic conventions (matching candidate_record.py):
- No Vault writes. Caller persists the emitted ProfileRows.
- No PermissionGate routing here. Caller (the hunter agent) handles that.
- All disk/network I/O is the caller's job. We accept a loaded model
and tokenizer; we return data structures.
What this DOES touch:
- GPU forward passes (calls `model(**batch)` under torch.no_grad).
- kv_intercept.kv_quant_active context manager to install hooks.
- Hook lifecycle (always cleaned up via context manager β verified in
smoke_test_v3.py test #6).
What this does NOT do:
- It does not handle inference-queue gating. If two profilers run
concurrently on the same GPU, you'll OOM. Caller must serialize.
- It does not write to disk. Output is in-memory ProfileRow objects.
- It does not bump pipeline_version. The constant lives here as a
reference value; callers wanting a different lineage pass their own.
Config sweep rationale (11 per layer):
Full Cartesian (5 k_bits Γ 5 v_bits Γ 4 quantizers) = 100/layer is way
overkill β most combinations carry no signal the allocator can use.
This curated 11 covers the decision boundaries that actually matter:
Symmetric K=V (the common case): (8,8), (4,4), (3,3), (2,2) hqq_g64
K-cheaper-than-V (K more tolerant): (4,8), (3,8), (2,8), (2,4) hqq_g64
Quantizer comparison: (4,4) scaled_uniform
(4,4) scaled_per_head
(8,8) scaled_uniform
Strategy: "joint" (default) vs "isolated"
"joint" hooks every layer at the same config in ONE forward pass via
kv_quant_active_multi. Total cost: 11 forwards. Drift values reflect
CUMULATIVE CASCADE β each layer's attention sees the quanted output of
earlier layers' attention. This is realistic to deployment.
"isolated" hooks ONE layer at a time, runs a forward per (layer, config)
pair. Total cost: 11 Γ n_layers forwards. Drift values reflect MARGINAL
per-layer sensitivity β what happens when only that layer is quanted.
Empirically (A100, OLMo-2-13B, 32 prompts, 1 GB KV budget) β see
mxguru1/hsaq-strategy-comparison/report.json:
- joint took 64.3 s, isolated took 665.6 s (10.4Γ slower)
- isolated drift values are uniformly LOWER (ratios 0.12-0.61 by config)
because the joint pass compounds drift through the residual stream
- BUT both strategies produced IDENTICAL per-layer allocations (40/40
agreement) when piped through assign_kv_bits β the greedy allocator
only cares about pairwise drift-per-byte ratios, and the ordering of
configs by drift is preserved across strategies
Default is "joint" because: same allocation, 10Γ faster, drift values
are cascade-realistic. Switch to "isolated" only when you want the
marginal per-layer signal for analysis (e.g. ranking pruning candidates,
understanding which layers cause the most ripple downstream).
"""
from __future__ import annotations
import hashlib
import json
import logging
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Callable, Iterable, Literal, Optional
logger = logging.getLogger("HSAQ.KVProfiler")
# Bump when changing the config sweep, drift metric implementation, or
# hook semantics. Cached rows under earlier versions become lookup-misses
# so they're safely ignored rather than silently reused.
PIPELINE_VERSION = "1.0.0"
# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------
DriftMetric = Literal["mse_normalised", "kl_softmax"]
KVQuantizer = Literal["hqq_g64", "scaled_uniform", "scaled_per_head", "fp16_passthrough"]
@dataclass(frozen=True)
class SweepConfig:
"""One (k_bits, v_bits, quantizer) probe to measure on every layer."""
k_bits: int
v_bits: int
quantizer: KVQuantizer
group_size: int = 64
@dataclass
class ProfileRow:
"""One row of measured KV sensitivity. Shape matches the Vault table
kv_sensitivity_profile (migration 003) exactly β caller can insert
this dict directly or pass through a Vault adapter."""
model_hash: str
calibration_hash: str
pipeline_version: str
layer_idx: int
k_bits: int
v_bits: int
quantizer: str
drift_attn_output: float
drift_metric: str
bytes_per_kv_token: float
max_seq_len_observed: int
num_kv_heads: int
head_dim: int
profiled_at: str
profiled_by_agent_id: str
profiled_by_agent_tier: int
def to_vault_payload(self) -> dict:
"""Row-shaped dict ready for INSERT into kv_sensitivity_profile."""
return {
"model_hash": self.model_hash,
"calibration_hash": self.calibration_hash,
"pipeline_version": self.pipeline_version,
"layer_idx": self.layer_idx,
"k_bits": self.k_bits,
"v_bits": self.v_bits,
"quantizer": self.quantizer,
"drift_attn_output": self.drift_attn_output,
"drift_metric": self.drift_metric,
"bytes_per_kv_token": self.bytes_per_kv_token,
"max_seq_len_observed": self.max_seq_len_observed,
"num_kv_heads": self.num_kv_heads,
"head_dim": self.head_dim,
"profiled_at": self.profiled_at,
"profiled_by_agent_id": self.profiled_by_agent_id,
"profiled_by_agent_tier": self.profiled_by_agent_tier,
}
# ---------------------------------------------------------------------------
# Default sweep β the curated 11 configs
# ---------------------------------------------------------------------------
DEFAULT_SWEEP: tuple[SweepConfig, ...] = (
# ββ Symmetric K=V (most common case, 4 configs) ββββββββββββββββββββββ
SweepConfig(k_bits=8, v_bits=8, quantizer="hqq_g64"),
SweepConfig(k_bits=4, v_bits=4, quantizer="hqq_g64"),
SweepConfig(k_bits=3, v_bits=3, quantizer="hqq_g64"),
SweepConfig(k_bits=2, v_bits=2, quantizer="hqq_g64"),
# ββ K-cheaper-than-V (K is more error-tolerant, 4 configs) βββββββββββ
SweepConfig(k_bits=4, v_bits=8, quantizer="hqq_g64"),
SweepConfig(k_bits=3, v_bits=8, quantizer="hqq_g64"),
SweepConfig(k_bits=2, v_bits=8, quantizer="hqq_g64"),
SweepConfig(k_bits=2, v_bits=4, quantizer="hqq_g64"),
# ββ Quantizer comparison (3 configs) ββββββββββββββββββββββββββββββββββ
SweepConfig(k_bits=4, v_bits=4, quantizer="scaled_uniform"),
SweepConfig(k_bits=4, v_bits=4, quantizer="scaled_per_head"),
SweepConfig(k_bits=8, v_bits=8, quantizer="scaled_uniform"),
)
# ---------------------------------------------------------------------------
# Bytes-per-KV-token accounting
# ---------------------------------------------------------------------------
def kv_bytes_per_token(
num_kv_heads: int,
head_dim: int,
k_bits: int,
v_bits: int,
quantizer: KVQuantizer,
group_size: int = 64,
) -> float:
"""Bytes per token of KV cache for one layer at the given config.
Includes quantizer-specific overhead (scales / zeros). The allocator
reads this value cached in the Vault rather than recomputing, so the
formula here is the single source of truth β change it, bump
PIPELINE_VERSION.
"""
elems = num_kv_heads * head_dim
if quantizer == "fp16_passthrough":
# 2 bytes/elem Γ K and V both
return 2 * elems * 2 # K + V at fp16 (2 bytes/elem)
if quantizer == "scaled_uniform":
# One fp16 scale per row of K and V each
k_payload = elems * k_bits / 8
v_payload = elems * v_bits / 8
return (k_payload + 2) + (v_payload + 2)
if quantizer == "scaled_per_head":
# fp16 scale per head, per K/V
k_payload = elems * k_bits / 8
v_payload = elems * v_bits / 8
return (k_payload + num_kv_heads * 2) + (v_payload + num_kv_heads * 2)
if quantizer == "hqq_g64":
# Groups along the head_dim axis (per row of the KV cache).
# NOTE: this is the corrected accounting β earlier stub grouped
# along (num_kv_heads Γ head_dim) which underestimated overhead.
groups_per_row = max(1, head_dim // group_size)
# zero + scale per group, fp16 each = 4 bytes per group
overhead_per_row = num_kv_heads * groups_per_row * 4
k_payload = elems * k_bits / 8
v_payload = elems * v_bits / 8
return (k_payload + overhead_per_row) + (v_payload + overhead_per_row)
raise ValueError(f"unknown quantizer: {quantizer}")
# ---------------------------------------------------------------------------
# Drift metrics
# ---------------------------------------------------------------------------
def compute_drift(quanted, baseline, metric: DriftMetric) -> float:
"""Compute attention-output drift between quanted and fp16 baseline.
mse_normalised: ((quanted - baseline)^2).mean() / baseline.abs().mean()^2
Scale-invariant. Recommended for allocator inputs. KIVI-style.
kl_softmax: KL divergence treating attention outputs as logits over the
last dim. PROVISIONAL β attention output isn't a distribution, so
this is a proxy metric only. Kept for compatibility with the
Vault schema; do NOT use as allocator input without validation.
"""
import torch
import torch.nn.functional as F
if metric == "mse_normalised":
mse = ((quanted - baseline) ** 2).mean().item()
denom = (baseline ** 2).mean().item()
return mse / max(denom, 1e-8)
if metric == "kl_softmax":
a = F.log_softmax(quanted.float().flatten(0, -2), dim=-1)
b = F.softmax(baseline.float().flatten(0, -2), dim=-1)
return float(F.kl_div(a, b, reduction="batchmean"))
raise ValueError(f"unknown drift metric: {metric}")
# ---------------------------------------------------------------------------
# Calibration hash
# ---------------------------------------------------------------------------
def compute_calibration_hash(
calibration_texts: list[str],
max_seq_len: int,
) -> str:
"""Deterministic hash of calibration inputs for cache invalidation."""
payload = json.dumps({
"n_texts": len(calibration_texts),
"max_seq_len": max_seq_len,
# First/last text content matters for distinguishing calibration sets;
# hash a sample rather than every text to keep this cheap.
"first_text": calibration_texts[0][:200] if calibration_texts else "",
"last_text": calibration_texts[-1][:200] if calibration_texts else "",
"total_chars": sum(len(t) for t in calibration_texts),
}, sort_keys=True)
return hashlib.sha256(payload.encode()).hexdigest()[:16]
# ---------------------------------------------------------------------------
# Per-layer attention-output capture
# ---------------------------------------------------------------------------
def _capture_attn_outputs(
model,
attn_modules_by_layer: dict[int, object],
batch,
) -> dict[int, "torch.Tensor"]:
"""Run one forward pass with hooks that capture attention output per layer.
HF attention modules return either `attn_output` directly or a tuple
starting with `attn_output`. We grab the first tensor element in either
case.
"""
import torch
captured: dict[int, torch.Tensor] = {}
handles: list = []
def make_hook(li: int):
def hook(_mod, _inp, out):
t = out[0] if isinstance(out, tuple) else out
if isinstance(t, torch.Tensor):
# Move to CPU to bound GPU memory across all captured layers
captured[li] = t.detach().to("cpu")
return hook
try:
for li, attn in attn_modules_by_layer.items():
handles.append(attn.register_forward_hook(make_hook(li)))
with torch.no_grad():
model(**batch, use_cache=False)
finally:
for h in handles:
h.remove()
return captured
# ---------------------------------------------------------------------------
# Profiler
# ---------------------------------------------------------------------------
def profile_kv_sensitivity(
model,
tokenizer,
calibration_texts: list[str],
*,
model_hash: str,
profiled_by_agent_id: str,
profiled_by_agent_tier: int,
sweep: Iterable[SweepConfig] = DEFAULT_SWEEP,
drift_metric: DriftMetric = "mse_normalised",
max_seq_len: int = 1024,
pipeline_version: str = PIPELINE_VERSION,
strategy: Literal["joint", "isolated"] = "joint",
progress_cb: Optional[Callable[[str], None]] = None,
) -> list[ProfileRow]:
"""Measure attention-output drift per (layer Γ sweep config).
Args:
model: HF model in eval mode on its target device.
tokenizer: Matching HF tokenizer.
calibration_texts: List of natural-language strings to feed through
the model. Length controls statistical robustness; quality of
the distribution match controls allocator decisions on out-of-
distribution data. 32-256 samples is the sweet spot.
model_hash: Deterministic identifier of the model (caller-computed).
Used as the first part of the cache invalidation key.
profiled_by_agent_id, profiled_by_agent_tier: Audit chain. Caller
is whatever agent invoked the profile (e.g. hunter at tier 2).
sweep: Iterable of SweepConfigs to measure. Defaults to the curated
11-config sweep documented in the module header.
drift_metric: "mse_normalised" (recommended) or "kl_softmax"
(provisional β see compute_drift docstring).
max_seq_len: Tokenizer truncation length and the max_seq_len_observed
column for emitted rows.
pipeline_version: Override only when running a deliberately
distinct lineage (e.g. comparing two metric implementations).
strategy: "joint" (default) hooks every layer at the same config in
one forward pass β fast, cumulative cascade drift. "isolated"
loops per (layer, config) β slower, marginal per-layer drift.
See module docstring for the empirical equivalence finding.
progress_cb: Optional callback for progress messages.
Returns:
List of ProfileRow objects, one per (layer, sweep_config) pair.
Caller persists these into kv_sensitivity_profile.
Raises:
RuntimeError: If the model doesn't expose Llama-family attention.
Any tokenizer / forward-pass errors propagate.
"""
# Local imports β keeps module importable without torch present.
import torch # noqa: F401
from kv_intercept import (
KVQuantSpec,
find_attention_modules,
kv_quant_active,
kv_quant_active_multi,
)
if strategy not in ("joint", "isolated"):
raise ValueError(f"unknown strategy: {strategy!r}; want 'joint' or 'isolated'")
def log(msg: str) -> None:
if progress_cb:
progress_cb(msg)
else:
logger.info(msg)
sweep_list = list(sweep)
if not sweep_list:
raise ValueError("Empty sweep β at least one SweepConfig required")
if not calibration_texts:
raise ValueError("No calibration_texts provided")
# ββ 1. Locate attention modules βββββββββββββββββββββββββββββββββββββ
attn_modules = find_attention_modules(model)
n_layers = len(attn_modules)
log(f"[kv-prof] discovered {n_layers} attention layers")
# ββ 2. Tokenize calibration set βββββββββββββββββββββββββββββββββββββ
batch = tokenizer(
calibration_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_seq_len,
).to(model.device)
actual_seq_len = batch.input_ids.shape[1]
log(f"[kv-prof] tokenised {len(calibration_texts)} prompts, "
f"seq_len={actual_seq_len}")
# ββ 3. Compute calibration hash βββββββββββββββββββββββββββββββββββββ
calibration_hash = compute_calibration_hash(calibration_texts, max_seq_len)
log(f"[kv-prof] calibration_hash={calibration_hash}")
# ββ 4. Capture full-precision baseline attention outputs ββββββββββββ
log("[kv-prof] capturing fp baseline attention outputs (1 forward pass)")
t_start = time.time()
baseline_outputs = _capture_attn_outputs(model, attn_modules, batch)
log(f"[kv-prof] baseline captured in {time.time() - t_start:.1f}s")
if len(baseline_outputs) != n_layers:
logger.warning(
"Captured baselines for %d/%d layers β some hooks didn't fire",
len(baseline_outputs), n_layers,
)
# ββ 5. Per-layer dimensions βββββββββββββββββββββββββββββββββββββββββ
# Pulled from model.config β Llama-family models report these directly.
# If a model has per-layer variation (very rare), this is the first
# place to patch.
num_kv_heads = getattr(
model.config, "num_key_value_heads",
getattr(model.config, "num_attention_heads", 1),
)
head_dim = getattr(
model.config, "head_dim", None,
) or (model.config.hidden_size // model.config.num_attention_heads)
log(f"[kv-prof] num_kv_heads={num_kv_heads}, head_dim={head_dim}")
# ββ 6. Sweep loop βββββββββββββββββββββββββββββββββββββββββββββββββββ
# joint: outer = configs, inner = (one multi-layer hook + 1 forward).
# Total len(sweep) forwards. Drift is cumulative cascade.
# isolated: outer = configs, middle = layers, inner = (one single-layer
# hook + 1 forward). Total len(sweep) * n_layers forwards.
# Drift is marginal per-layer.
rows: list[ProfileRow] = []
n_configs = len(sweep_list)
if strategy == "joint":
log(f"[kv-prof] strategy=joint: {n_configs} configs Γ {n_layers} layers "
f"= {n_configs * n_layers} measurements ({n_configs} forward passes)")
else:
log(f"[kv-prof] strategy=isolated: {n_configs} configs Γ {n_layers} layers "
f"= {n_configs * n_layers} measurements "
f"({n_configs * n_layers} forward passes)")
profiled_at = datetime.now(UTC).isoformat()
def _record(li: int, cfg: SweepConfig, bpt: float, quanted_output) -> None:
drift = compute_drift(
quanted_output,
baseline_outputs[li],
drift_metric,
)
rows.append(ProfileRow(
model_hash=model_hash,
calibration_hash=calibration_hash,
pipeline_version=pipeline_version,
layer_idx=li,
k_bits=cfg.k_bits,
v_bits=cfg.v_bits,
quantizer=cfg.quantizer,
drift_attn_output=float(drift),
drift_metric=drift_metric,
bytes_per_kv_token=float(bpt),
max_seq_len_observed=actual_seq_len,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
profiled_at=profiled_at,
profiled_by_agent_id=profiled_by_agent_id,
profiled_by_agent_tier=profiled_by_agent_tier,
))
for cfg_idx, cfg in enumerate(sweep_list):
log(f"[kv-prof] config {cfg_idx + 1}/{n_configs}: "
f"k={cfg.k_bits}b v={cfg.v_bits}b {cfg.quantizer}")
t_cfg = time.time()
spec = KVQuantSpec(
k_bits=cfg.k_bits,
v_bits=cfg.v_bits,
quantizer=cfg.quantizer,
group_size=cfg.group_size,
)
bpt = kv_bytes_per_token(
num_kv_heads=num_kv_heads,
head_dim=head_dim,
k_bits=cfg.k_bits,
v_bits=cfg.v_bits,
quantizer=cfg.quantizer,
group_size=cfg.group_size,
)
# Special case: fp16_passthrough drift is zero by definition for both strategies.
if cfg.quantizer == "fp16_passthrough":
for li in attn_modules:
if li in baseline_outputs:
_record(li, cfg, bpt, baseline_outputs[li])
log(f"[kv-prof] config done in {time.time() - t_cfg:.1f}s (passthrough)")
continue
if strategy == "joint":
specs_by_layer = {li: spec for li in attn_modules}
with kv_quant_active_multi(attn_modules, specs_by_layer):
quanted_outputs = _capture_attn_outputs(
model, attn_modules, batch,
)
for li in attn_modules:
if li not in quanted_outputs or li not in baseline_outputs:
logger.debug("layer %d: missing capture, skipping", li)
continue
_record(li, cfg, bpt, quanted_outputs[li])
else: # isolated
for li, attn in attn_modules.items():
with kv_quant_active(attn, spec):
captured = _capture_attn_outputs(
model, attn_modules, batch,
)
if li not in captured or li not in baseline_outputs:
logger.debug("layer %d: missing capture, skipping", li)
continue
_record(li, cfg, bpt, captured[li])
log(f"[kv-prof] config done in {time.time() - t_cfg:.1f}s, "
f"sample drift = {rows[-1].drift_attn_output:.4e}")
log(f"[kv-prof] complete β {len(rows)} rows total")
return rows
# ---------------------------------------------------------------------------
# Bridge to allocator
# ---------------------------------------------------------------------------
def rows_to_kv_candidates(rows: list[ProfileRow]):
"""Group ProfileRows by layer into KVCandidate objects that
assign_kv_bits (in assignment_v2.py) consumes directly.
Carries num_kv_heads and head_dim through from the rows so the
allocator gets the real per-layer dimensions, not placeholders.
"""
from assignment_v2 import KVCandidate, KVOption
from collections import defaultdict
by_layer: dict[int, list[ProfileRow]] = defaultdict(list)
for r in rows:
by_layer[r.layer_idx].append(r)
candidates = []
for layer_idx, layer_rows in sorted(by_layer.items()):
# All rows for one layer share num_kv_heads/head_dim by construction
first = layer_rows[0]
options = [
KVOption(
k_bits=r.k_bits,
v_bits=r.v_bits,
quantizer=r.quantizer,
drift=r.drift_attn_output,
bytes_per_kv_token=r.bytes_per_kv_token,
)
for r in layer_rows
]
candidates.append(KVCandidate(
layer_idx=layer_idx,
num_kv_heads=first.num_kv_heads,
head_dim=first.head_dim,
options=options,
))
return candidates
|