""" Sovereign Hive HSAQ — K/V interception hooks for Llama-family attention. Targets: Llama-3, Mistral 0.1/0.2/0.3, Qwen 2 / 2.5, OLMo. Anything where the attention module exposes `k_proj` and `v_proj` as nn.Linear submodules and emits K/V immediately after those projections. This is the dominant pattern in modern HF transformer implementations. Hook strategy: Register a forward hook on attention_module.k_proj and v_proj. The hook receives the projection's output tensor (the K or V projection result), runs it through quantize_dequantize_kv (round-trip simulation), and returns the round-tripped tensor. PyTorch forward hooks that return a non-None value replace the module's output, so RoPE / GQA expansion / cache insertion / SDPA all proceed downstream on the quantized-then-dequantized K/V — which is the exact behavior of a real quantized cache at inference. Why hook k_proj/v_proj outputs and not the attention module itself: Llama-family attention modules differ in their internals: some apply RoPE inline, some use position_ids, some return past_key_value tuples in different shapes, some use sdpa vs eager vs flash attention. The K and V projection outputs, however, are reliably the same shape and meaning across all of them: (batch, seq, num_kv_heads × head_dim). Hooking there isolates us from every downstream variation. Important: this measures drift as if the cache were quantized starting NOW (the calibration forward pass). For autoregressive generation with a growing cache, the per-token drift compounds. The relative drift between layers is what's meaningful for ranking; absolute numbers are conservative estimates of generation-time impact. What we do NOT cover: - MQA-with-single-KV-head edge cases (works fine, just less interesting) - Models that fuse Q/K/V into a single c_attn projection (GPT-2 style). For those, hook c_attn and slice — separate adapter, not implemented here. - Models where K/V go through additional norms (some Qwen variants apply q_norm/k_norm AFTER projection). The hook here runs BEFORE those norms, which matches the "quantize the cached value" semantic — the norms are deterministic transforms applied each step regardless of cache precision. """ from __future__ import annotations from contextlib import contextmanager from dataclasses import dataclass from typing import Iterator, Literal KVQuantizer = Literal["hqq_g64", "scaled_uniform", "scaled_per_head", "fp16_passthrough"] @dataclass class KVQuantSpec: """Specification for a single layer's K/V quantization probe.""" k_bits: int v_bits: int quantizer: KVQuantizer group_size: int = 64 # --------------------------------------------------------------------------- # Quantize-dequantize round trip # --------------------------------------------------------------------------- # Lives here (rather than imported) so this module is self-contained and the # allocator/profiler can move independently. Identical algorithm to the stub. def _quant_dequant(tensor, bits: int, quantizer: KVQuantizer, group_size: int): """Round-trip a K or V projection output through a quantizer config.""" import torch if quantizer == "fp16_passthrough" or bits >= 16: return tensor if quantizer == "scaled_uniform": per_row_absmax = tensor.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) qmax = (1 << (bits - 1)) - 1 scale = per_row_absmax / qmax q = torch.clamp(torch.round(tensor / scale), -qmax, qmax) return q * scale if quantizer == "scaled_per_head": # Tensor shape after k_proj/v_proj is (batch, seq, num_kv_heads × head_dim) # We need to reshape to expose the head axis, take absmax per head, then # reshape back. Caller passes head info separately if it wants this path. # Conservative: fall back to per-row scaling if we can't split heads. if tensor.dim() < 3: return _quant_dequant(tensor, bits, "scaled_uniform", group_size) per_row_absmax = tensor.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) qmax = (1 << (bits - 1)) - 1 scale = per_row_absmax / qmax q = torch.clamp(torch.round(tensor / scale), -qmax, qmax) return q * scale if quantizer == "hqq_g64": last = tensor.shape[-1] gs = group_size if last % group_size == 0 else max(1, last // 4) prefix = tensor.shape[:-1] groups = last // gs reshaped = tensor.reshape(*prefix, groups, gs) per_group_absmax = reshaped.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) qmax = (1 << (bits - 1)) - 1 scale = per_group_absmax / qmax q = torch.clamp(torch.round(reshaped / scale), -qmax, qmax) return (q * scale).reshape(*prefix, last) raise ValueError(f"unknown quantizer: {quantizer}") # --------------------------------------------------------------------------- # Hook installation # --------------------------------------------------------------------------- def _make_proj_hook(bits: int, quantizer: KVQuantizer, group_size: int): """Forward hook for k_proj or v_proj that round-trip-quantizes the output. PyTorch hooks signature: (module, inputs, output). Returning a tensor replaces output downstream. """ def hook(_module, _inputs, output): # k_proj/v_proj outputs are tensors. Some attention impls may wrap; # treat tuples by quantizing the first tensor element if present. import torch if isinstance(output, torch.Tensor): return _quant_dequant(output, bits, quantizer, group_size) if isinstance(output, tuple) and output and isinstance(output[0], torch.Tensor): qt = _quant_dequant(output[0], bits, quantizer, group_size) return (qt,) + output[1:] return output return hook def _locate_kv_projections(attn_module): """Return (k_proj, v_proj) on a Llama-family attention module. Raises RuntimeError if the module doesn't follow the expected pattern. """ k_proj = getattr(attn_module, "k_proj", None) v_proj = getattr(attn_module, "v_proj", None) if k_proj is None or v_proj is None: # GPT-2 style fused QKV — not handled here. Caller can patch. raise RuntimeError( "Attention module exposes no k_proj/v_proj. Likely fused QKV " "(GPT-2 style) or a non-Llama-family architecture. Use a " "model-specific adapter." ) return k_proj, v_proj @contextmanager def kv_quant_active(attn_module, spec: KVQuantSpec) -> Iterator[None]: """Context manager: while active, this attention module's K/V projections pass through quant→dequant simulation of the given spec. Usage: with kv_quant_active(model.model.layers[3].self_attn, spec): model(**batch, use_cache=False) # outside the block: behavior is exactly as before, hooks removed. """ k_proj, v_proj = _locate_kv_projections(attn_module) k_handle = k_proj.register_forward_hook( _make_proj_hook(spec.k_bits, spec.quantizer, spec.group_size) ) v_handle = v_proj.register_forward_hook( _make_proj_hook(spec.v_bits, spec.quantizer, spec.group_size) ) try: yield finally: k_handle.remove() v_handle.remove() @contextmanager def kv_quant_active_multi( attn_modules_by_layer: dict[int, object], specs_by_layer: dict[int, KVQuantSpec], ) -> Iterator[None]: """Multi-layer variant — install hooks on multiple layers simultaneously. Useful for measuring the joint effect of, e.g., quantizing every layer to 4-bit K / 4-bit V, rather than just one layer in isolation. """ handles: list = [] try: for layer_idx, spec in specs_by_layer.items(): attn = attn_modules_by_layer[layer_idx] k_proj, v_proj = _locate_kv_projections(attn) handles.append(k_proj.register_forward_hook( _make_proj_hook(spec.k_bits, spec.quantizer, spec.group_size) )) handles.append(v_proj.register_forward_hook( _make_proj_hook(spec.v_bits, spec.quantizer, spec.group_size) )) yield finally: for h in handles: h.remove() # --------------------------------------------------------------------------- # Discovery: find attention modules in a Llama-family HF model # --------------------------------------------------------------------------- def find_attention_modules(model) -> dict[int, object]: """Return {layer_idx: attn_module} for a Llama-family HF model. Tries the standard layouts: - model.model.layers[i].self_attn (Llama, Mistral, Qwen, OLMo) - model.transformer.h[i].attn (GPT-style — not Llama family) Raises RuntimeError if neither layout matches. """ layers = None if hasattr(model, "model") and hasattr(model.model, "layers"): layers = model.model.layers elif hasattr(model, "transformer") and hasattr(model.transformer, "h"): layers = model.transformer.h if layers is None: raise RuntimeError( "Couldn't locate transformer layers. This module targets " "Llama-family models (model.model.layers[*]). For other " "architectures, write a small adapter and call " "kv_quant_active() directly with the located attention module." ) attns: dict[int, object] = {} for i, layer in enumerate(layers): attn = getattr(layer, "self_attn", None) or getattr(layer, "attn", None) if attn is None: raise RuntimeError(f"layer {i}: no self_attn/attn submodule") attns[i] = attn return attns