File size: 14,469 Bytes
0769ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
"""
ENGRAM Protocol β€” KV Cache Compression Layer


Implements:
  - FP16 passthrough (no compression)
  - Q8_0: group quantization matching llama.cpp GGML_TYPE_Q8_0
    Phase 1 production fallback. ~2x compression, <5% speed hit (D5).
  - PolarQuant: MSE-optimal random rotation + Lloyd-Max codebook at 3 bits.
    QJL REMOVED β€” confirmed harmful by 6+ independent implementations (D5).
    Softmax amplifies QJL variance, making two-stage worse than MSE-only.

Reference: TheTom/turboquant_plus (511+ tests, most mature impl)
"""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np
import torch

from kvcos.core.types import CompressionMethod

# ── Q8_0 Constants ────────────────────────────────────────────────────────────
Q8_GROUP_SIZE = 32


@dataclass(frozen=True)
class CompressionResult:
    """Result of compressing a KV cache tensor."""

    data: torch.Tensor
    method: CompressionMethod
    original_dtype: torch.dtype
    compression_ratio: float
    metadata: dict[str, str]


# ── FP16 Passthrough ──────────────────────────────────────────────────────────


def compress_fp16(kv: torch.Tensor) -> CompressionResult:
    """No-op compression: ensure tensor is FP16."""
    data = kv.to(torch.float16).contiguous()
    return CompressionResult(
        data=data,
        method=CompressionMethod.FP16,
        original_dtype=kv.dtype,
        compression_ratio=1.0,
        metadata={},
    )


def decompress_fp16(data: torch.Tensor) -> torch.Tensor:
    return data.to(torch.float16)


# ── Q8_0 Quantization ────────────────────────────────────────────────────────
# Matches llama.cpp GGML_TYPE_Q8_0 layout:
#   32-element groups, 1 float16 scale per group, 32 int8 values
#   Storage: (32*1 + 2) / (32*2) = 34/64 β‰ˆ 1.88x compression


def compress_q8_0(kv: torch.Tensor) -> CompressionResult:
    """Quantize KV cache to Q8_0 (int8 with per-group scale).

    Stores dequantized bfloat16 for safetensors compatibility β€”
    safetensors doesn't support int8+scale pairs natively.
    """
    original_dtype = kv.dtype
    original_bytes = kv.numel() * kv.element_size()

    kv_flat = kv.float().contiguous()
    orig_shape = kv_flat.shape

    last_dim = orig_shape[-1]
    pad_amount = (Q8_GROUP_SIZE - last_dim % Q8_GROUP_SIZE) % Q8_GROUP_SIZE
    if pad_amount > 0:
        kv_flat = torch.nn.functional.pad(kv_flat, (0, pad_amount))

    new_shape = kv_flat.shape[:-1] + (-1, Q8_GROUP_SIZE)
    grouped = kv_flat.reshape(new_shape)

    scales = grouped.abs().amax(dim=-1, keepdim=True) / 127.0
    scales = scales.clamp(min=1e-10)

    quantized = torch.clamp(torch.round(grouped / scales), -127, 127)
    dequantized = (quantized * scales).reshape(kv_flat.shape)

    if pad_amount > 0:
        dequantized = dequantized[..., :last_dim]

    dequantized = dequantized.reshape(orig_shape).to(torch.bfloat16)
    compressed_bytes = dequantized.numel() * 2

    return CompressionResult(
        data=dequantized,
        method=CompressionMethod.Q8_0,
        original_dtype=original_dtype,
        compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
        metadata={"q8_group_size": str(Q8_GROUP_SIZE)},
    )


def decompress_q8_0(data: torch.Tensor) -> torch.Tensor:
    return data.to(torch.float16)


# ── PolarQuant (Phase 2 β€” TurboQuant without QJL) ────────────────────────────
# QJL is INTENTIONALLY ABSENT per D5.


class PolarQuantConfig:
    """Configuration for PolarQuant compression."""

    def __init__(self, bits: int = 3, seed: int = 42):
        self.bits = bits
        self.n_centroids = 2**bits
        self.seed = seed
        self._rotation_cache: dict[int, torch.Tensor] = {}
        self._codebook_cache: dict[int, torch.Tensor] = {}

    def get_rotation_matrix(self, dim: int, device: torch.device) -> torch.Tensor:
        """Get fixed random orthogonal rotation matrix R ∈ R^(dΓ—d)."""
        if dim not in self._rotation_cache:
            rng = np.random.RandomState(self.seed)
            gaussian = rng.randn(dim, dim).astype(np.float32)
            q, r = np.linalg.qr(gaussian)
            d = np.diag(r)
            ph = np.sign(d)
            q *= ph[np.newaxis, :]
            self._rotation_cache[dim] = torch.from_numpy(q)
        return self._rotation_cache[dim].to(device)

    def get_lloyd_max_codebook(self, dim: int) -> torch.Tensor:
        """Lloyd-Max optimal centroids for N(0,1), 3-bit (8 levels)."""
        if dim not in self._codebook_cache:
            codebook = torch.tensor(
                [-1.748, -1.050, -0.501, -0.000, 0.000, 0.501, 1.050, 1.748],
                dtype=torch.float32,
            )
            self._codebook_cache[dim] = codebook
        return self._codebook_cache[dim]


