workflow-twin / env /quantizer.py
NDGCodes's picture
fix repo structure for HF
1a692ce
from __future__ import annotations
import math
import random
class RotatedQuantizedMemory:
def __init__(self, dimension: int, seed: int = 42) -> None:
if dimension <= 0:
raise ValueError("dimension must be positive")
self.dimension = dimension
self._random = random.Random(seed)
self._perm = list(range(dimension))
self._random.shuffle(self._perm)
self._inv_perm = [0] * dimension
for i, j in enumerate(self._perm):
self._inv_perm[j] = i
self._signs = [1.0 if self._random.random() < 0.5 else -1.0 for _ in range(dimension)]
self._qjl_matrix = [
[self._random.gauss(0.0, 1.0) for _ in range(dimension)] for _ in range(dimension)
]
self._qjl_matrix_t = [list(col) for col in zip(*self._qjl_matrix)]
self._codebooks: dict[int, list[float]] = {}
self._prod_code_cache: dict[tuple[int, tuple[float, ...]], dict] = {}
self._prod_recon_cache: dict[tuple[int, tuple[float, ...]], list[float]] = {}
self._prod_cache_max = 0
@staticmethod
def _vector_key(bits: int, vector: list[float]) -> tuple[int, tuple[float, ...]]:
rounded = tuple(round(value, 6) for value in vector)
return bits, rounded
@staticmethod
def _clone_code(code: dict) -> dict:
return {
"bits": int(code["bits"]),
"mse_code": {"bits": int(code["mse_code"]["bits"]), "indices": list(code["mse_code"]["indices"])},
"qjl_signs": list(code["qjl_signs"]),
"gamma": float(code["gamma"]),
}
def _cache_put(self, key: tuple[int, tuple[float, ...]], code: dict, reconstructed: list[float]) -> None:
if self._prod_cache_max <= 0:
return
self._prod_code_cache[key] = self._clone_code(code)
self._prod_recon_cache[key] = list(reconstructed)
if len(self._prod_code_cache) > self._prod_cache_max:
oldest_key = next(iter(self._prod_code_cache))
self._prod_code_cache.pop(oldest_key, None)
self._prod_recon_cache.pop(oldest_key, None)
def _rotate(self, vector: list[float]) -> list[float]:
return [self._signs[i] * vector[self._perm[i]] for i in range(self.dimension)]
def _inverse_rotate(self, vector: list[float]) -> list[float]:
out = [0.0] * self.dimension
for i in range(self.dimension):
out[self._perm[i]] = self._signs[i] * vector[i]
return out
def _codebook(self, bits: int) -> list[float]:
if bits <= 0:
raise ValueError("bits must be >= 1")
if bits in self._codebooks:
return self._codebooks[bits]
levels = 2**bits
step = 2.0 / levels
centroids = [-1.0 + step * (i + 0.5) for i in range(levels)]
self._codebooks[bits] = centroids
return centroids
def quantize_mse(self, vector: list[float], bits: int) -> dict:
if len(vector) != self.dimension:
raise ValueError("vector dimension mismatch")
rotated = self._rotate(vector)
centroids = self._codebook(bits)
levels = len(centroids)
step = 2.0 / levels
indices: list[int] = []
for value in rotated:
clipped = max(-1.0, min(1.0, value))
idx = int((clipped + 1.0) / step)
idx = max(0, min(levels - 1, idx))
indices.append(idx)
return {"bits": bits, "indices": indices}
def dequantize_mse(self, code: dict) -> list[float]:
bits = int(code["bits"])
indices: list[int] = code["indices"]
centroids = self._codebook(bits)
rotated_hat = [centroids[idx] for idx in indices]
return self._inverse_rotate(rotated_hat)
def _qjl_sign(self, residual: list[float]) -> list[int]:
signs: list[int] = []
for row in self._qjl_matrix:
dot = sum(a * b for a, b in zip(row, residual, strict=False))
signs.append(1 if dot >= 0 else -1)
return signs
def _qjl_inverse(self, signs: list[int]) -> list[float]:
scale = math.pi / (2.0 * self.dimension)
recovered = [0.0] * self.dimension
for j, col in enumerate(self._qjl_matrix_t):
recovered[j] = scale * sum(value * sign for value, sign in zip(col, signs, strict=False))
return recovered
def quantize_and_dequantize_prod(self, vector: list[float], bits: int) -> tuple[dict, list[float]]:
if len(vector) != self.dimension:
raise ValueError("vector dimension mismatch")
key: tuple[int, tuple[float, ...]] | None = None
if self._prod_cache_max > 0:
key = self._vector_key(bits, vector)
cached_code = self._prod_code_cache.get(key)
cached_recon = self._prod_recon_cache.get(key)
if cached_code is not None and cached_recon is not None:
return self._clone_code(cached_code), list(cached_recon)
mse_bits = max(bits - 1, 1)
mse_code = self.quantize_mse(vector, mse_bits)
mse_hat = self.dequantize_mse(mse_code)
residual = [x - y for x, y in zip(vector, mse_hat, strict=False)]
gamma = math.sqrt(sum(v * v for v in residual))
if gamma > 0:
unit_residual = [v / gamma for v in residual]
else:
unit_residual = [0.0] * self.dimension
code = {
"bits": bits,
"mse_code": mse_code,
"qjl_signs": self._qjl_sign(unit_residual),
"gamma": gamma,
}
residual_hat = self._qjl_inverse(code["qjl_signs"])
reconstructed = [m + gamma * r for m, r in zip(mse_hat, residual_hat, strict=False)]
if key is not None:
self._cache_put(key, code, reconstructed)
return self._clone_code(code), list(reconstructed)
return code, reconstructed
def quantize_prod(self, vector: list[float], bits: int) -> dict:
code, _ = self.quantize_and_dequantize_prod(vector, bits)
return code
def dequantize_prod(self, code: dict) -> list[float]:
mse_hat = self.dequantize_mse(code["mse_code"])
residual_hat = self._qjl_inverse(code["qjl_signs"])
gamma = float(code["gamma"])
return [m + gamma * r for m, r in zip(mse_hat, residual_hat, strict=False)]
@staticmethod
def compute_distortion(vector: list[float], reconstructed: list[float], query: list[float]) -> dict[str, float]:
dim = max(len(vector), 1)
mse = sum((x - y) ** 2 for x, y in zip(vector, reconstructed, strict=False)) / dim
ip_true = sum(a * b for a, b in zip(query, vector, strict=False))
ip_hat = sum(a * b for a, b in zip(query, reconstructed, strict=False))
inner_error = (ip_true - ip_hat) ** 2 / dim
return {"mse": mse, "inner_product_error": inner_error}
TurboQuantizer = RotatedQuantizedMemory