from __future__ import annotations import math import random class RotatedQuantizedMemory: def __init__(self, dimension: int, seed: int = 42) -> None: if dimension <= 0: raise ValueError("dimension must be positive") self.dimension = dimension self._random = random.Random(seed) self._perm = list(range(dimension)) self._random.shuffle(self._perm) self._inv_perm = [0] * dimension for i, j in enumerate(self._perm): self._inv_perm[j] = i self._signs = [1.0 if self._random.random() < 0.5 else -1.0 for _ in range(dimension)] self._qjl_matrix = [ [self._random.gauss(0.0, 1.0) for _ in range(dimension)] for _ in range(dimension) ] self._qjl_matrix_t = [list(col) for col in zip(*self._qjl_matrix)] self._codebooks: dict[int, list[float]] = {} self._prod_code_cache: dict[tuple[int, tuple[float, ...]], dict] = {} self._prod_recon_cache: dict[tuple[int, tuple[float, ...]], list[float]] = {} self._prod_cache_max = 0 @staticmethod def _vector_key(bits: int, vector: list[float]) -> tuple[int, tuple[float, ...]]: rounded = tuple(round(value, 6) for value in vector) return bits, rounded @staticmethod def _clone_code(code: dict) -> dict: return { "bits": int(code["bits"]), "mse_code": {"bits": int(code["mse_code"]["bits"]), "indices": list(code["mse_code"]["indices"])}, "qjl_signs": list(code["qjl_signs"]), "gamma": float(code["gamma"]), } def _cache_put(self, key: tuple[int, tuple[float, ...]], code: dict, reconstructed: list[float]) -> None: if self._prod_cache_max <= 0: return self._prod_code_cache[key] = self._clone_code(code) self._prod_recon_cache[key] = list(reconstructed) if len(self._prod_code_cache) > self._prod_cache_max: oldest_key = next(iter(self._prod_code_cache)) self._prod_code_cache.pop(oldest_key, None) self._prod_recon_cache.pop(oldest_key, None) def _rotate(self, vector: list[float]) -> list[float]: return [self._signs[i] * vector[self._perm[i]] for i in range(self.dimension)] def _inverse_rotate(self, vector: list[float]) -> list[float]: out = [0.0] * self.dimension for i in range(self.dimension): out[self._perm[i]] = self._signs[i] * vector[i] return out def _codebook(self, bits: int) -> list[float]: if bits <= 0: raise ValueError("bits must be >= 1") if bits in self._codebooks: return self._codebooks[bits] levels = 2**bits step = 2.0 / levels centroids = [-1.0 + step * (i + 0.5) for i in range(levels)] self._codebooks[bits] = centroids return centroids def quantize_mse(self, vector: list[float], bits: int) -> dict: if len(vector) != self.dimension: raise ValueError("vector dimension mismatch") rotated = self._rotate(vector) centroids = self._codebook(bits) levels = len(centroids) step = 2.0 / levels indices: list[int] = [] for value in rotated: clipped = max(-1.0, min(1.0, value)) idx = int((clipped + 1.0) / step) idx = max(0, min(levels - 1, idx)) indices.append(idx) return {"bits": bits, "indices": indices} def dequantize_mse(self, code: dict) -> list[float]: bits = int(code["bits"]) indices: list[int] = code["indices"] centroids = self._codebook(bits) rotated_hat = [centroids[idx] for idx in indices] return self._inverse_rotate(rotated_hat) def _qjl_sign(self, residual: list[float]) -> list[int]: signs: list[int] = [] for row in self._qjl_matrix: dot = sum(a * b for a, b in zip(row, residual, strict=False)) signs.append(1 if dot >= 0 else -1) return signs def _qjl_inverse(self, signs: list[int]) -> list[float]: scale = math.pi / (2.0 * self.dimension) recovered = [0.0] * self.dimension for j, col in enumerate(self._qjl_matrix_t): recovered[j] = scale * sum(value * sign for value, sign in zip(col, signs, strict=False)) return recovered def quantize_and_dequantize_prod(self, vector: list[float], bits: int) -> tuple[dict, list[float]]: if len(vector) != self.dimension: raise ValueError("vector dimension mismatch") key: tuple[int, tuple[float, ...]] | None = None if self._prod_cache_max > 0: key = self._vector_key(bits, vector) cached_code = self._prod_code_cache.get(key) cached_recon = self._prod_recon_cache.get(key) if cached_code is not None and cached_recon is not None: return self._clone_code(cached_code), list(cached_recon) mse_bits = max(bits - 1, 1) mse_code = self.quantize_mse(vector, mse_bits) mse_hat = self.dequantize_mse(mse_code) residual = [x - y for x, y in zip(vector, mse_hat, strict=False)] gamma = math.sqrt(sum(v * v for v in residual)) if gamma > 0: unit_residual = [v / gamma for v in residual] else: unit_residual = [0.0] * self.dimension code = { "bits": bits, "mse_code": mse_code, "qjl_signs": self._qjl_sign(unit_residual), "gamma": gamma, } residual_hat = self._qjl_inverse(code["qjl_signs"]) reconstructed = [m + gamma * r for m, r in zip(mse_hat, residual_hat, strict=False)] if key is not None: self._cache_put(key, code, reconstructed) return self._clone_code(code), list(reconstructed) return code, reconstructed def quantize_prod(self, vector: list[float], bits: int) -> dict: code, _ = self.quantize_and_dequantize_prod(vector, bits) return code def dequantize_prod(self, code: dict) -> list[float]: mse_hat = self.dequantize_mse(code["mse_code"]) residual_hat = self._qjl_inverse(code["qjl_signs"]) gamma = float(code["gamma"]) return [m + gamma * r for m, r in zip(mse_hat, residual_hat, strict=False)] @staticmethod def compute_distortion(vector: list[float], reconstructed: list[float], query: list[float]) -> dict[str, float]: dim = max(len(vector), 1) mse = sum((x - y) ** 2 for x, y in zip(vector, reconstructed, strict=False)) / dim ip_true = sum(a * b for a, b in zip(query, vector, strict=False)) ip_hat = sum(a * b for a, b in zip(query, reconstructed, strict=False)) inner_error = (ip_true - ip_hat) ** 2 / dim return {"mse": mse, "inner_product_error": inner_error} TurboQuantizer = RotatedQuantizedMemory