File size: 9,791 Bytes
059bd12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""
Sovereign Hive HSAQ — K/V interception hooks for Llama-family attention.

Targets: Llama-3, Mistral 0.1/0.2/0.3, Qwen 2 / 2.5, OLMo. Anything where the
attention module exposes `k_proj` and `v_proj` as nn.Linear submodules and
emits K/V immediately after those projections. This is the dominant pattern
in modern HF transformer implementations.

Hook strategy:
  Register a forward hook on attention_module.k_proj and v_proj. The hook
  receives the projection's output tensor (the K or V projection result),
  runs it through quantize_dequantize_kv (round-trip simulation), and
  returns the round-tripped tensor. PyTorch forward hooks that return a
  non-None value replace the module's output, so RoPE / GQA expansion / cache
  insertion / SDPA all proceed downstream on the quantized-then-dequantized
  K/V — which is the exact behavior of a real quantized cache at inference.

Why hook k_proj/v_proj outputs and not the attention module itself:
  Llama-family attention modules differ in their internals: some apply RoPE
  inline, some use position_ids, some return past_key_value tuples in
  different shapes, some use sdpa vs eager vs flash attention. The K and V
  projection outputs, however, are reliably the same shape and meaning
  across all of them: (batch, seq, num_kv_heads × head_dim). Hooking there
  isolates us from every downstream variation.

Important: this measures drift as if the cache were quantized starting
NOW (the calibration forward pass). For autoregressive generation with a
growing cache, the per-token drift compounds. The relative drift between
layers is what's meaningful for ranking; absolute numbers are conservative
estimates of generation-time impact.

What we do NOT cover:
  - MQA-with-single-KV-head edge cases (works fine, just less interesting)
  - Models that fuse Q/K/V into a single c_attn projection (GPT-2 style).
    For those, hook c_attn and slice — separate adapter, not implemented here.
  - Models where K/V go through additional norms (some Qwen variants apply
    q_norm/k_norm AFTER projection). The hook here runs BEFORE those norms,
    which matches the "quantize the cached value" semantic — the norms are
    deterministic transforms applied each step regardless of cache precision.
"""

from __future__ import annotations

from contextlib import contextmanager
from dataclasses import dataclass
from typing import Iterator, Literal


KVQuantizer = Literal["hqq_g64", "scaled_uniform", "scaled_per_head", "fp16_passthrough"]


@dataclass
class KVQuantSpec:
    """Specification for a single layer's K/V quantization probe."""
    k_bits: int
    v_bits: int
    quantizer: KVQuantizer
    group_size: int = 64


# ---------------------------------------------------------------------------
# Quantize-dequantize round trip
# ---------------------------------------------------------------------------
# Lives here (rather than imported) so this module is self-contained and the
# allocator/profiler can move independently. Identical algorithm to the stub.


