| """ |
| Sovereign Hive HSAQ — K/V interception hooks for Llama-family attention. |
| |
| Targets: Llama-3, Mistral 0.1/0.2/0.3, Qwen 2 / 2.5, OLMo. Anything where the |
| attention module exposes `k_proj` and `v_proj` as nn.Linear submodules and |
| emits K/V immediately after those projections. This is the dominant pattern |
| in modern HF transformer implementations. |
| |
| Hook strategy: |
| Register a forward hook on attention_module.k_proj and v_proj. The hook |
| receives the projection's output tensor (the K or V projection result), |
| runs it through quantize_dequantize_kv (round-trip simulation), and |
| returns the round-tripped tensor. PyTorch forward hooks that return a |
| non-None value replace the module's output, so RoPE / GQA expansion / cache |
| insertion / SDPA all proceed downstream on the quantized-then-dequantized |
| K/V — which is the exact behavior of a real quantized cache at inference. |
| |
| Why hook k_proj/v_proj outputs and not the attention module itself: |
| Llama-family attention modules differ in their internals: some apply RoPE |
| inline, some use position_ids, some return past_key_value tuples in |
| different shapes, some use sdpa vs eager vs flash attention. The K and V |
| projection outputs, however, are reliably the same shape and meaning |
| across all of them: (batch, seq, num_kv_heads × head_dim). Hooking there |
| isolates us from every downstream variation. |
| |
| Important: this measures drift as if the cache were quantized starting |
| NOW (the calibration forward pass). For autoregressive generation with a |
| growing cache, the per-token drift compounds. The relative drift between |
| layers is what's meaningful for ranking; absolute numbers are conservative |
| estimates of generation-time impact. |
| |
| What we do NOT cover: |
| - MQA-with-single-KV-head edge cases (works fine, just less interesting) |
| - Models that fuse Q/K/V into a single c_attn projection (GPT-2 style). |
| For those, hook c_attn and slice — separate adapter, not implemented here. |
| - Models where K/V go through additional norms (some Qwen variants apply |
| q_norm/k_norm AFTER projection). The hook here runs BEFORE those norms, |
| which matches the "quantize the cached value" semantic — the norms are |
| deterministic transforms applied each step regardless of cache precision. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from contextlib import contextmanager |
| from dataclasses import dataclass |
| from typing import Iterator, Literal |
|
|
|
|
| KVQuantizer = Literal["hqq_g64", "scaled_uniform", "scaled_per_head", "fp16_passthrough"] |
|
|
|
|
| @dataclass |
| class KVQuantSpec: |
| """Specification for a single layer's K/V quantization probe.""" |
| k_bits: int |
| v_bits: int |
| quantizer: KVQuantizer |
| group_size: int = 64 |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| def _quant_dequant(tensor, bits: int, quantizer: KVQuantizer, group_size: int): |
| """Round-trip a K or V projection output through a quantizer config.""" |
| import torch |
|
|
| if quantizer == "fp16_passthrough" or bits >= 16: |
| return tensor |
|
|
| if quantizer == "scaled_uniform": |
| per_row_absmax = tensor.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) |
| qmax = (1 << (bits - 1)) - 1 |
| scale = per_row_absmax / qmax |
| q = torch.clamp(torch.round(tensor / scale), -qmax, qmax) |
| return q * scale |
|
|
| if quantizer == "scaled_per_head": |
| |
| |
| |
| |
| if tensor.dim() < 3: |
| return _quant_dequant(tensor, bits, "scaled_uniform", group_size) |
| per_row_absmax = tensor.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) |
| qmax = (1 << (bits - 1)) - 1 |
| scale = per_row_absmax / qmax |
| q = torch.clamp(torch.round(tensor / scale), -qmax, qmax) |
| return q * scale |
|
|
| if quantizer == "hqq_g64": |
| last = tensor.shape[-1] |
| gs = group_size if last % group_size == 0 else max(1, last // 4) |
| prefix = tensor.shape[:-1] |
| groups = last // gs |
| reshaped = tensor.reshape(*prefix, groups, gs) |
| per_group_absmax = reshaped.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) |
| qmax = (1 << (bits - 1)) - 1 |
| scale = per_group_absmax / qmax |
| q = torch.clamp(torch.round(reshaped / scale), -qmax, qmax) |
| return (q * scale).reshape(*prefix, last) |
|
|
| raise ValueError(f"unknown quantizer: {quantizer}") |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _make_proj_hook(bits: int, quantizer: KVQuantizer, group_size: int): |
| """Forward hook for k_proj or v_proj that round-trip-quantizes the output. |
| |
| PyTorch hooks signature: (module, inputs, output). Returning a tensor |
| replaces output downstream. |
| """ |
| def hook(_module, _inputs, output): |
| |
| |
| import torch |
| if isinstance(output, torch.Tensor): |
| return _quant_dequant(output, bits, quantizer, group_size) |
| if isinstance(output, tuple) and output and isinstance(output[0], torch.Tensor): |
| qt = _quant_dequant(output[0], bits, quantizer, group_size) |
| return (qt,) + output[1:] |
| return output |
| return hook |
|
|
|
|
| def _locate_kv_projections(attn_module): |
| """Return (k_proj, v_proj) on a Llama-family attention module. |
| |
| Raises RuntimeError if the module doesn't follow the expected pattern. |
| """ |
| k_proj = getattr(attn_module, "k_proj", None) |
| v_proj = getattr(attn_module, "v_proj", None) |
| if k_proj is None or v_proj is None: |
| |
| raise RuntimeError( |
| "Attention module exposes no k_proj/v_proj. Likely fused QKV " |
| "(GPT-2 style) or a non-Llama-family architecture. Use a " |
| "model-specific adapter." |
| ) |
| return k_proj, v_proj |
|
|
|
|
| @contextmanager |
| def kv_quant_active(attn_module, spec: KVQuantSpec) -> Iterator[None]: |
| """Context manager: while active, this attention module's K/V projections |
| pass through quant→dequant simulation of the given spec. |
| |
| Usage: |
| with kv_quant_active(model.model.layers[3].self_attn, spec): |
| model(**batch, use_cache=False) |
| # outside the block: behavior is exactly as before, hooks removed. |
| """ |
| k_proj, v_proj = _locate_kv_projections(attn_module) |
|
|
| k_handle = k_proj.register_forward_hook( |
| _make_proj_hook(spec.k_bits, spec.quantizer, spec.group_size) |
| ) |
| v_handle = v_proj.register_forward_hook( |
| _make_proj_hook(spec.v_bits, spec.quantizer, spec.group_size) |
| ) |
| try: |
| yield |
| finally: |
| k_handle.remove() |
| v_handle.remove() |
|
|
|
|
| @contextmanager |
| def kv_quant_active_multi( |
| attn_modules_by_layer: dict[int, object], |
| specs_by_layer: dict[int, KVQuantSpec], |
| ) -> Iterator[None]: |
| """Multi-layer variant — install hooks on multiple layers simultaneously. |
| |
| Useful for measuring the joint effect of, e.g., quantizing every layer |
| to 4-bit K / 4-bit V, rather than just one layer in isolation. |
| """ |
| handles: list = [] |
| try: |
| for layer_idx, spec in specs_by_layer.items(): |
| attn = attn_modules_by_layer[layer_idx] |
| k_proj, v_proj = _locate_kv_projections(attn) |
| handles.append(k_proj.register_forward_hook( |
| _make_proj_hook(spec.k_bits, spec.quantizer, spec.group_size) |
| )) |
| handles.append(v_proj.register_forward_hook( |
| _make_proj_hook(spec.v_bits, spec.quantizer, spec.group_size) |
| )) |
| yield |
| finally: |
| for h in handles: |
| h.remove() |
|
|
|
|
| |
| |
| |
|
|
|
|
| def find_attention_modules(model) -> dict[int, object]: |
| """Return {layer_idx: attn_module} for a Llama-family HF model. |
| |
| Tries the standard layouts: |
| - model.model.layers[i].self_attn (Llama, Mistral, Qwen, OLMo) |
| - model.transformer.h[i].attn (GPT-style — not Llama family) |
| Raises RuntimeError if neither layout matches. |
| """ |
| layers = None |
| if hasattr(model, "model") and hasattr(model.model, "layers"): |
| layers = model.model.layers |
| elif hasattr(model, "transformer") and hasattr(model.transformer, "h"): |
| layers = model.transformer.h |
|
|
| if layers is None: |
| raise RuntimeError( |
| "Couldn't locate transformer layers. This module targets " |
| "Llama-family models (model.model.layers[*]). For other " |
| "architectures, write a small adapter and call " |
| "kv_quant_active() directly with the located attention module." |
| ) |
|
|
| attns: dict[int, object] = {} |
| for i, layer in enumerate(layers): |
| attn = getattr(layer, "self_attn", None) or getattr(layer, "attn", None) |
| if attn is None: |
| raise RuntimeError(f"layer {i}: no self_attn/attn submodule") |
| attns[i] = attn |
| return attns |
|
|