_POLAR_CONFIG = PolarQuantConfig()


def compress_polarquant(kv: torch.Tensor) -> CompressionResult:
    """Compress using PolarQuant (3-bit Lloyd-Max after random rotation).

    Phase 2 implementation. Currently stores dequantized bfloat16.
    True 3-bit packed storage is Phase 2+.
    """
    original_dtype = kv.dtype
    original_bytes = kv.numel() * kv.element_size()
    device = kv.device

    kv_float = kv.float().contiguous()
    orig_shape = kv_float.shape

    head_dim = orig_shape[-1]
    flat = kv_float.reshape(-1, head_dim)

    R = _POLAR_CONFIG.get_rotation_matrix(head_dim, device)
    rotated = flat @ R

    dim_std = rotated.std(dim=0, keepdim=True).clamp(min=1e-10)
    normalized = rotated / dim_std

    codebook = _POLAR_CONFIG.get_lloyd_max_codebook(head_dim).to(device)
    distances = (normalized.unsqueeze(-1) - codebook.unsqueeze(0).unsqueeze(0)) ** 2
    indices = distances.argmin(dim=-1)

    dequantized = codebook[indices]
    dequantized = dequantized * dim_std
    R_inv = R.T
    dequantized = dequantized @ R_inv

    dequantized = dequantized.reshape(orig_shape).to(torch.bfloat16)
    compressed_bytes = dequantized.numel() * 2

    return CompressionResult(
        data=dequantized,
        method=CompressionMethod.POLARQUANT,
        original_dtype=original_dtype,
        compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
        metadata={
            "polarquant_bits": "3",
            "polarquant_seed": str(_POLAR_CONFIG.seed),
            "qjl_enabled": "false",  # D5: QJL permanently disabled
        },
    )


def decompress_polarquant(data: torch.Tensor) -> torch.Tensor:
    return data.to(torch.float16)


# ── INT8 Quantization (Phase 2 β€” true on-disk compression) ───────────────────
# Stores actual int8 tensors in safetensors (1 byte/element vs 2 for fp16).
# Per-row symmetric quantization: scale = max(abs(row)) / 127.
# Separate scale tensor stored alongside quantized data.
# 2x on-disk compression with cos_sim > 0.999.


@dataclass(frozen=True)
class Int8CompressedPair:
    """INT8 quantized tensor + per-row scales."""

    quantized: torch.Tensor  # int8 [same shape as input]
    scales: torch.Tensor  # float16 [shape[:-1]] β€” one scale per row


def compress_int8_tensor(kv: torch.Tensor) -> Int8CompressedPair:
    """Quantize a KV tensor to int8 with per-row scales.

    Args:
        kv: [..., head_dim] tensor (any dtype)

    Returns:
        Int8CompressedPair with int8 data and float16 scales
    """
    orig_shape = kv.shape
    flat = kv.float().reshape(-1, orig_shape[-1])

    row_max = flat.abs().amax(dim=1, keepdim=True).clamp(min=1e-8)
    scales = row_max / 127.0

    quantized = (flat / scales).round().clamp(-127, 127).to(torch.int8)
    scales_f16 = scales.squeeze(1).to(torch.float16)

    return Int8CompressedPair(
        quantized=quantized.reshape(orig_shape),
        scales=scales_f16.reshape(orig_shape[:-1]),
    )