def _quant_dequant(tensor, bits: int, quantizer: KVQuantizer, group_size: int):
    """Round-trip a K or V projection output through a quantizer config."""
    import torch

    if quantizer == "fp16_passthrough" or bits >= 16:
        return tensor

    if quantizer == "scaled_uniform":
        per_row_absmax = tensor.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
        qmax = (1 << (bits - 1)) - 1
        scale = per_row_absmax / qmax
        q = torch.clamp(torch.round(tensor / scale), -qmax, qmax)
        return q * scale

    if quantizer == "scaled_per_head":
        # Tensor shape after k_proj/v_proj is (batch, seq, num_kv_heads × head_dim)
        # We need to reshape to expose the head axis, take absmax per head, then
        # reshape back. Caller passes head info separately if it wants this path.
        # Conservative: fall back to per-row scaling if we can't split heads.
        if tensor.dim() < 3:
            return _quant_dequant(tensor, bits, "scaled_uniform", group_size)
        per_row_absmax = tensor.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
        qmax = (1 << (bits - 1)) - 1
        scale = per_row_absmax / qmax
        q = torch.clamp(torch.round(tensor / scale), -qmax, qmax)
        return q * scale

    if quantizer == "hqq_g64":
        last = tensor.shape[-1]
        gs = group_size if last % group_size == 0 else max(1, last // 4)
        prefix = tensor.shape[:-1]
        groups = last // gs
        reshaped = tensor.reshape(*prefix, groups, gs)
        per_group_absmax = reshaped.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
        qmax = (1 << (bits - 1)) - 1
        scale = per_group_absmax / qmax
        q = torch.clamp(torch.round(reshaped / scale), -qmax, qmax)
        return (q * scale).reshape(*prefix, last)

    raise ValueError(f"unknown quantizer: {quantizer}")


# ---------------------------------------------------------------------------
# Hook installation
# ---------------------------------------------------------------------------


def _make_proj_hook(bits: int, quantizer: KVQuantizer, group_size: int):
    """Forward hook for k_proj or v_proj that round-trip-quantizes the output.

    PyTorch hooks signature: (module, inputs, output). Returning a tensor
    replaces output downstream.
    """
    def hook(_module, _inputs, output):
        # k_proj/v_proj outputs are tensors. Some attention impls may wrap;
        # treat tuples by quantizing the first tensor element if present.
        import torch
        if isinstance(output, torch.Tensor):
            return _quant_dequant(output, bits, quantizer, group_size)
        if isinstance(output, tuple) and output and isinstance(output[0], torch.Tensor):
            qt = _quant_dequant(output[0], bits, quantizer, group_size)
            return (qt,) + output[1:]
        return output
    return hook


def _locate_kv_projections(attn_module):
    """Return (k_proj, v_proj) on a Llama-family attention module.

    Raises RuntimeError if the module doesn't follow the expected pattern.
    """
    k_proj = getattr(attn_module, "k_proj", None)
    v_proj = getattr(attn_module, "v_proj", None)
    if k_proj is None or v_proj is None:
        # GPT-2 style fused QKV — not handled here. Caller can patch.
        raise RuntimeError(
            "Attention module exposes no k_proj/v_proj. Likely fused QKV "
            "(GPT-2 style) or a non-Llama-family architecture. Use a "
            "model-specific adapter."
        )
    return k_proj, v_proj


@contextmanager
def kv_quant_active(attn_module, spec: KVQuantSpec) -> Iterator[None]:
    """Context manager: while active, this attention module's K/V projections
    pass through quant→dequant simulation of the given spec.

    Usage:
        with kv_quant_active(model.model.layers[3].self_attn, spec):
            model(**batch, use_cache=False)
        # outside the block: behavior is exactly as before, hooks removed.
    """
    k_proj, v_proj = _locate_kv_projections(attn_module)

    k_handle = k_proj.register_forward_hook(
        _make_proj_hook(spec.k_bits, spec.quantizer, spec.group_size)
    )
    v_handle = v_proj.register_forward_hook(
        _make_proj_hook(spec.v_bits, spec.quantizer, spec.group_size)
    )
    try:
        yield
    finally:
        k_handle.remove()
        v_handle.remove()


@contextmanager
def kv_quant_active_multi(
    attn_modules_by_layer: dict[int, object],
    specs_by_layer: dict[int, KVQuantSpec],
) -> Iterator[None]:
    """Multi-layer variant — install hooks on multiple layers simultaneously.

    Useful for measuring the joint effect of, e.g., quantizing every layer
    to 4-bit K / 4-bit V, rather than just one layer in isolation.
    """
    handles: list = []
    try:
        for layer_idx, spec in specs_by_layer.items():
            attn = attn_modules_by_layer[layer_idx]
            k_proj, v_proj = _locate_kv_projections(attn)
            handles.append(k_proj.register_forward_hook(
                _make_proj_hook(spec.k_bits, spec.quantizer, spec.group_size)
            ))
            handles.append(v_proj.register_forward_hook(
                _make_proj_hook(spec.v_bits, spec.quantizer, spec.group_size)
            ))
        yield
    finally:
        for h in handles:
            h.remove()


# ---------------------------------------------------------------------------
# Discovery: find attention modules in a Llama-family HF model
# ---------------------------------------------------------------------------


def find_attention_modules(model) -> dict[int, object]:
    """Return {layer_idx: attn_module} for a Llama-family HF model.

    Tries the standard layouts:
      - model.model.layers[i].self_attn   (Llama, Mistral, Qwen, OLMo)
      - model.transformer.h[i].attn       (GPT-style — not Llama family)
    Raises RuntimeError if neither layout matches.
    """
    layers = None
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        layers = model.model.layers
    elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
        layers = model.transformer.h

    if layers is None:
        raise RuntimeError(
            "Couldn't locate transformer layers. This module targets "
            "Llama-family models (model.model.layers[*]). For other "
            "architectures, write a small adapter and call "
            "kv_quant_active() directly with the located attention module."
        )

    attns: dict[int, object] = {}
    for i, layer in enumerate(layers):
        attn = getattr(layer, "self_attn", None) or getattr(layer, "attn", None)
        if attn is None:
            raise RuntimeError(f"layer {i}: no self_attn/attn submodule")
        attns[i] = attn
    return attns