| """ |
| ENGRAM Protocol β KV Cache Compression Layer |
| |
| |
| Implements: |
| - FP16 passthrough (no compression) |
| - Q8_0: group quantization matching llama.cpp GGML_TYPE_Q8_0 |
| Phase 1 production fallback. ~2x compression, <5% speed hit (D5). |
| - PolarQuant: MSE-optimal random rotation + Lloyd-Max codebook at 3 bits. |
| QJL REMOVED β confirmed harmful by 6+ independent implementations (D5). |
| Softmax amplifies QJL variance, making two-stage worse than MSE-only. |
| |
| Reference: TheTom/turboquant_plus (511+ tests, most mature impl) |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import numpy as np |
| import torch |
|
|
| from kvcos.core.types import CompressionMethod |
|
|
| |
| Q8_GROUP_SIZE = 32 |
|
|
|
|
| @dataclass(frozen=True) |
| class CompressionResult: |
| """Result of compressing a KV cache tensor.""" |
|
|
| data: torch.Tensor |
| method: CompressionMethod |
| original_dtype: torch.dtype |
| compression_ratio: float |
| metadata: dict[str, str] |
|
|
|
|
| |
|
|
|
|
| def compress_fp16(kv: torch.Tensor) -> CompressionResult: |
| """No-op compression: ensure tensor is FP16.""" |
| data = kv.to(torch.float16).contiguous() |
| return CompressionResult( |
| data=data, |
| method=CompressionMethod.FP16, |
| original_dtype=kv.dtype, |
| compression_ratio=1.0, |
| metadata={}, |
| ) |
|
|
|
|
| def decompress_fp16(data: torch.Tensor) -> torch.Tensor: |
| return data.to(torch.float16) |
|
|
|
|
| |
| |
| |
| |
|
|
|
|
| def compress_q8_0(kv: torch.Tensor) -> CompressionResult: |
| """Quantize KV cache to Q8_0 (int8 with per-group scale). |
| |
| Stores dequantized bfloat16 for safetensors compatibility β |
| safetensors doesn't support int8+scale pairs natively. |
| """ |
| original_dtype = kv.dtype |
| original_bytes = kv.numel() * kv.element_size() |
|
|
| kv_flat = kv.float().contiguous() |
| orig_shape = kv_flat.shape |
|
|
| last_dim = orig_shape[-1] |
| pad_amount = (Q8_GROUP_SIZE - last_dim % Q8_GROUP_SIZE) % Q8_GROUP_SIZE |
| if pad_amount > 0: |
| kv_flat = torch.nn.functional.pad(kv_flat, (0, pad_amount)) |
|
|
| new_shape = kv_flat.shape[:-1] + (-1, Q8_GROUP_SIZE) |
| grouped = kv_flat.reshape(new_shape) |
|
|
| scales = grouped.abs().amax(dim=-1, keepdim=True) / 127.0 |
| scales = scales.clamp(min=1e-10) |
|
|
| quantized = torch.clamp(torch.round(grouped / scales), -127, 127) |
| dequantized = (quantized * scales).reshape(kv_flat.shape) |
|
|
| if pad_amount > 0: |
| dequantized = dequantized[..., :last_dim] |
|
|
| dequantized = dequantized.reshape(orig_shape).to(torch.bfloat16) |
| compressed_bytes = dequantized.numel() * 2 |
|
|
| return CompressionResult( |
| data=dequantized, |
| method=CompressionMethod.Q8_0, |
| original_dtype=original_dtype, |
| compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0, |
| metadata={"q8_group_size": str(Q8_GROUP_SIZE)}, |
| ) |
|
|
|
|
| def decompress_q8_0(data: torch.Tensor) -> torch.Tensor: |
| return data.to(torch.float16) |
|
|
|
|
| |
| |
|
|
|
|
| class PolarQuantConfig: |
| """Configuration for PolarQuant compression.""" |
|
|
| def __init__(self, bits: int = 3, seed: int = 42): |
| self.bits = bits |
| self.n_centroids = 2**bits |
| self.seed = seed |
| self._rotation_cache: dict[int, torch.Tensor] = {} |
| self._codebook_cache: dict[int, torch.Tensor] = {} |
|
|
| def get_rotation_matrix(self, dim: int, device: torch.device) -> torch.Tensor: |
| """Get fixed random orthogonal rotation matrix R β R^(dΓd).""" |
| if dim not in self._rotation_cache: |
| rng = np.random.RandomState(self.seed) |
| gaussian = rng.randn(dim, dim).astype(np.float32) |
| q, r = np.linalg.qr(gaussian) |
| d = np.diag(r) |
| ph = np.sign(d) |
| q *= ph[np.newaxis, :] |
| self._rotation_cache[dim] = torch.from_numpy(q) |
| return self._rotation_cache[dim].to(device) |
|
|
| def get_lloyd_max_codebook(self, dim: int) -> torch.Tensor: |
| """Lloyd-Max optimal centroids for N(0,1), 3-bit (8 levels).""" |
| if dim not in self._codebook_cache: |
| codebook = torch.tensor( |
| [-1.748, -1.050, -0.501, -0.000, 0.000, 0.501, 1.050, 1.748], |
| dtype=torch.float32, |
| ) |
| self._codebook_cache[dim] = codebook |
| return self._codebook_cache[dim] |
|
|
|
|
| _POLAR_CONFIG = PolarQuantConfig() |
|
|
|
|
| def compress_polarquant(kv: torch.Tensor) -> CompressionResult: |
| """Compress using PolarQuant (3-bit Lloyd-Max after random rotation). |
| |
| Phase 2 implementation. Currently stores dequantized bfloat16. |
| True 3-bit packed storage is Phase 2+. |
| """ |
| original_dtype = kv.dtype |
| original_bytes = kv.numel() * kv.element_size() |
| device = kv.device |
|
|
| kv_float = kv.float().contiguous() |
| orig_shape = kv_float.shape |
|
|
| head_dim = orig_shape[-1] |
| flat = kv_float.reshape(-1, head_dim) |
|
|
| R = _POLAR_CONFIG.get_rotation_matrix(head_dim, device) |
| rotated = flat @ R |
|
|
| dim_std = rotated.std(dim=0, keepdim=True).clamp(min=1e-10) |
| normalized = rotated / dim_std |
|
|
| codebook = _POLAR_CONFIG.get_lloyd_max_codebook(head_dim).to(device) |
| distances = (normalized.unsqueeze(-1) - codebook.unsqueeze(0).unsqueeze(0)) ** 2 |
| indices = distances.argmin(dim=-1) |
|
|
| dequantized = codebook[indices] |
| dequantized = dequantized * dim_std |
| R_inv = R.T |
| dequantized = dequantized @ R_inv |
|
|
| dequantized = dequantized.reshape(orig_shape).to(torch.bfloat16) |
| compressed_bytes = dequantized.numel() * 2 |
|
|
| return CompressionResult( |
| data=dequantized, |
| method=CompressionMethod.POLARQUANT, |
| original_dtype=original_dtype, |
| compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0, |
| metadata={ |
| "polarquant_bits": "3", |
| "polarquant_seed": str(_POLAR_CONFIG.seed), |
| "qjl_enabled": "false", |
| }, |
| ) |
|
|
|
|
| def decompress_polarquant(data: torch.Tensor) -> torch.Tensor: |
| return data.to(torch.float16) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| @dataclass(frozen=True) |
| class Int8CompressedPair: |
| """INT8 quantized tensor + per-row scales.""" |
|
|
| quantized: torch.Tensor |
| scales: torch.Tensor |
|
|
|
|
| def compress_int8_tensor(kv: torch.Tensor) -> Int8CompressedPair: |
| """Quantize a KV tensor to int8 with per-row scales. |
| |
| Args: |
| kv: [..., head_dim] tensor (any dtype) |
| |
| Returns: |
| Int8CompressedPair with int8 data and float16 scales |
| """ |
| orig_shape = kv.shape |
| flat = kv.float().reshape(-1, orig_shape[-1]) |
|
|
| row_max = flat.abs().amax(dim=1, keepdim=True).clamp(min=1e-8) |
| scales = row_max / 127.0 |
|
|
| quantized = (flat / scales).round().clamp(-127, 127).to(torch.int8) |
| scales_f16 = scales.squeeze(1).to(torch.float16) |
|
|
| return Int8CompressedPair( |
| quantized=quantized.reshape(orig_shape), |
| scales=scales_f16.reshape(orig_shape[:-1]), |
| ) |
|
|
|
|
| def decompress_int8_tensor(quantized: torch.Tensor, scales: torch.Tensor) -> torch.Tensor: |
| """Dequantize int8 tensor using per-row scales. |
| |
| Returns float16 tensor of the original shape. |
| """ |
| return (quantized.float() * scales.float().unsqueeze(-1)).to(torch.float16) |
|
|
|
|
| def compress_int8(kv: torch.Tensor) -> CompressionResult: |
| """INT8 compression β returns dequantized float16 for CompressionResult compat. |
| |
| The actual int8 storage is handled by the serializer which calls |
| compress_int8_tensor() directly for true on-disk compression. |
| This wrapper exists for the dispatcher API. |
| """ |
| pair = compress_int8_tensor(kv) |
| dequantized = decompress_int8_tensor(pair.quantized, pair.scales) |
|
|
| original_bytes = kv.numel() * kv.element_size() |
| |
| compressed_bytes = pair.quantized.numel() * 1 + pair.scales.numel() * 2 |
|
|
| return CompressionResult( |
| data=dequantized, |
| method=CompressionMethod.INT8, |
| original_dtype=kv.dtype, |
| compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0, |
| metadata={"int8_scale_dtype": "float16"}, |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| @dataclass(frozen=True) |
| class LayerDeltaCompressed: |
| """Layer-delta compressed: fp16 baseline + int8 deltas.""" |
|
|
| baseline: torch.Tensor |
| delta_quantized: list[torch.Tensor] |
| delta_scales: list[torch.Tensor] |
| n_layers: int |
|
|
|
|
| def compress_layer_delta(kv: torch.Tensor) -> LayerDeltaCompressed: |
| """Compress KV tensor using inter-layer delta encoding. |
| |
| Args: |
| kv: [n_layers, n_kv_heads, n_cells, head_dim] |
| |
| Returns: |
| LayerDeltaCompressed with fp16 baseline + int8 deltas |
| """ |
| n_layers = kv.shape[0] |
| baseline = kv[0].to(torch.float16) |
|
|
| deltas: list[torch.Tensor] = [] |
| scales: list[torch.Tensor] = [] |
|
|
| for i in range(1, n_layers): |
| delta = (kv[i].float() - kv[i - 1].float()) |
| flat = delta.reshape(-1, delta.shape[-1]) |
| row_max = flat.abs().amax(dim=1).clamp(min=1e-8) / 127.0 |
| q = (flat / row_max.unsqueeze(1)).round().clamp(-127, 127).to(torch.int8) |
| deltas.append(q.reshape(delta.shape)) |
| scales.append(row_max.to(torch.float16).reshape(delta.shape[:-1])) |
|
|
| return LayerDeltaCompressed( |
| baseline=baseline, delta_quantized=deltas, |
| delta_scales=scales, n_layers=n_layers, |
| ) |
|
|
|
|
| def decompress_layer_delta(data: LayerDeltaCompressed) -> torch.Tensor: |
| """Decompress layer-delta encoded KV tensor.""" |
| layers = [data.baseline.float()] |
| for dq, ds in zip(data.delta_quantized, data.delta_scales): |
| flat = dq.float().reshape(-1, dq.shape[-1]) |
| delta = (flat * ds.float().reshape(-1).unsqueeze(1)).reshape(dq.shape) |
| layers.append(layers[-1] + delta) |
| return torch.stack(layers).to(torch.float16) |
|
|
|
|
| def compress_layer_delta_result(kv: torch.Tensor) -> CompressionResult: |
| """Layer-delta wrapper for CompressionResult API.""" |
| compressed = compress_layer_delta(kv) |
| decompressed = decompress_layer_delta(compressed) |
|
|
| original_bytes = kv.numel() * kv.element_size() |
| |
| n = compressed.n_layers |
| per_layer_elements = kv[0].numel() |
| scale_elements = kv.shape[1] * kv.shape[2] |
| compressed_bytes = ( |
| per_layer_elements * 2 |
| + (n - 1) * per_layer_elements * 1 |
| + (n - 1) * scale_elements * 2 |
| ) |
|
|
| return CompressionResult( |
| data=decompressed, |
| method=CompressionMethod.LAYER_DELTA, |
| original_dtype=kv.dtype, |
| compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0, |
| metadata={"delta_n_layers": str(n)}, |
| ) |
|
|
|
|
| |
|
|
|
|
| def compress(kv: torch.Tensor, method: CompressionMethod) -> CompressionResult: |
| """Compress a KV cache tensor using the specified method.""" |
| match method: |
| case CompressionMethod.FP16: |
| return compress_fp16(kv) |
| case CompressionMethod.Q8_0: |
| return compress_q8_0(kv) |
| case CompressionMethod.POLARQUANT: |
| return compress_polarquant(kv) |
| case CompressionMethod.INT8: |
| return compress_int8(kv) |
| case CompressionMethod.LAYER_DELTA: |
| return compress_layer_delta_result(kv) |
| case CompressionMethod.Q4_0: |
| import warnings |
|
|
| warnings.warn( |
| "Q4_0 has 92% dequantization slowdown at 64K+ context. " |
| "Using Q8_0 instead. See D5.", |
| UserWarning, |
| stacklevel=2, |
| ) |
| return compress_q8_0(kv) |
| case _: |
| raise ValueError(f"Unknown compression method: {method}") |
|
|
|
|
| def decompress(data: torch.Tensor, method: CompressionMethod) -> torch.Tensor: |
| """Decompress a KV cache tensor.""" |
| match method: |
| case CompressionMethod.FP16: |
| return decompress_fp16(data) |
| case CompressionMethod.Q8_0 | CompressionMethod.Q4_0: |
| return decompress_q8_0(data) |
| case CompressionMethod.POLARQUANT: |
| return decompress_polarquant(data) |
| case CompressionMethod.INT8 | CompressionMethod.LAYER_DELTA: |
| |
| return data.to(torch.float16) |
| case _: |
| raise ValueError(f"Unknown compression method: {method}") |
|
|