neshaki091
Deploy TurboQuant Backend (Cleaned history & optimized for HF Spaces)
ba86059
import torch
import numpy as np
from typing import Tuple, List, Optional
from dataclasses import dataclass
# =============================================================================
# Scalar Quantizer - Stage 1 (MSE)
# =============================================================================
@dataclass
class SQQuantized:
indices: torch.Tensor # (N, D) uint8/int - indices to centroids
centroids: torch.Tensor # (K,) - centroids values
class ScalarQuantizer:
"""
Scalar Quantizer (STAGE 1): Nén từng chiều độc lập dựa trên codebook.
Tối ưu hóa bằng thuật toán Max-Lloyd cho phân phối Gaussian.
"""
def __init__(self, dim: int, bits: int = 4, device: str = "cpu", dtype=torch.float32, seed: int = 42):
self.dim = dim
self.bits = bits
self.n_clusters = 2**bits
self.device = device
self.dtype = dtype
# Mặc định: Centroids tối ưu cho phân phối Gaussian N(0, 1/d)
# sau khi nhân với ma trận xoay Π (QJL).
# Các giá trị được tính sẵn từ thuật toán Max-Lloyd (Continuous K-Means).
scale = 1.0 / np.sqrt(dim)
if bits == 1:
# 1-bit MSE (2 states): ±sqrt(2/pi) * (1/sqrt(d)) ≈ ±0.798/sqrt(d)
val = 0.79788456 * scale
centroids = torch.tensor([-val, val], device=self.device, dtype=self.dtype)
elif bits == 3:
# 3-bit MSE (8 states): ±0.245, ±0.756, ±1.344, ±2.152
vals = np.array([-2.152, -1.344, -0.756, -0.245, 0.245, 0.756, 1.344, 2.152]) * scale
centroids = torch.from_numpy(vals).to(device=self.device, dtype=self.dtype)
else:
raise ValueError(f"ScalarQuantizer only supports 1-bit and 3-bit MSE configurations. Received: {bits}")
self.centroids = centroids
self.boundaries = self._get_boundaries(centroids)
def _get_boundaries(self, centroids: torch.Tensor) -> torch.Tensor:
"""Tính các mốc ranh giới (Decision Boundaries) giữa các centroids."""
boundaries = torch.zeros(len(centroids) + 1, device=self.device)
boundaries[0], boundaries[-1] = -1e10, 1e10
if len(centroids) > 1:
# Ranh giới tối ưu Voronoi là trung điểm của 2 centroids kế tiếp
boundaries[1:-1] = (centroids[:-1] + centroids[1:]) / 2
return boundaries
def fit(self, x: torch.Tensor, iterations: int = 50):
"""
Học centroids tối ưu bằng thuật toán Max-Lloyd (1D K-Means).
LƯU Ý: Với dữ liệu TurboQuant đã qua phép xoay Π, dữ liệu đã hội tụ về
phân phối chuẩn N(0, 1/d). Các centroids khởi tạo trong __init__ đã là
tối ưu lý thuyết, việc chạy fit() có thể không cần thiết hoặc chỉ điều chỉnh nhẹ.
"""
data_flat = x.flatten()
if data_flat.device.type != self.device:
data_flat = data_flat.to(self.device)
# Sampling 1M points để tối ưu tốc độ huấn luyện
if len(data_flat) > 1_000_000:
indices = torch.randperm(len(data_flat), device=self.device)[:1_000_000]
subset = data_flat[indices]
else:
subset = data_flat
# LƯỢNG TỬ HÓA CẢI TIẾN: Percentile Clipping (Cách 3)
# Cắt bỏ 1% giá trị dị biệt ở hai đầu để chống nhiễu (outliers), giúp độ phân giải 4-bit tốt hơn
q_low = torch.quantile(subset, 0.01)
q_high = torch.quantile(subset, 0.99)
subset = torch.clamp(subset, q_low, q_high)
# 1. Khởi tạo Centroids bằng Quantiles (Gần với tối ưu Gaussian ngay từ đầu)
p = torch.linspace(0, 1, self.n_clusters + 1, device=self.device)
p_mid = (p[:-1] + p[1:]) / 2
centroids = torch.quantile(subset, p_mid).sort()[0]
# 2. Vòng lặp Max-Lloyd
for _ in range(iterations):
boundaries = self._get_boundaries(centroids)
# Gán điểm vào các buckets (Searchsorted tìm ranh giới nhanh trong 1D)
bucket_indices = torch.searchsorted(boundaries, subset) - 1
bucket_indices = bucket_indices.clamp(0, self.n_clusters - 1)
# Tính Centroids mới = Mean của từng bucket
new_centroids = torch.zeros_like(centroids)
counts = torch.zeros(self.n_clusters, device=self.device)
sums = torch.zeros(self.n_clusters, device=self.device)
sums.scatter_add_(0, bucket_indices, subset)
counts.scatter_add_(0, bucket_indices, torch.ones_like(subset))
mask = counts > 0
new_centroids[mask] = sums[mask] / counts[mask]
# Xử lý các bucket trống (nếu có) bằng cách nội suy từ centroids cũ
new_centroids[~mask] = centroids[~mask]
# Kiểm tra hội tụ
if torch.allclose(centroids, new_centroids, atol=1e-6):
break
centroids = new_centroids.sort()[0]
self.centroids = centroids
self.boundaries = self._get_boundaries(centroids)
def quantize(self, x: torch.Tensor) -> SQQuantized:
"""Lượng tử hóa vector x thành các chỉ số (indices)."""
# Đảm bảo contiguous cho performance
# Sửa lỗi: Cần trừ 1 vì searchsorted trả về vị trí chèn, index của bucket là searchsorted - 1
indices = (torch.searchsorted(self.boundaries, x.contiguous()) - 1).clamp(0, self.n_clusters - 1)
indices_np = indices.cpu().numpy().astype(np.uint8)
# Bit-Packing Logic (2 values per byte for 4-bit)
n, d = indices_np.shape
bits = self.bits
vals_per_byte = 1
if bits == 1: vals_per_byte = 8
elif bits == 2: vals_per_byte = 4
elif bits in [3, 4]: vals_per_byte = 2
if vals_per_byte > 1:
packed_d = (d + vals_per_byte - 1) // vals_per_byte
packed_codes = np.zeros((n, packed_d), dtype=np.uint8)
for i in range(vals_per_byte):
# Pack values into bytes
subset = indices_np[:, i::vals_per_byte]
curr_d = subset.shape[1]
packed_codes[:, :curr_d] |= (subset << (i * bits))
return SQQuantized(indices=torch.from_numpy(packed_codes), centroids=self.centroids)
return SQQuantized(indices=indices.to(torch.uint8), centroids=self.centroids)
def reconstruct(self, codes: torch.Tensor) -> torch.Tensor:
"""Giải nén vector (De-quantize) từ mã đã nén."""
codes_np = codes.cpu().numpy()
n, packed_d = codes_np.shape
bits = self.bits
vals_per_byte = 1
if bits == 1: vals_per_byte = 8
elif bits == 2: vals_per_byte = 4
elif bits in [3, 4]: vals_per_byte = 2
bit_mask = (1 << bits) - 1
if vals_per_byte > 1:
indices = np.zeros((n, self.dim), dtype=np.uint8)
for i in range(vals_per_byte):
# Unpack
subset = (codes_np >> (i * bits)) & bit_mask
# i::vals_per_byte unpacking
indices[:, i::vals_per_byte] = subset[:, :((self.dim - i + vals_per_byte - 1) // vals_per_byte)]
indices_t = torch.from_numpy(indices).to(self.device).long()
return self.centroids[indices_t]