File size: 12,405 Bytes
18f4d80 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 | """
RotorQuant: Reimagining TurboQuant with Clifford Algebra
Instead of TurboQuant's random orthogonal matrix Π (via QR decomposition),
RotorQuant uses Clifford rotors R = exp(B/2) for decorrelation.
Why this is better for geometric data:
1. Rotor sandwich R x R̃ preserves the FULL algebraic structure
(inner products, outer products, grades) — not just norms
2. Rotors compose naturally: R₂(R₁ x R̃₁)R̃₂ = (R₂R₁) x (R₂R₁)~
3. Grade-aware quantization: different grades can use different bit budgets
4. The bivector structure of rotors means we only need 3 parameters
(not d² for a full rotation matrix) — massive parameter savings
Algorithm:
Stage 1 (MSE): Embed vectors as Cl(3,0) multivectors → rotor sandwich →
grade-aware Lloyd-Max quantization per component
Stage 2 (QJL): 1-bit sign quantization on residual for unbiased inner products
"""
import torch
import torch.nn as nn
import math
from typing import Optional, Tuple, Dict
try:
from .clifford import (
MV_DIM, geometric_product, reverse, make_random_rotor,
rotor_sandwich, embed_vectors_as_multivectors,
extract_vectors_from_multivectors, multivector_norm_sq,
)
from .lloyd_max import LloydMaxCodebook
except ImportError:
from clifford import (
MV_DIM, geometric_product, reverse, make_random_rotor,
rotor_sandwich, embed_vectors_as_multivectors,
extract_vectors_from_multivectors, multivector_norm_sq,
)
from lloyd_max import LloydMaxCodebook
class RotorQuantMSE(nn.Module):
"""
Stage 1: MSE-optimal quantizer using Clifford rotors.
Instead of Π @ x (matrix multiply), we do R x R̃ (rotor sandwich).
Then per-component Lloyd-Max quantization on the rotated multivector.
"""
def __init__(self, d: int, bits: int, seed: int = 42,
grade_bits: Optional[Dict[str, int]] = None,
device: str = "cpu"):
"""
Args:
d: original vector dimension
bits: default bits per component
grade_bits: optional per-grade bit override, e.g.
{'scalar': 2, 'vector': 3, 'bivector': 2, 'trivector': 1}
seed: random seed
device: torch device
"""
super().__init__()
self.d = d
self.bits = bits
self.device = device
# Compute how many multivector groups we need
self.n_groups = (d + 2) // 3 # ceil(d/3)
self.mv_dim = self.n_groups * MV_DIM # total components
# Grade-aware bit allocation
# Only store non-zero grades: the rotor sandwich R v R̃ of a grade-1
# vector produces ONLY odd grades (vector + trivector). Scalar and
# bivector are mathematically guaranteed to be zero — don't waste
# storage on them. This cuts stored indices from 8/group to 4/group.
if grade_bits is None:
grade_bits = {
'vector': bits,
}
self.grade_bits = grade_bits
# Create per-grade codebooks
# d_eff determines the Lloyd-Max Gaussian σ = 1/√d_eff
d_eff_vector = d # vector grades: σ ≈ 1/√d
self.codebooks = nn.ModuleDict()
for grade_name, gb in grade_bits.items():
cb = LloydMaxCodebook(d_eff_vector, gb)
self.register_buffer(f'centroids_{grade_name}',
cb.centroids.to(device))
# Only quantize grade-1 (vector) components.
# Trivector (e123) carries 15% of energy but contributes ZERO to
# reconstruction because extract_vectors_from_multivectors only reads
# grade-1 (e1, e2, e3). Dropping trivector saves 25% of indices
# with zero MSE impact, matching TurboQuant's compression ratio.
# [scalar, e1, e2, e3, e12, e13, e23, e123]
self.grade_map = {
'vector': [1, 2, 3],
}
# Pre-compute random rotors (one per group for decorrelation)
rotors = []
for i in range(self.n_groups):
r = make_random_rotor((), device=device, seed=seed + i)
rotors.append(r)
self.register_buffer('rotors', torch.stack(rotors)) # (n_groups, 8)
def _apply_rotors(self, mv: torch.Tensor) -> torch.Tensor:
"""Apply per-group rotor sandwich: R_i x_i R̃_i"""
# mv: (..., n_groups, 8)
# rotors: (n_groups, 8) → broadcast over batch dims
return rotor_sandwich(self.rotors, mv)
def _unapply_rotors(self, mv: torch.Tensor) -> torch.Tensor:
"""Inverse rotor sandwich: R̃_i x_i R_i"""
rotor_rev = reverse(self.rotors)
return rotor_sandwich(rotor_rev, mv)
def _quantize_grade(self, x: torch.Tensor, grade_name: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantize components of a specific grade."""
centroids = getattr(self, f'centroids_{grade_name}')
diffs = x.unsqueeze(-1) - centroids # (..., n_components, n_levels)
indices = diffs.abs().argmin(dim=-1)
x_q = centroids[indices]
return x_q, indices
def quantize(self, x: torch.Tensor) -> Tuple[torch.Tensor, dict]:
"""
Quantize vectors via rotor + grade-aware Lloyd-Max.
x: (..., d) input vectors
Returns: (mv_q, indices_dict)
"""
# Normalize to unit vectors (store norms separately)
norms = torch.norm(x, dim=-1, keepdim=True).clamp(min=1e-8)
x_unit = x / norms
# Embed as multivectors
mv = embed_vectors_as_multivectors(x_unit) # (..., n_groups, 8)
# Apply rotor decorrelation
mv_rot = self._apply_rotors(mv)
# Grade-aware quantization
mv_q = torch.zeros_like(mv_rot)
all_indices = {}
for grade_name, component_indices in self.grade_map.items():
grade_data = mv_rot[..., component_indices] # (..., n_groups, n_components)
flat = grade_data.reshape(*grade_data.shape[:-1], -1)
q_flat, idx = self._quantize_grade(flat, grade_name)
q_data = q_flat.reshape_as(grade_data)
mv_q[..., component_indices] = q_data
all_indices[grade_name] = idx
# Store norms in indices for dequantize
all_indices['_norms'] = norms.squeeze(-1)
return mv_q, all_indices
def dequantize(self, indices: dict) -> torch.Tensor:
"""Reconstruct vectors from quantized indices."""
sample_centroids = getattr(self, 'centroids_vector')
vector_idx = indices['vector']
flat_batch = vector_idx.shape[0] if vector_idx.dim() >= 1 else 1
mv_q = torch.zeros(flat_batch, self.n_groups, MV_DIM,
dtype=sample_centroids.dtype,
device=sample_centroids.device)
for grade_name, component_indices in self.grade_map.items():
if grade_name.startswith('_'):
continue
centroids = getattr(self, f'centroids_{grade_name}')
idx = indices[grade_name]
values = centroids[idx]
n_components = len(component_indices)
values = values.reshape(flat_batch, self.n_groups, n_components)
mv_q[..., component_indices] = values
# Undo rotor rotation
mv_recon = self._unapply_rotors(mv_q)
# Extract unit vectors and rescale by stored norms
x_hat = extract_vectors_from_multivectors(mv_recon, self.d)
if '_norms' in indices:
norms = indices['_norms']
if norms.dim() < x_hat.dim():
norms = norms.unsqueeze(-1)
x_hat = x_hat * norms
return x_hat
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, dict]:
"""Full quantize-dequantize cycle."""
mv_q, indices = self.quantize(x)
x_hat = self.dequantize(indices)
return x_hat, indices
class RotorQuantProd(nn.Module):
"""
Stage 1 + Stage 2: Unbiased inner product quantizer.
Uses (b-1)-bit rotor MSE quantizer + 1-bit QJL on residuals.
The QJL operates in the original vector space (not multivector space)
since inner products are computed there.
"""
def __init__(self, d: int, bits: int, qjl_dim: Optional[int] = None,
seed: int = 42, device: str = "cpu"):
super().__init__()
self.d = d
self.bits = bits
self.mse_bits = max(bits - 1, 1)
self.qjl_dim = qjl_dim or d
self.device = device
# Stage 1: Rotor MSE quantizer
self.mse = RotorQuantMSE(d, self.mse_bits, seed=seed, device=device)
# Stage 2: QJL projection matrix (same as TurboQuant)
gen = torch.Generator(device='cpu')
gen.manual_seed(seed + 1)
S = torch.randn(self.qjl_dim, d, generator=gen)
self.register_buffer("S", S.to(device))
def quantize(self, x: torch.Tensor) -> dict:
"""Full RotorQuant quantization."""
# Stage 1: Rotor MSE
x_hat, mse_indices = self.mse(x)
# Residual
residual = x - x_hat
residual_norm = torch.norm(residual, dim=-1, keepdim=True)
# Stage 2: QJL sign quantization on residual
projected = residual @ self.S.T
qjl_signs = torch.sign(projected)
qjl_signs[qjl_signs == 0] = 1.0
return {
'mse_indices': mse_indices,
'qjl_signs': qjl_signs,
'residual_norm': residual_norm.squeeze(-1),
}
def dequantize(self, compressed: dict) -> torch.Tensor:
"""Reconstruct from MSE component."""
return self.mse.dequantize(compressed['mse_indices'])
def inner_product(self, y: torch.Tensor, compressed: dict) -> torch.Tensor:
"""
Unbiased inner product estimate: <y, x>.
Same QJL correction as TurboQuant — the math doesn't change
because QJL operates in vector space.
"""
x_mse = self.mse.dequantize(compressed['mse_indices'])
term1 = (y * x_mse).sum(dim=-1)
y_projected = y @ self.S.T
qjl_ip = (y_projected * compressed['qjl_signs']).sum(dim=-1)
m = self.qjl_dim
correction_scale = math.sqrt(math.pi / 2) / m
term2 = compressed['residual_norm'] * correction_scale * qjl_ip
return term1 + term2
def forward(self, x: torch.Tensor) -> dict:
return self.quantize(x)
class RotorQuantKVCache:
"""
KV cache using RotorQuant compression.
Drop-in replacement for TurboQuantKVCache.
"""
def __init__(self, d_key: int, d_value: int, bits: int = 3,
seed: int = 42, device: str = "cpu"):
self.d_key = d_key
self.d_value = d_value
self.bits = bits
self.device = device
self.key_quantizer = RotorQuantProd(d_key, bits, seed=seed, device=device)
self.value_quantizer = RotorQuantMSE(d_value, bits, seed=seed + 100, device=device)
self.key_cache = []
self.value_cache = []
def append(self, keys: torch.Tensor, values: torch.Tensor):
flat_keys = keys.reshape(-1, self.d_key)
flat_values = values.reshape(-1, self.d_value)
compressed_keys = self.key_quantizer.quantize(flat_keys)
_, value_indices = self.value_quantizer(flat_values)
self.key_cache.append(compressed_keys)
self.value_cache.append(value_indices)
def attention_scores(self, queries: torch.Tensor) -> torch.Tensor:
scores = []
for cached in self.key_cache:
s = self.key_quantizer.inner_product(queries, cached)
scores.append(s)
return torch.cat(scores, dim=-1) if scores else torch.tensor([])
def get_values(self) -> torch.Tensor:
values = []
for indices in self.value_cache:
v = self.value_quantizer.dequantize(indices)
values.append(v)
return torch.cat(values, dim=0) if values else torch.tensor([])
def __len__(self):
return sum(
c['qjl_signs'].shape[0] for c in self.key_cache
) if self.key_cache else 0
|