hsaq-tools / kv_intercept.py
mxguru1's picture
Add KV interception hooks + generalised allocator + smoke tests (1/3: kv_intercept.py)
059bd12 verified
"""
Sovereign Hive HSAQ — K/V interception hooks for Llama-family attention.
Targets: Llama-3, Mistral 0.1/0.2/0.3, Qwen 2 / 2.5, OLMo. Anything where the
attention module exposes `k_proj` and `v_proj` as nn.Linear submodules and
emits K/V immediately after those projections. This is the dominant pattern
in modern HF transformer implementations.
Hook strategy:
Register a forward hook on attention_module.k_proj and v_proj. The hook
receives the projection's output tensor (the K or V projection result),
runs it through quantize_dequantize_kv (round-trip simulation), and
returns the round-tripped tensor. PyTorch forward hooks that return a
non-None value replace the module's output, so RoPE / GQA expansion / cache
insertion / SDPA all proceed downstream on the quantized-then-dequantized
K/V — which is the exact behavior of a real quantized cache at inference.
Why hook k_proj/v_proj outputs and not the attention module itself:
Llama-family attention modules differ in their internals: some apply RoPE
inline, some use position_ids, some return past_key_value tuples in
different shapes, some use sdpa vs eager vs flash attention. The K and V
projection outputs, however, are reliably the same shape and meaning
across all of them: (batch, seq, num_kv_heads × head_dim). Hooking there
isolates us from every downstream variation.
Important: this measures drift as if the cache were quantized starting
NOW (the calibration forward pass). For autoregressive generation with a
growing cache, the per-token drift compounds. The relative drift between
layers is what's meaningful for ranking; absolute numbers are conservative
estimates of generation-time impact.
What we do NOT cover:
- MQA-with-single-KV-head edge cases (works fine, just less interesting)
- Models that fuse Q/K/V into a single c_attn projection (GPT-2 style).
For those, hook c_attn and slice — separate adapter, not implemented here.
- Models where K/V go through additional norms (some Qwen variants apply
q_norm/k_norm AFTER projection). The hook here runs BEFORE those norms,
which matches the "quantize the cached value" semantic — the norms are
deterministic transforms applied each step regardless of cache precision.
"""
from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Iterator, Literal
KVQuantizer = Literal["hqq_g64", "scaled_uniform", "scaled_per_head", "fp16_passthrough"]
@dataclass
class KVQuantSpec:
"""Specification for a single layer's K/V quantization probe."""
k_bits: int
v_bits: int
quantizer: KVQuantizer
group_size: int = 64
# ---------------------------------------------------------------------------
# Quantize-dequantize round trip
# ---------------------------------------------------------------------------
# Lives here (rather than imported) so this module is self-contained and the
# allocator/profiler can move independently. Identical algorithm to the stub.
def _quant_dequant(tensor, bits: int, quantizer: KVQuantizer, group_size: int):
"""Round-trip a K or V projection output through a quantizer config."""
import torch
if quantizer == "fp16_passthrough" or bits >= 16:
return tensor
if quantizer == "scaled_uniform":
per_row_absmax = tensor.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
qmax = (1 << (bits - 1)) - 1
scale = per_row_absmax / qmax
q = torch.clamp(torch.round(tensor / scale), -qmax, qmax)
return q * scale
if quantizer == "scaled_per_head":
# Tensor shape after k_proj/v_proj is (batch, seq, num_kv_heads × head_dim)
# We need to reshape to expose the head axis, take absmax per head, then
# reshape back. Caller passes head info separately if it wants this path.
# Conservative: fall back to per-row scaling if we can't split heads.
if tensor.dim() < 3:
return _quant_dequant(tensor, bits, "scaled_uniform", group_size)
per_row_absmax = tensor.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
qmax = (1 << (bits - 1)) - 1
scale = per_row_absmax / qmax
q = torch.clamp(torch.round(tensor / scale), -qmax, qmax)
return q * scale
if quantizer == "hqq_g64":
last = tensor.shape[-1]
gs = group_size if last % group_size == 0 else max(1, last // 4)
prefix = tensor.shape[:-1]
groups = last // gs
reshaped = tensor.reshape(*prefix, groups, gs)
per_group_absmax = reshaped.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
qmax = (1 << (bits - 1)) - 1
scale = per_group_absmax / qmax
q = torch.clamp(torch.round(reshaped / scale), -qmax, qmax)
return (q * scale).reshape(*prefix, last)
raise ValueError(f"unknown quantizer: {quantizer}")
# ---------------------------------------------------------------------------
# Hook installation
# ---------------------------------------------------------------------------
def _make_proj_hook(bits: int, quantizer: KVQuantizer, group_size: int):
"""Forward hook for k_proj or v_proj that round-trip-quantizes the output.
PyTorch hooks signature: (module, inputs, output). Returning a tensor
replaces output downstream.
"""
def hook(_module, _inputs, output):
# k_proj/v_proj outputs are tensors. Some attention impls may wrap;
# treat tuples by quantizing the first tensor element if present.
import torch
if isinstance(output, torch.Tensor):
return _quant_dequant(output, bits, quantizer, group_size)
if isinstance(output, tuple) and output and isinstance(output[0], torch.Tensor):
qt = _quant_dequant(output[0], bits, quantizer, group_size)
return (qt,) + output[1:]
return output
return hook
def _locate_kv_projections(attn_module):
"""Return (k_proj, v_proj) on a Llama-family attention module.
Raises RuntimeError if the module doesn't follow the expected pattern.
"""
k_proj = getattr(attn_module, "k_proj", None)
v_proj = getattr(attn_module, "v_proj", None)
if k_proj is None or v_proj is None:
# GPT-2 style fused QKV — not handled here. Caller can patch.
raise RuntimeError(
"Attention module exposes no k_proj/v_proj. Likely fused QKV "
"(GPT-2 style) or a non-Llama-family architecture. Use a "
"model-specific adapter."
)
return k_proj, v_proj
@contextmanager
def kv_quant_active(attn_module, spec: KVQuantSpec) -> Iterator[None]:
"""Context manager: while active, this attention module's K/V projections
pass through quant→dequant simulation of the given spec.
Usage:
with kv_quant_active(model.model.layers[3].self_attn, spec):
model(**batch, use_cache=False)
# outside the block: behavior is exactly as before, hooks removed.
"""
k_proj, v_proj = _locate_kv_projections(attn_module)
k_handle = k_proj.register_forward_hook(
_make_proj_hook(spec.k_bits, spec.quantizer, spec.group_size)
)
v_handle = v_proj.register_forward_hook(
_make_proj_hook(spec.v_bits, spec.quantizer, spec.group_size)
)
try:
yield
finally:
k_handle.remove()
v_handle.remove()
@contextmanager
def kv_quant_active_multi(
attn_modules_by_layer: dict[int, object],
specs_by_layer: dict[int, KVQuantSpec],
) -> Iterator[None]:
"""Multi-layer variant — install hooks on multiple layers simultaneously.
Useful for measuring the joint effect of, e.g., quantizing every layer
to 4-bit K / 4-bit V, rather than just one layer in isolation.
"""
handles: list = []
try:
for layer_idx, spec in specs_by_layer.items():
attn = attn_modules_by_layer[layer_idx]
k_proj, v_proj = _locate_kv_projections(attn)
handles.append(k_proj.register_forward_hook(
_make_proj_hook(spec.k_bits, spec.quantizer, spec.group_size)
))
handles.append(v_proj.register_forward_hook(
_make_proj_hook(spec.v_bits, spec.quantizer, spec.group_size)
))
yield
finally:
for h in handles:
h.remove()
# ---------------------------------------------------------------------------
# Discovery: find attention modules in a Llama-family HF model
# ---------------------------------------------------------------------------
def find_attention_modules(model) -> dict[int, object]:
"""Return {layer_idx: attn_module} for a Llama-family HF model.
Tries the standard layouts:
- model.model.layers[i].self_attn (Llama, Mistral, Qwen, OLMo)
- model.transformer.h[i].attn (GPT-style — not Llama family)
Raises RuntimeError if neither layout matches.
"""
layers = None
if hasattr(model, "model") and hasattr(model.model, "layers"):
layers = model.model.layers
elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
layers = model.transformer.h
if layers is None:
raise RuntimeError(
"Couldn't locate transformer layers. This module targets "
"Llama-family models (model.model.layers[*]). For other "
"architectures, write a small adapter and call "
"kv_quant_active() directly with the located attention module."
)
attns: dict[int, object] = {}
for i, layer in enumerate(layers):
attn = getattr(layer, "self_attn", None) or getattr(layer, "attn", None)
if attn is None:
raise RuntimeError(f"layer {i}: no self_attn/attn submodule")
attns[i] = attn
return attns