""" Sovereign Hive HSAQ — KV-cache sensitivity profiler ==================================================== Sweeps each attention layer across a curated 11-config grid of (k_bits, v_bits, quantizer) combinations and measures the resulting attention-output drift on calibration data. Emits ProfileRow objects ready for insertion into the `kv_sensitivity_profile` Vault table (migration 003). Why this file exists: The KV cache is the bigger lever than weight quantization for MHA models on 12 GB cards (e.g. OLMo: ~3.3 GB KV at 4K ctx fp16, more than the weights). The allocator needs measured drift values to pick which layers can tolerate aggressive KV quant. This profiler is what feeds it. Pure-logic conventions (matching candidate_record.py): - No Vault writes. Caller persists the emitted ProfileRows. - No PermissionGate routing here. Caller (the hunter agent) handles that. - All disk/network I/O is the caller's job. We accept a loaded model and tokenizer; we return data structures. What this DOES touch: - GPU forward passes (calls `model(**batch)` under torch.no_grad). - kv_intercept.kv_quant_active context manager to install hooks. - Hook lifecycle (always cleaned up via context manager — verified in smoke_test_v3.py test #6). What this does NOT do: - It does not handle inference-queue gating. If two profilers run concurrently on the same GPU, you'll OOM. Caller must serialize. - It does not write to disk. Output is in-memory ProfileRow objects. - It does not bump pipeline_version. The constant lives here as a reference value; callers wanting a different lineage pass their own. Config sweep rationale (11 per layer): Full Cartesian (5 k_bits × 5 v_bits × 4 quantizers) = 100/layer is way overkill — most combinations carry no signal the allocator can use. This curated 11 covers the decision boundaries that actually matter: Symmetric K=V (the common case): (8,8), (4,4), (3,3), (2,2) hqq_g64 K-cheaper-than-V (K more tolerant): (4,8), (3,8), (2,8), (2,4) hqq_g64 Quantizer comparison: (4,4) scaled_uniform (4,4) scaled_per_head (8,8) scaled_uniform Strategy: "joint" (default) vs "isolated" "joint" hooks every layer at the same config in ONE forward pass via kv_quant_active_multi. Total cost: 11 forwards. Drift values reflect CUMULATIVE CASCADE — each layer's attention sees the quanted output of earlier layers' attention. This is realistic to deployment. "isolated" hooks ONE layer at a time, runs a forward per (layer, config) pair. Total cost: 11 × n_layers forwards. Drift values reflect MARGINAL per-layer sensitivity — what happens when only that layer is quanted. Empirically (A100, OLMo-2-13B, 32 prompts, 1 GB KV budget) — see mxguru1/hsaq-strategy-comparison/report.json: - joint took 64.3 s, isolated took 665.6 s (10.4× slower) - isolated drift values are uniformly LOWER (ratios 0.12-0.61 by config) because the joint pass compounds drift through the residual stream - BUT both strategies produced IDENTICAL per-layer allocations (40/40 agreement) when piped through assign_kv_bits — the greedy allocator only cares about pairwise drift-per-byte ratios, and the ordering of configs by drift is preserved across strategies Default is "joint" because: same allocation, 10× faster, drift values are cascade-realistic. Switch to "isolated" only when you want the marginal per-layer signal for analysis (e.g. ranking pruning candidates, understanding which layers cause the most ripple downstream). """ from __future__ import annotations import hashlib import json import logging import time from dataclasses import dataclass, field from datetime import UTC, datetime from typing import Callable, Iterable, Literal, Optional logger = logging.getLogger("HSAQ.KVProfiler") # Bump when changing the config sweep, drift metric implementation, or # hook semantics. Cached rows under earlier versions become lookup-misses # so they're safely ignored rather than silently reused. PIPELINE_VERSION = "1.0.0" # --------------------------------------------------------------------------- # Types # --------------------------------------------------------------------------- DriftMetric = Literal["mse_normalised", "kl_softmax"] KVQuantizer = Literal["hqq_g64", "scaled_uniform", "scaled_per_head", "fp16_passthrough"] @dataclass(frozen=True) class SweepConfig: """One (k_bits, v_bits, quantizer) probe to measure on every layer.""" k_bits: int v_bits: int quantizer: KVQuantizer group_size: int = 64 @dataclass class ProfileRow: """One row of measured KV sensitivity. Shape matches the Vault table kv_sensitivity_profile (migration 003) exactly — caller can insert this dict directly or pass through a Vault adapter.""" model_hash: str calibration_hash: str pipeline_version: str layer_idx: int k_bits: int v_bits: int quantizer: str drift_attn_output: float drift_metric: str bytes_per_kv_token: float max_seq_len_observed: int num_kv_heads: int head_dim: int profiled_at: str profiled_by_agent_id: str profiled_by_agent_tier: int def to_vault_payload(self) -> dict: """Row-shaped dict ready for INSERT into kv_sensitivity_profile.""" return { "model_hash": self.model_hash, "calibration_hash": self.calibration_hash, "pipeline_version": self.pipeline_version, "layer_idx": self.layer_idx, "k_bits": self.k_bits, "v_bits": self.v_bits, "quantizer": self.quantizer, "drift_attn_output": self.drift_attn_output, "drift_metric": self.drift_metric, "bytes_per_kv_token": self.bytes_per_kv_token, "max_seq_len_observed": self.max_seq_len_observed, "num_kv_heads": self.num_kv_heads, "head_dim": self.head_dim, "profiled_at": self.profiled_at, "profiled_by_agent_id": self.profiled_by_agent_id, "profiled_by_agent_tier": self.profiled_by_agent_tier, } # --------------------------------------------------------------------------- # Default sweep — the curated 11 configs # --------------------------------------------------------------------------- DEFAULT_SWEEP: tuple[SweepConfig, ...] = ( # ── Symmetric K=V (most common case, 4 configs) ────────────────────── SweepConfig(k_bits=8, v_bits=8, quantizer="hqq_g64"), SweepConfig(k_bits=4, v_bits=4, quantizer="hqq_g64"), SweepConfig(k_bits=3, v_bits=3, quantizer="hqq_g64"), SweepConfig(k_bits=2, v_bits=2, quantizer="hqq_g64"), # ── K-cheaper-than-V (K is more error-tolerant, 4 configs) ─────────── SweepConfig(k_bits=4, v_bits=8, quantizer="hqq_g64"), SweepConfig(k_bits=3, v_bits=8, quantizer="hqq_g64"), SweepConfig(k_bits=2, v_bits=8, quantizer="hqq_g64"), SweepConfig(k_bits=2, v_bits=4, quantizer="hqq_g64"), # ── Quantizer comparison (3 configs) ────────────────────────────────── SweepConfig(k_bits=4, v_bits=4, quantizer="scaled_uniform"), SweepConfig(k_bits=4, v_bits=4, quantizer="scaled_per_head"), SweepConfig(k_bits=8, v_bits=8, quantizer="scaled_uniform"), ) # --------------------------------------------------------------------------- # Bytes-per-KV-token accounting # --------------------------------------------------------------------------- def kv_bytes_per_token( num_kv_heads: int, head_dim: int, k_bits: int, v_bits: int, quantizer: KVQuantizer, group_size: int = 64, ) -> float: """Bytes per token of KV cache for one layer at the given config. Includes quantizer-specific overhead (scales / zeros). The allocator reads this value cached in the Vault rather than recomputing, so the formula here is the single source of truth — change it, bump PIPELINE_VERSION. """ elems = num_kv_heads * head_dim if quantizer == "fp16_passthrough": # 2 bytes/elem × K and V both return 2 * elems * 2 # K + V at fp16 (2 bytes/elem) if quantizer == "scaled_uniform": # One fp16 scale per row of K and V each k_payload = elems * k_bits / 8 v_payload = elems * v_bits / 8 return (k_payload + 2) + (v_payload + 2) if quantizer == "scaled_per_head": # fp16 scale per head, per K/V k_payload = elems * k_bits / 8 v_payload = elems * v_bits / 8 return (k_payload + num_kv_heads * 2) + (v_payload + num_kv_heads * 2) if quantizer == "hqq_g64": # Groups along the head_dim axis (per row of the KV cache). # NOTE: this is the corrected accounting — earlier stub grouped # along (num_kv_heads × head_dim) which underestimated overhead. groups_per_row = max(1, head_dim // group_size) # zero + scale per group, fp16 each = 4 bytes per group overhead_per_row = num_kv_heads * groups_per_row * 4 k_payload = elems * k_bits / 8 v_payload = elems * v_bits / 8 return (k_payload + overhead_per_row) + (v_payload + overhead_per_row) raise ValueError(f"unknown quantizer: {quantizer}") # --------------------------------------------------------------------------- # Drift metrics # --------------------------------------------------------------------------- def compute_drift(quanted, baseline, metric: DriftMetric) -> float: """Compute attention-output drift between quanted and fp16 baseline. mse_normalised: ((quanted - baseline)^2).mean() / baseline.abs().mean()^2 Scale-invariant. Recommended for allocator inputs. KIVI-style. kl_softmax: KL divergence treating attention outputs as logits over the last dim. PROVISIONAL — attention output isn't a distribution, so this is a proxy metric only. Kept for compatibility with the Vault schema; do NOT use as allocator input without validation. """ import torch import torch.nn.functional as F if metric == "mse_normalised": mse = ((quanted - baseline) ** 2).mean().item() denom = (baseline ** 2).mean().item() return mse / max(denom, 1e-8) if metric == "kl_softmax": a = F.log_softmax(quanted.float().flatten(0, -2), dim=-1) b = F.softmax(baseline.float().flatten(0, -2), dim=-1) return float(F.kl_div(a, b, reduction="batchmean")) raise ValueError(f"unknown drift metric: {metric}") # --------------------------------------------------------------------------- # Calibration hash # --------------------------------------------------------------------------- def compute_calibration_hash( calibration_texts: list[str], max_seq_len: int, ) -> str: """Deterministic hash of calibration inputs for cache invalidation.""" payload = json.dumps({ "n_texts": len(calibration_texts), "max_seq_len": max_seq_len, # First/last text content matters for distinguishing calibration sets; # hash a sample rather than every text to keep this cheap. "first_text": calibration_texts[0][:200] if calibration_texts else "", "last_text": calibration_texts[-1][:200] if calibration_texts else "", "total_chars": sum(len(t) for t in calibration_texts), }, sort_keys=True) return hashlib.sha256(payload.encode()).hexdigest()[:16] # --------------------------------------------------------------------------- # Per-layer attention-output capture # --------------------------------------------------------------------------- def _capture_attn_outputs( model, attn_modules_by_layer: dict[int, object], batch, ) -> dict[int, "torch.Tensor"]: """Run one forward pass with hooks that capture attention output per layer. HF attention modules return either `attn_output` directly or a tuple starting with `attn_output`. We grab the first tensor element in either case. """ import torch captured: dict[int, torch.Tensor] = {} handles: list = [] def make_hook(li: int): def hook(_mod, _inp, out): t = out[0] if isinstance(out, tuple) else out if isinstance(t, torch.Tensor): # Move to CPU to bound GPU memory across all captured layers captured[li] = t.detach().to("cpu") return hook try: for li, attn in attn_modules_by_layer.items(): handles.append(attn.register_forward_hook(make_hook(li))) with torch.no_grad(): model(**batch, use_cache=False) finally: for h in handles: h.remove() return captured # --------------------------------------------------------------------------- # Profiler # --------------------------------------------------------------------------- def profile_kv_sensitivity( model, tokenizer, calibration_texts: list[str], *, model_hash: str, profiled_by_agent_id: str, profiled_by_agent_tier: int, sweep: Iterable[SweepConfig] = DEFAULT_SWEEP, drift_metric: DriftMetric = "mse_normalised", max_seq_len: int = 1024, pipeline_version: str = PIPELINE_VERSION, strategy: Literal["joint", "isolated"] = "joint", progress_cb: Optional[Callable[[str], None]] = None, ) -> list[ProfileRow]: """Measure attention-output drift per (layer × sweep config). Args: model: HF model in eval mode on its target device. tokenizer: Matching HF tokenizer. calibration_texts: List of natural-language strings to feed through the model. Length controls statistical robustness; quality of the distribution match controls allocator decisions on out-of- distribution data. 32-256 samples is the sweet spot. model_hash: Deterministic identifier of the model (caller-computed). Used as the first part of the cache invalidation key. profiled_by_agent_id, profiled_by_agent_tier: Audit chain. Caller is whatever agent invoked the profile (e.g. hunter at tier 2). sweep: Iterable of SweepConfigs to measure. Defaults to the curated 11-config sweep documented in the module header. drift_metric: "mse_normalised" (recommended) or "kl_softmax" (provisional — see compute_drift docstring). max_seq_len: Tokenizer truncation length and the max_seq_len_observed column for emitted rows. pipeline_version: Override only when running a deliberately distinct lineage (e.g. comparing two metric implementations). strategy: "joint" (default) hooks every layer at the same config in one forward pass — fast, cumulative cascade drift. "isolated" loops per (layer, config) — slower, marginal per-layer drift. See module docstring for the empirical equivalence finding. progress_cb: Optional callback for progress messages. Returns: List of ProfileRow objects, one per (layer, sweep_config) pair. Caller persists these into kv_sensitivity_profile. Raises: RuntimeError: If the model doesn't expose Llama-family attention. Any tokenizer / forward-pass errors propagate. """ # Local imports — keeps module importable without torch present. import torch # noqa: F401 from kv_intercept import ( KVQuantSpec, find_attention_modules, kv_quant_active, kv_quant_active_multi, ) if strategy not in ("joint", "isolated"): raise ValueError(f"unknown strategy: {strategy!r}; want 'joint' or 'isolated'") def log(msg: str) -> None: if progress_cb: progress_cb(msg) else: logger.info(msg) sweep_list = list(sweep) if not sweep_list: raise ValueError("Empty sweep — at least one SweepConfig required") if not calibration_texts: raise ValueError("No calibration_texts provided") # ── 1. Locate attention modules ───────────────────────────────────── attn_modules = find_attention_modules(model) n_layers = len(attn_modules) log(f"[kv-prof] discovered {n_layers} attention layers") # ── 2. Tokenize calibration set ───────────────────────────────────── batch = tokenizer( calibration_texts, return_tensors="pt", padding=True, truncation=True, max_length=max_seq_len, ).to(model.device) actual_seq_len = batch.input_ids.shape[1] log(f"[kv-prof] tokenised {len(calibration_texts)} prompts, " f"seq_len={actual_seq_len}") # ── 3. Compute calibration hash ───────────────────────────────────── calibration_hash = compute_calibration_hash(calibration_texts, max_seq_len) log(f"[kv-prof] calibration_hash={calibration_hash}") # ── 4. Capture full-precision baseline attention outputs ──────────── log("[kv-prof] capturing fp baseline attention outputs (1 forward pass)") t_start = time.time() baseline_outputs = _capture_attn_outputs(model, attn_modules, batch) log(f"[kv-prof] baseline captured in {time.time() - t_start:.1f}s") if len(baseline_outputs) != n_layers: logger.warning( "Captured baselines for %d/%d layers — some hooks didn't fire", len(baseline_outputs), n_layers, ) # ── 5. Per-layer dimensions ───────────────────────────────────────── # Pulled from model.config — Llama-family models report these directly. # If a model has per-layer variation (very rare), this is the first # place to patch. num_kv_heads = getattr( model.config, "num_key_value_heads", getattr(model.config, "num_attention_heads", 1), ) head_dim = getattr( model.config, "head_dim", None, ) or (model.config.hidden_size // model.config.num_attention_heads) log(f"[kv-prof] num_kv_heads={num_kv_heads}, head_dim={head_dim}") # ── 6. Sweep loop ─────────────────────────────────────────────────── # joint: outer = configs, inner = (one multi-layer hook + 1 forward). # Total len(sweep) forwards. Drift is cumulative cascade. # isolated: outer = configs, middle = layers, inner = (one single-layer # hook + 1 forward). Total len(sweep) * n_layers forwards. # Drift is marginal per-layer. rows: list[ProfileRow] = [] n_configs = len(sweep_list) if strategy == "joint": log(f"[kv-prof] strategy=joint: {n_configs} configs × {n_layers} layers " f"= {n_configs * n_layers} measurements ({n_configs} forward passes)") else: log(f"[kv-prof] strategy=isolated: {n_configs} configs × {n_layers} layers " f"= {n_configs * n_layers} measurements " f"({n_configs * n_layers} forward passes)") profiled_at = datetime.now(UTC).isoformat() def _record(li: int, cfg: SweepConfig, bpt: float, quanted_output) -> None: drift = compute_drift( quanted_output, baseline_outputs[li], drift_metric, ) rows.append(ProfileRow( model_hash=model_hash, calibration_hash=calibration_hash, pipeline_version=pipeline_version, layer_idx=li, k_bits=cfg.k_bits, v_bits=cfg.v_bits, quantizer=cfg.quantizer, drift_attn_output=float(drift), drift_metric=drift_metric, bytes_per_kv_token=float(bpt), max_seq_len_observed=actual_seq_len, num_kv_heads=num_kv_heads, head_dim=head_dim, profiled_at=profiled_at, profiled_by_agent_id=profiled_by_agent_id, profiled_by_agent_tier=profiled_by_agent_tier, )) for cfg_idx, cfg in enumerate(sweep_list): log(f"[kv-prof] config {cfg_idx + 1}/{n_configs}: " f"k={cfg.k_bits}b v={cfg.v_bits}b {cfg.quantizer}") t_cfg = time.time() spec = KVQuantSpec( k_bits=cfg.k_bits, v_bits=cfg.v_bits, quantizer=cfg.quantizer, group_size=cfg.group_size, ) bpt = kv_bytes_per_token( num_kv_heads=num_kv_heads, head_dim=head_dim, k_bits=cfg.k_bits, v_bits=cfg.v_bits, quantizer=cfg.quantizer, group_size=cfg.group_size, ) # Special case: fp16_passthrough drift is zero by definition for both strategies. if cfg.quantizer == "fp16_passthrough": for li in attn_modules: if li in baseline_outputs: _record(li, cfg, bpt, baseline_outputs[li]) log(f"[kv-prof] config done in {time.time() - t_cfg:.1f}s (passthrough)") continue if strategy == "joint": specs_by_layer = {li: spec for li in attn_modules} with kv_quant_active_multi(attn_modules, specs_by_layer): quanted_outputs = _capture_attn_outputs( model, attn_modules, batch, ) for li in attn_modules: if li not in quanted_outputs or li not in baseline_outputs: logger.debug("layer %d: missing capture, skipping", li) continue _record(li, cfg, bpt, quanted_outputs[li]) else: # isolated for li, attn in attn_modules.items(): with kv_quant_active(attn, spec): captured = _capture_attn_outputs( model, attn_modules, batch, ) if li not in captured or li not in baseline_outputs: logger.debug("layer %d: missing capture, skipping", li) continue _record(li, cfg, bpt, captured[li]) log(f"[kv-prof] config done in {time.time() - t_cfg:.1f}s, " f"sample drift = {rows[-1].drift_attn_output:.4e}") log(f"[kv-prof] complete — {len(rows)} rows total") return rows # --------------------------------------------------------------------------- # Bridge to allocator # --------------------------------------------------------------------------- def rows_to_kv_candidates(rows: list[ProfileRow]): """Group ProfileRows by layer into KVCandidate objects that assign_kv_bits (in assignment_v2.py) consumes directly. Carries num_kv_heads and head_dim through from the rows so the allocator gets the real per-layer dimensions, not placeholders. """ from assignment_v2 import KVCandidate, KVOption from collections import defaultdict by_layer: dict[int, list[ProfileRow]] = defaultdict(list) for r in rows: by_layer[r.layer_idx].append(r) candidates = [] for layer_idx, layer_rows in sorted(by_layer.items()): # All rows for one layer share num_kv_heads/head_dim by construction first = layer_rows[0] options = [ KVOption( k_bits=r.k_bits, v_bits=r.v_bits, quantizer=r.quantizer, drift=r.drift_attn_output, bytes_per_kv_token=r.bytes_per_kv_token, ) for r in layer_rows ] candidates.append(KVCandidate( layer_idx=layer_idx, num_kv_heads=first.num_kv_heads, head_dim=first.head_dim, options=options, )) return candidates