Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import math | |
| import random | |
| class RotatedQuantizedMemory: | |
| def __init__(self, dimension: int, seed: int = 42) -> None: | |
| if dimension <= 0: | |
| raise ValueError("dimension must be positive") | |
| self.dimension = dimension | |
| self._random = random.Random(seed) | |
| self._perm = list(range(dimension)) | |
| self._random.shuffle(self._perm) | |
| self._inv_perm = [0] * dimension | |
| for i, j in enumerate(self._perm): | |
| self._inv_perm[j] = i | |
| self._signs = [1.0 if self._random.random() < 0.5 else -1.0 for _ in range(dimension)] | |
| self._qjl_matrix = [ | |
| [self._random.gauss(0.0, 1.0) for _ in range(dimension)] for _ in range(dimension) | |
| ] | |
| self._qjl_matrix_t = [list(col) for col in zip(*self._qjl_matrix)] | |
| self._codebooks: dict[int, list[float]] = {} | |
| self._prod_code_cache: dict[tuple[int, tuple[float, ...]], dict] = {} | |
| self._prod_recon_cache: dict[tuple[int, tuple[float, ...]], list[float]] = {} | |
| self._prod_cache_max = 0 | |
| def _vector_key(bits: int, vector: list[float]) -> tuple[int, tuple[float, ...]]: | |
| rounded = tuple(round(value, 6) for value in vector) | |
| return bits, rounded | |
| def _clone_code(code: dict) -> dict: | |
| return { | |
| "bits": int(code["bits"]), | |
| "mse_code": {"bits": int(code["mse_code"]["bits"]), "indices": list(code["mse_code"]["indices"])}, | |
| "qjl_signs": list(code["qjl_signs"]), | |
| "gamma": float(code["gamma"]), | |
| } | |
| def _cache_put(self, key: tuple[int, tuple[float, ...]], code: dict, reconstructed: list[float]) -> None: | |
| if self._prod_cache_max <= 0: | |
| return | |
| self._prod_code_cache[key] = self._clone_code(code) | |
| self._prod_recon_cache[key] = list(reconstructed) | |
| if len(self._prod_code_cache) > self._prod_cache_max: | |
| oldest_key = next(iter(self._prod_code_cache)) | |
| self._prod_code_cache.pop(oldest_key, None) | |
| self._prod_recon_cache.pop(oldest_key, None) | |
| def _rotate(self, vector: list[float]) -> list[float]: | |
| return [self._signs[i] * vector[self._perm[i]] for i in range(self.dimension)] | |
| def _inverse_rotate(self, vector: list[float]) -> list[float]: | |
| out = [0.0] * self.dimension | |
| for i in range(self.dimension): | |
| out[self._perm[i]] = self._signs[i] * vector[i] | |
| return out | |
| def _codebook(self, bits: int) -> list[float]: | |
| if bits <= 0: | |
| raise ValueError("bits must be >= 1") | |
| if bits in self._codebooks: | |
| return self._codebooks[bits] | |
| levels = 2**bits | |
| step = 2.0 / levels | |
| centroids = [-1.0 + step * (i + 0.5) for i in range(levels)] | |
| self._codebooks[bits] = centroids | |
| return centroids | |
| def quantize_mse(self, vector: list[float], bits: int) -> dict: | |
| if len(vector) != self.dimension: | |
| raise ValueError("vector dimension mismatch") | |
| rotated = self._rotate(vector) | |
| centroids = self._codebook(bits) | |
| levels = len(centroids) | |
| step = 2.0 / levels | |
| indices: list[int] = [] | |
| for value in rotated: | |
| clipped = max(-1.0, min(1.0, value)) | |
| idx = int((clipped + 1.0) / step) | |
| idx = max(0, min(levels - 1, idx)) | |
| indices.append(idx) | |
| return {"bits": bits, "indices": indices} | |
| def dequantize_mse(self, code: dict) -> list[float]: | |
| bits = int(code["bits"]) | |
| indices: list[int] = code["indices"] | |
| centroids = self._codebook(bits) | |
| rotated_hat = [centroids[idx] for idx in indices] | |
| return self._inverse_rotate(rotated_hat) | |
| def _qjl_sign(self, residual: list[float]) -> list[int]: | |
| signs: list[int] = [] | |
| for row in self._qjl_matrix: | |
| dot = sum(a * b for a, b in zip(row, residual, strict=False)) | |
| signs.append(1 if dot >= 0 else -1) | |
| return signs | |
| def _qjl_inverse(self, signs: list[int]) -> list[float]: | |
| scale = math.pi / (2.0 * self.dimension) | |
| recovered = [0.0] * self.dimension | |
| for j, col in enumerate(self._qjl_matrix_t): | |
| recovered[j] = scale * sum(value * sign for value, sign in zip(col, signs, strict=False)) | |
| return recovered | |
| def quantize_and_dequantize_prod(self, vector: list[float], bits: int) -> tuple[dict, list[float]]: | |
| if len(vector) != self.dimension: | |
| raise ValueError("vector dimension mismatch") | |
| key: tuple[int, tuple[float, ...]] | None = None | |
| if self._prod_cache_max > 0: | |
| key = self._vector_key(bits, vector) | |
| cached_code = self._prod_code_cache.get(key) | |
| cached_recon = self._prod_recon_cache.get(key) | |
| if cached_code is not None and cached_recon is not None: | |
| return self._clone_code(cached_code), list(cached_recon) | |
| mse_bits = max(bits - 1, 1) | |
| mse_code = self.quantize_mse(vector, mse_bits) | |
| mse_hat = self.dequantize_mse(mse_code) | |
| residual = [x - y for x, y in zip(vector, mse_hat, strict=False)] | |
| gamma = math.sqrt(sum(v * v for v in residual)) | |
| if gamma > 0: | |
| unit_residual = [v / gamma for v in residual] | |
| else: | |
| unit_residual = [0.0] * self.dimension | |
| code = { | |
| "bits": bits, | |
| "mse_code": mse_code, | |
| "qjl_signs": self._qjl_sign(unit_residual), | |
| "gamma": gamma, | |
| } | |
| residual_hat = self._qjl_inverse(code["qjl_signs"]) | |
| reconstructed = [m + gamma * r for m, r in zip(mse_hat, residual_hat, strict=False)] | |
| if key is not None: | |
| self._cache_put(key, code, reconstructed) | |
| return self._clone_code(code), list(reconstructed) | |
| return code, reconstructed | |
| def quantize_prod(self, vector: list[float], bits: int) -> dict: | |
| code, _ = self.quantize_and_dequantize_prod(vector, bits) | |
| return code | |
| def dequantize_prod(self, code: dict) -> list[float]: | |
| mse_hat = self.dequantize_mse(code["mse_code"]) | |
| residual_hat = self._qjl_inverse(code["qjl_signs"]) | |
| gamma = float(code["gamma"]) | |
| return [m + gamma * r for m, r in zip(mse_hat, residual_hat, strict=False)] | |
| def compute_distortion(vector: list[float], reconstructed: list[float], query: list[float]) -> dict[str, float]: | |
| dim = max(len(vector), 1) | |
| mse = sum((x - y) ** 2 for x, y in zip(vector, reconstructed, strict=False)) / dim | |
| ip_true = sum(a * b for a, b in zip(query, vector, strict=False)) | |
| ip_hat = sum(a * b for a, b in zip(query, reconstructed, strict=False)) | |
| inner_error = (ip_true - ip_hat) ** 2 / dim | |
| return {"mse": mse, "inner_product_error": inner_error} | |
| TurboQuantizer = RotatedQuantizedMemory | |