hsaq-tools / kv_profiler.py
mxguru1's picture
Add strategy=joint|isolated flag + empirical findings in docstring (joint default, 10x faster, same allocation as isolated per A vs B comparison on OLMo-2-13B)
002163f verified
"""
Sovereign Hive HSAQ β€” KV-cache sensitivity profiler
====================================================
Sweeps each attention layer across a curated 11-config grid of
(k_bits, v_bits, quantizer) combinations and measures the resulting
attention-output drift on calibration data. Emits ProfileRow objects
ready for insertion into the `kv_sensitivity_profile` Vault table
(migration 003).
Why this file exists:
The KV cache is the bigger lever than weight quantization for MHA
models on 12 GB cards (e.g. OLMo: ~3.3 GB KV at 4K ctx fp16, more than
the weights). The allocator needs measured drift values to pick which
layers can tolerate aggressive KV quant. This profiler is what feeds it.
Pure-logic conventions (matching candidate_record.py):
- No Vault writes. Caller persists the emitted ProfileRows.
- No PermissionGate routing here. Caller (the hunter agent) handles that.
- All disk/network I/O is the caller's job. We accept a loaded model
and tokenizer; we return data structures.
What this DOES touch:
- GPU forward passes (calls `model(**batch)` under torch.no_grad).
- kv_intercept.kv_quant_active context manager to install hooks.
- Hook lifecycle (always cleaned up via context manager β€” verified in
smoke_test_v3.py test #6).
What this does NOT do:
- It does not handle inference-queue gating. If two profilers run
concurrently on the same GPU, you'll OOM. Caller must serialize.
- It does not write to disk. Output is in-memory ProfileRow objects.
- It does not bump pipeline_version. The constant lives here as a
reference value; callers wanting a different lineage pass their own.
Config sweep rationale (11 per layer):
Full Cartesian (5 k_bits Γ— 5 v_bits Γ— 4 quantizers) = 100/layer is way
overkill β€” most combinations carry no signal the allocator can use.
This curated 11 covers the decision boundaries that actually matter:
Symmetric K=V (the common case): (8,8), (4,4), (3,3), (2,2) hqq_g64
K-cheaper-than-V (K more tolerant): (4,8), (3,8), (2,8), (2,4) hqq_g64
Quantizer comparison: (4,4) scaled_uniform
(4,4) scaled_per_head
(8,8) scaled_uniform
Strategy: "joint" (default) vs "isolated"
"joint" hooks every layer at the same config in ONE forward pass via
kv_quant_active_multi. Total cost: 11 forwards. Drift values reflect
CUMULATIVE CASCADE β€” each layer's attention sees the quanted output of
earlier layers' attention. This is realistic to deployment.
"isolated" hooks ONE layer at a time, runs a forward per (layer, config)
pair. Total cost: 11 Γ— n_layers forwards. Drift values reflect MARGINAL
per-layer sensitivity β€” what happens when only that layer is quanted.
Empirically (A100, OLMo-2-13B, 32 prompts, 1 GB KV budget) β€” see
mxguru1/hsaq-strategy-comparison/report.json:
- joint took 64.3 s, isolated took 665.6 s (10.4Γ— slower)
- isolated drift values are uniformly LOWER (ratios 0.12-0.61 by config)
because the joint pass compounds drift through the residual stream
- BUT both strategies produced IDENTICAL per-layer allocations (40/40
agreement) when piped through assign_kv_bits β€” the greedy allocator
only cares about pairwise drift-per-byte ratios, and the ordering of
configs by drift is preserved across strategies
Default is "joint" because: same allocation, 10Γ— faster, drift values
are cascade-realistic. Switch to "isolated" only when you want the
marginal per-layer signal for analysis (e.g. ranking pruning candidates,
understanding which layers cause the most ripple downstream).
"""
from __future__ import annotations
import hashlib
import json
import logging
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Callable, Iterable, Literal, Optional
logger = logging.getLogger("HSAQ.KVProfiler")
# Bump when changing the config sweep, drift metric implementation, or
# hook semantics. Cached rows under earlier versions become lookup-misses
# so they're safely ignored rather than silently reused.
PIPELINE_VERSION = "1.0.0"
# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------
DriftMetric = Literal["mse_normalised", "kl_softmax"]
KVQuantizer = Literal["hqq_g64", "scaled_uniform", "scaled_per_head", "fp16_passthrough"]
@dataclass(frozen=True)
class SweepConfig:
"""One (k_bits, v_bits, quantizer) probe to measure on every layer."""
k_bits: int
v_bits: int
quantizer: KVQuantizer
group_size: int = 64
@dataclass
class ProfileRow:
"""One row of measured KV sensitivity. Shape matches the Vault table
kv_sensitivity_profile (migration 003) exactly β€” caller can insert
this dict directly or pass through a Vault adapter."""
model_hash: str
calibration_hash: str
pipeline_version: str
layer_idx: int
k_bits: int
v_bits: int
quantizer: str
drift_attn_output: float
drift_metric: str
bytes_per_kv_token: float
max_seq_len_observed: int
num_kv_heads: int
head_dim: int
profiled_at: str
profiled_by_agent_id: str
profiled_by_agent_tier: int
def to_vault_payload(self) -> dict:
"""Row-shaped dict ready for INSERT into kv_sensitivity_profile."""
return {
"model_hash": self.model_hash,
"calibration_hash": self.calibration_hash,
"pipeline_version": self.pipeline_version,
"layer_idx": self.layer_idx,
"k_bits": self.k_bits,
"v_bits": self.v_bits,
"quantizer": self.quantizer,
"drift_attn_output": self.drift_attn_output,
"drift_metric": self.drift_metric,
"bytes_per_kv_token": self.bytes_per_kv_token,
"max_seq_len_observed": self.max_seq_len_observed,
"num_kv_heads": self.num_kv_heads,
"head_dim": self.head_dim,
"profiled_at": self.profiled_at,
"profiled_by_agent_id": self.profiled_by_agent_id,
"profiled_by_agent_tier": self.profiled_by_agent_tier,
}
# ---------------------------------------------------------------------------
# Default sweep β€” the curated 11 configs
# ---------------------------------------------------------------------------
DEFAULT_SWEEP: tuple[SweepConfig, ...] = (
# ── Symmetric K=V (most common case, 4 configs) ──────────────────────
SweepConfig(k_bits=8, v_bits=8, quantizer="hqq_g64"),
SweepConfig(k_bits=4, v_bits=4, quantizer="hqq_g64"),
SweepConfig(k_bits=3, v_bits=3, quantizer="hqq_g64"),
SweepConfig(k_bits=2, v_bits=2, quantizer="hqq_g64"),
# ── K-cheaper-than-V (K is more error-tolerant, 4 configs) ───────────
SweepConfig(k_bits=4, v_bits=8, quantizer="hqq_g64"),
SweepConfig(k_bits=3, v_bits=8, quantizer="hqq_g64"),
SweepConfig(k_bits=2, v_bits=8, quantizer="hqq_g64"),
SweepConfig(k_bits=2, v_bits=4, quantizer="hqq_g64"),
# ── Quantizer comparison (3 configs) ──────────────────────────────────
SweepConfig(k_bits=4, v_bits=4, quantizer="scaled_uniform"),
SweepConfig(k_bits=4, v_bits=4, quantizer="scaled_per_head"),
SweepConfig(k_bits=8, v_bits=8, quantizer="scaled_uniform"),
)
# ---------------------------------------------------------------------------
# Bytes-per-KV-token accounting
# ---------------------------------------------------------------------------
def kv_bytes_per_token(
num_kv_heads: int,
head_dim: int,
k_bits: int,
v_bits: int,
quantizer: KVQuantizer,
group_size: int = 64,
) -> float:
"""Bytes per token of KV cache for one layer at the given config.
Includes quantizer-specific overhead (scales / zeros). The allocator
reads this value cached in the Vault rather than recomputing, so the
formula here is the single source of truth β€” change it, bump
PIPELINE_VERSION.
"""
elems = num_kv_heads * head_dim
if quantizer == "fp16_passthrough":
# 2 bytes/elem Γ— K and V both
return 2 * elems * 2 # K + V at fp16 (2 bytes/elem)
if quantizer == "scaled_uniform":
# One fp16 scale per row of K and V each
k_payload = elems * k_bits / 8
v_payload = elems * v_bits / 8
return (k_payload + 2) + (v_payload + 2)
if quantizer == "scaled_per_head":
# fp16 scale per head, per K/V
k_payload = elems * k_bits / 8
v_payload = elems * v_bits / 8
return (k_payload + num_kv_heads * 2) + (v_payload + num_kv_heads * 2)
if quantizer == "hqq_g64":
# Groups along the head_dim axis (per row of the KV cache).
# NOTE: this is the corrected accounting β€” earlier stub grouped
# along (num_kv_heads Γ— head_dim) which underestimated overhead.
groups_per_row = max(1, head_dim // group_size)
# zero + scale per group, fp16 each = 4 bytes per group
overhead_per_row = num_kv_heads * groups_per_row * 4
k_payload = elems * k_bits / 8
v_payload = elems * v_bits / 8
return (k_payload + overhead_per_row) + (v_payload + overhead_per_row)
raise ValueError(f"unknown quantizer: {quantizer}")
# ---------------------------------------------------------------------------
# Drift metrics
# ---------------------------------------------------------------------------
def compute_drift(quanted, baseline, metric: DriftMetric) -> float:
"""Compute attention-output drift between quanted and fp16 baseline.
mse_normalised: ((quanted - baseline)^2).mean() / baseline.abs().mean()^2
Scale-invariant. Recommended for allocator inputs. KIVI-style.
kl_softmax: KL divergence treating attention outputs as logits over the
last dim. PROVISIONAL β€” attention output isn't a distribution, so
this is a proxy metric only. Kept for compatibility with the
Vault schema; do NOT use as allocator input without validation.
"""
import torch
import torch.nn.functional as F
if metric == "mse_normalised":
mse = ((quanted - baseline) ** 2).mean().item()
denom = (baseline ** 2).mean().item()
return mse / max(denom, 1e-8)
if metric == "kl_softmax":
a = F.log_softmax(quanted.float().flatten(0, -2), dim=-1)
b = F.softmax(baseline.float().flatten(0, -2), dim=-1)
return float(F.kl_div(a, b, reduction="batchmean"))
raise ValueError(f"unknown drift metric: {metric}")
# ---------------------------------------------------------------------------
# Calibration hash
# ---------------------------------------------------------------------------
def compute_calibration_hash(
calibration_texts: list[str],
max_seq_len: int,
) -> str:
"""Deterministic hash of calibration inputs for cache invalidation."""
payload = json.dumps({
"n_texts": len(calibration_texts),
"max_seq_len": max_seq_len,
# First/last text content matters for distinguishing calibration sets;
# hash a sample rather than every text to keep this cheap.
"first_text": calibration_texts[0][:200] if calibration_texts else "",
"last_text": calibration_texts[-1][:200] if calibration_texts else "",
"total_chars": sum(len(t) for t in calibration_texts),
}, sort_keys=True)
return hashlib.sha256(payload.encode()).hexdigest()[:16]
# ---------------------------------------------------------------------------
# Per-layer attention-output capture
# ---------------------------------------------------------------------------
def _capture_attn_outputs(
model,
attn_modules_by_layer: dict[int, object],
batch,
) -> dict[int, "torch.Tensor"]:
"""Run one forward pass with hooks that capture attention output per layer.
HF attention modules return either `attn_output` directly or a tuple
starting with `attn_output`. We grab the first tensor element in either
case.
"""
import torch
captured: dict[int, torch.Tensor] = {}
handles: list = []
def make_hook(li: int):
def hook(_mod, _inp, out):
t = out[0] if isinstance(out, tuple) else out
if isinstance(t, torch.Tensor):
# Move to CPU to bound GPU memory across all captured layers
captured[li] = t.detach().to("cpu")
return hook
try:
for li, attn in attn_modules_by_layer.items():
handles.append(attn.register_forward_hook(make_hook(li)))
with torch.no_grad():
model(**batch, use_cache=False)
finally:
for h in handles:
h.remove()
return captured
# ---------------------------------------------------------------------------
# Profiler
# ---------------------------------------------------------------------------
def profile_kv_sensitivity(
model,
tokenizer,
calibration_texts: list[str],
*,
model_hash: str,
profiled_by_agent_id: str,
profiled_by_agent_tier: int,
sweep: Iterable[SweepConfig] = DEFAULT_SWEEP,
drift_metric: DriftMetric = "mse_normalised",
max_seq_len: int = 1024,
pipeline_version: str = PIPELINE_VERSION,
strategy: Literal["joint", "isolated"] = "joint",
progress_cb: Optional[Callable[[str], None]] = None,
) -> list[ProfileRow]:
"""Measure attention-output drift per (layer Γ— sweep config).
Args:
model: HF model in eval mode on its target device.
tokenizer: Matching HF tokenizer.
calibration_texts: List of natural-language strings to feed through
the model. Length controls statistical robustness; quality of
the distribution match controls allocator decisions on out-of-
distribution data. 32-256 samples is the sweet spot.
model_hash: Deterministic identifier of the model (caller-computed).
Used as the first part of the cache invalidation key.
profiled_by_agent_id, profiled_by_agent_tier: Audit chain. Caller
is whatever agent invoked the profile (e.g. hunter at tier 2).
sweep: Iterable of SweepConfigs to measure. Defaults to the curated
11-config sweep documented in the module header.
drift_metric: "mse_normalised" (recommended) or "kl_softmax"
(provisional β€” see compute_drift docstring).
max_seq_len: Tokenizer truncation length and the max_seq_len_observed
column for emitted rows.
pipeline_version: Override only when running a deliberately
distinct lineage (e.g. comparing two metric implementations).
strategy: "joint" (default) hooks every layer at the same config in
one forward pass β€” fast, cumulative cascade drift. "isolated"
loops per (layer, config) β€” slower, marginal per-layer drift.
See module docstring for the empirical equivalence finding.
progress_cb: Optional callback for progress messages.
Returns:
List of ProfileRow objects, one per (layer, sweep_config) pair.
Caller persists these into kv_sensitivity_profile.
Raises:
RuntimeError: If the model doesn't expose Llama-family attention.
Any tokenizer / forward-pass errors propagate.
"""
# Local imports β€” keeps module importable without torch present.
import torch # noqa: F401
from kv_intercept import (
KVQuantSpec,
find_attention_modules,
kv_quant_active,
kv_quant_active_multi,
)
if strategy not in ("joint", "isolated"):
raise ValueError(f"unknown strategy: {strategy!r}; want 'joint' or 'isolated'")
def log(msg: str) -> None:
if progress_cb:
progress_cb(msg)
else:
logger.info(msg)
sweep_list = list(sweep)
if not sweep_list:
raise ValueError("Empty sweep β€” at least one SweepConfig required")
if not calibration_texts:
raise ValueError("No calibration_texts provided")
# ── 1. Locate attention modules ─────────────────────────────────────
attn_modules = find_attention_modules(model)
n_layers = len(attn_modules)
log(f"[kv-prof] discovered {n_layers} attention layers")
# ── 2. Tokenize calibration set ─────────────────────────────────────
batch = tokenizer(
calibration_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_seq_len,
).to(model.device)
actual_seq_len = batch.input_ids.shape[1]
log(f"[kv-prof] tokenised {len(calibration_texts)} prompts, "
f"seq_len={actual_seq_len}")
# ── 3. Compute calibration hash ─────────────────────────────────────
calibration_hash = compute_calibration_hash(calibration_texts, max_seq_len)
log(f"[kv-prof] calibration_hash={calibration_hash}")
# ── 4. Capture full-precision baseline attention outputs ────────────
log("[kv-prof] capturing fp baseline attention outputs (1 forward pass)")
t_start = time.time()
baseline_outputs = _capture_attn_outputs(model, attn_modules, batch)
log(f"[kv-prof] baseline captured in {time.time() - t_start:.1f}s")
if len(baseline_outputs) != n_layers:
logger.warning(
"Captured baselines for %d/%d layers β€” some hooks didn't fire",
len(baseline_outputs), n_layers,
)
# ── 5. Per-layer dimensions ─────────────────────────────────────────
# Pulled from model.config β€” Llama-family models report these directly.
# If a model has per-layer variation (very rare), this is the first
# place to patch.
num_kv_heads = getattr(
model.config, "num_key_value_heads",
getattr(model.config, "num_attention_heads", 1),
)
head_dim = getattr(
model.config, "head_dim", None,
) or (model.config.hidden_size // model.config.num_attention_heads)
log(f"[kv-prof] num_kv_heads={num_kv_heads}, head_dim={head_dim}")
# ── 6. Sweep loop ───────────────────────────────────────────────────
# joint: outer = configs, inner = (one multi-layer hook + 1 forward).
# Total len(sweep) forwards. Drift is cumulative cascade.
# isolated: outer = configs, middle = layers, inner = (one single-layer
# hook + 1 forward). Total len(sweep) * n_layers forwards.
# Drift is marginal per-layer.
rows: list[ProfileRow] = []
n_configs = len(sweep_list)
if strategy == "joint":
log(f"[kv-prof] strategy=joint: {n_configs} configs Γ— {n_layers} layers "
f"= {n_configs * n_layers} measurements ({n_configs} forward passes)")
else:
log(f"[kv-prof] strategy=isolated: {n_configs} configs Γ— {n_layers} layers "
f"= {n_configs * n_layers} measurements "
f"({n_configs * n_layers} forward passes)")
profiled_at = datetime.now(UTC).isoformat()
def _record(li: int, cfg: SweepConfig, bpt: float, quanted_output) -> None:
drift = compute_drift(
quanted_output,
baseline_outputs[li],
drift_metric,
)
rows.append(ProfileRow(
model_hash=model_hash,
calibration_hash=calibration_hash,
pipeline_version=pipeline_version,
layer_idx=li,
k_bits=cfg.k_bits,
v_bits=cfg.v_bits,
quantizer=cfg.quantizer,
drift_attn_output=float(drift),
drift_metric=drift_metric,
bytes_per_kv_token=float(bpt),
max_seq_len_observed=actual_seq_len,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
profiled_at=profiled_at,
profiled_by_agent_id=profiled_by_agent_id,
profiled_by_agent_tier=profiled_by_agent_tier,
))
for cfg_idx, cfg in enumerate(sweep_list):
log(f"[kv-prof] config {cfg_idx + 1}/{n_configs}: "
f"k={cfg.k_bits}b v={cfg.v_bits}b {cfg.quantizer}")
t_cfg = time.time()
spec = KVQuantSpec(
k_bits=cfg.k_bits,
v_bits=cfg.v_bits,
quantizer=cfg.quantizer,
group_size=cfg.group_size,
)
bpt = kv_bytes_per_token(
num_kv_heads=num_kv_heads,
head_dim=head_dim,
k_bits=cfg.k_bits,
v_bits=cfg.v_bits,
quantizer=cfg.quantizer,
group_size=cfg.group_size,
)
# Special case: fp16_passthrough drift is zero by definition for both strategies.
if cfg.quantizer == "fp16_passthrough":
for li in attn_modules:
if li in baseline_outputs:
_record(li, cfg, bpt, baseline_outputs[li])
log(f"[kv-prof] config done in {time.time() - t_cfg:.1f}s (passthrough)")
continue
if strategy == "joint":
specs_by_layer = {li: spec for li in attn_modules}
with kv_quant_active_multi(attn_modules, specs_by_layer):
quanted_outputs = _capture_attn_outputs(
model, attn_modules, batch,
)
for li in attn_modules:
if li not in quanted_outputs or li not in baseline_outputs:
logger.debug("layer %d: missing capture, skipping", li)
continue
_record(li, cfg, bpt, quanted_outputs[li])
else: # isolated
for li, attn in attn_modules.items():
with kv_quant_active(attn, spec):
captured = _capture_attn_outputs(
model, attn_modules, batch,
)
if li not in captured or li not in baseline_outputs:
logger.debug("layer %d: missing capture, skipping", li)
continue
_record(li, cfg, bpt, captured[li])
log(f"[kv-prof] config done in {time.time() - t_cfg:.1f}s, "
f"sample drift = {rows[-1].drift_attn_output:.4e}")
log(f"[kv-prof] complete β€” {len(rows)} rows total")
return rows
# ---------------------------------------------------------------------------
# Bridge to allocator
# ---------------------------------------------------------------------------
def rows_to_kv_candidates(rows: list[ProfileRow]):
"""Group ProfileRows by layer into KVCandidate objects that
assign_kv_bits (in assignment_v2.py) consumes directly.
Carries num_kv_heads and head_dim through from the rows so the
allocator gets the real per-layer dimensions, not placeholders.
"""
from assignment_v2 import KVCandidate, KVOption
from collections import defaultdict
by_layer: dict[int, list[ProfileRow]] = defaultdict(list)
for r in rows:
by_layer[r.layer_idx].append(r)
candidates = []
for layer_idx, layer_rows in sorted(by_layer.items()):
# All rows for one layer share num_kv_heads/head_dim by construction
first = layer_rows[0]
options = [
KVOption(
k_bits=r.k_bits,
v_bits=r.v_bits,
quantizer=r.quantizer,
drift=r.drift_attn_output,
bytes_per_kv_token=r.bytes_per_kv_token,
)
for r in layer_rows
]
candidates.append(KVCandidate(
layer_idx=layer_idx,
num_kv_heads=first.num_kv_heads,
head_dim=first.head_dim,
options=options,
))
return candidates