def decompress_int8_tensor(quantized: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
    """Dequantize int8 tensor using per-row scales.

    Returns float16 tensor of the original shape.
    """
    return (quantized.float() * scales.float().unsqueeze(-1)).to(torch.float16)


def compress_int8(kv: torch.Tensor) -> CompressionResult:
    """INT8 compression β€” returns dequantized float16 for CompressionResult compat.

    The actual int8 storage is handled by the serializer which calls
    compress_int8_tensor() directly for true on-disk compression.
    This wrapper exists for the dispatcher API.
    """
    pair = compress_int8_tensor(kv)
    dequantized = decompress_int8_tensor(pair.quantized, pair.scales)

    original_bytes = kv.numel() * kv.element_size()
    # True on-disk: int8 data + float16 scales
    compressed_bytes = pair.quantized.numel() * 1 + pair.scales.numel() * 2

    return CompressionResult(
        data=dequantized,
        method=CompressionMethod.INT8,
        original_dtype=kv.dtype,
        compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
        metadata={"int8_scale_dtype": "float16"},
    )


# ── LAYER_DELTA Compression ──────────────────────────────────────────────────
# Stores layer 0 as fp16 baseline, layers 1..N as int8 deltas from previous.
# Inter-layer residuals are typically small (adjacent layers are correlated),
# so int8 quantization of deltas achieves better fidelity than direct int8.
# On-disk: ~(1/N) fp16 + ((N-1)/N) int8 β‰ˆ slightly better than straight INT8.


@dataclass(frozen=True)
class LayerDeltaCompressed:
    """Layer-delta compressed: fp16 baseline + int8 deltas."""

    baseline: torch.Tensor  # [n_kv_heads, n_cells, head_dim] fp16
    delta_quantized: list[torch.Tensor]  # each int8 [n_kv_heads, n_cells, head_dim]
    delta_scales: list[torch.Tensor]  # each fp16 [n_kv_heads, n_cells]
    n_layers: int


def compress_layer_delta(kv: torch.Tensor) -> LayerDeltaCompressed:
    """Compress KV tensor using inter-layer delta encoding.

    Args:
        kv: [n_layers, n_kv_heads, n_cells, head_dim]

    Returns:
        LayerDeltaCompressed with fp16 baseline + int8 deltas
    """
    n_layers = kv.shape[0]
    baseline = kv[0].to(torch.float16)

    deltas: list[torch.Tensor] = []
    scales: list[torch.Tensor] = []

    for i in range(1, n_layers):
        delta = (kv[i].float() - kv[i - 1].float())
        flat = delta.reshape(-1, delta.shape[-1])
        row_max = flat.abs().amax(dim=1).clamp(min=1e-8) / 127.0
        q = (flat / row_max.unsqueeze(1)).round().clamp(-127, 127).to(torch.int8)
        deltas.append(q.reshape(delta.shape))
        scales.append(row_max.to(torch.float16).reshape(delta.shape[:-1]))

    return LayerDeltaCompressed(
        baseline=baseline, delta_quantized=deltas,
        delta_scales=scales, n_layers=n_layers,
    )


def decompress_layer_delta(data: LayerDeltaCompressed) -> torch.Tensor:
    """Decompress layer-delta encoded KV tensor."""
    layers = [data.baseline.float()]
    for dq, ds in zip(data.delta_quantized, data.delta_scales):
        flat = dq.float().reshape(-1, dq.shape[-1])
        delta = (flat * ds.float().reshape(-1).unsqueeze(1)).reshape(dq.shape)
        layers.append(layers[-1] + delta)
    return torch.stack(layers).to(torch.float16)


def compress_layer_delta_result(kv: torch.Tensor) -> CompressionResult:
    """Layer-delta wrapper for CompressionResult API."""
    compressed = compress_layer_delta(kv)
    decompressed = decompress_layer_delta(compressed)

    original_bytes = kv.numel() * kv.element_size()
    # On-disk: baseline fp16 + (N-1) int8 deltas + (N-1) fp16 scales
    n = compressed.n_layers
    per_layer_elements = kv[0].numel()
    scale_elements = kv.shape[1] * kv.shape[2]  # n_kv_heads * n_cells
    compressed_bytes = (
        per_layer_elements * 2  # baseline fp16
        + (n - 1) * per_layer_elements * 1  # int8 deltas
        + (n - 1) * scale_elements * 2  # fp16 scales
    )

    return CompressionResult(
        data=decompressed,
        method=CompressionMethod.LAYER_DELTA,
        original_dtype=kv.dtype,
        compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
        metadata={"delta_n_layers": str(n)},
    )


# ── Dispatcher ────────────────────────────────────────────────────────────────


def compress(kv: torch.Tensor, method: CompressionMethod) -> CompressionResult:
    """Compress a KV cache tensor using the specified method."""
    match method:
        case CompressionMethod.FP16:
            return compress_fp16(kv)
        case CompressionMethod.Q8_0:
            return compress_q8_0(kv)
        case CompressionMethod.POLARQUANT:
            return compress_polarquant(kv)
        case CompressionMethod.INT8:
            return compress_int8(kv)
        case CompressionMethod.LAYER_DELTA:
            return compress_layer_delta_result(kv)
        case CompressionMethod.Q4_0:
            import warnings

            warnings.warn(
                "Q4_0 has 92% dequantization slowdown at 64K+ context. "
                "Using Q8_0 instead. See D5.",
                UserWarning,
                stacklevel=2,
            )
            return compress_q8_0(kv)
        case _:
            raise ValueError(f"Unknown compression method: {method}")


def decompress(data: torch.Tensor, method: CompressionMethod) -> torch.Tensor:
    """Decompress a KV cache tensor."""
    match method:
        case CompressionMethod.FP16:
            return decompress_fp16(data)
        case CompressionMethod.Q8_0 | CompressionMethod.Q4_0:
            return decompress_q8_0(data)
        case CompressionMethod.POLARQUANT:
            return decompress_polarquant(data)
        case CompressionMethod.INT8 | CompressionMethod.LAYER_DELTA:
            # Already dequantized float16 in CompressionResult
            return data.to(torch.float16)
        case _:
            raise ValueError(f"Unknown compression method: {method}")