Spaces:
Running
Running
| import torch | |
| import numpy as np | |
| from typing import Tuple, List, Optional | |
| from dataclasses import dataclass | |
| # ============================================================================= | |
| # Scalar Quantizer - Stage 1 (MSE) | |
| # ============================================================================= | |
| class SQQuantized: | |
| indices: torch.Tensor # (N, D) uint8/int - indices to centroids | |
| centroids: torch.Tensor # (K,) - centroids values | |
| class ScalarQuantizer: | |
| """ | |
| Scalar Quantizer (STAGE 1): Nén từng chiều độc lập dựa trên codebook. | |
| Tối ưu hóa bằng thuật toán Max-Lloyd cho phân phối Gaussian. | |
| """ | |
| def __init__(self, dim: int, bits: int = 4, device: str = "cpu", dtype=torch.float32, seed: int = 42): | |
| self.dim = dim | |
| self.bits = bits | |
| self.n_clusters = 2**bits | |
| self.device = device | |
| self.dtype = dtype | |
| # Mặc định: Centroids tối ưu cho phân phối Gaussian N(0, 1/d) | |
| # sau khi nhân với ma trận xoay Π (QJL). | |
| # Các giá trị được tính sẵn từ thuật toán Max-Lloyd (Continuous K-Means). | |
| scale = 1.0 / np.sqrt(dim) | |
| if bits == 1: | |
| # 1-bit MSE (2 states): ±sqrt(2/pi) * (1/sqrt(d)) ≈ ±0.798/sqrt(d) | |
| val = 0.79788456 * scale | |
| centroids = torch.tensor([-val, val], device=self.device, dtype=self.dtype) | |
| elif bits == 3: | |
| # 3-bit MSE (8 states): ±0.245, ±0.756, ±1.344, ±2.152 | |
| vals = np.array([-2.152, -1.344, -0.756, -0.245, 0.245, 0.756, 1.344, 2.152]) * scale | |
| centroids = torch.from_numpy(vals).to(device=self.device, dtype=self.dtype) | |
| else: | |
| raise ValueError(f"ScalarQuantizer only supports 1-bit and 3-bit MSE configurations. Received: {bits}") | |
| self.centroids = centroids | |
| self.boundaries = self._get_boundaries(centroids) | |
| def _get_boundaries(self, centroids: torch.Tensor) -> torch.Tensor: | |
| """Tính các mốc ranh giới (Decision Boundaries) giữa các centroids.""" | |
| boundaries = torch.zeros(len(centroids) + 1, device=self.device) | |
| boundaries[0], boundaries[-1] = -1e10, 1e10 | |
| if len(centroids) > 1: | |
| # Ranh giới tối ưu Voronoi là trung điểm của 2 centroids kế tiếp | |
| boundaries[1:-1] = (centroids[:-1] + centroids[1:]) / 2 | |
| return boundaries | |
| def fit(self, x: torch.Tensor, iterations: int = 50): | |
| """ | |
| Học centroids tối ưu bằng thuật toán Max-Lloyd (1D K-Means). | |
| LƯU Ý: Với dữ liệu TurboQuant đã qua phép xoay Π, dữ liệu đã hội tụ về | |
| phân phối chuẩn N(0, 1/d). Các centroids khởi tạo trong __init__ đã là | |
| tối ưu lý thuyết, việc chạy fit() có thể không cần thiết hoặc chỉ điều chỉnh nhẹ. | |
| """ | |
| data_flat = x.flatten() | |
| if data_flat.device.type != self.device: | |
| data_flat = data_flat.to(self.device) | |
| # Sampling 1M points để tối ưu tốc độ huấn luyện | |
| if len(data_flat) > 1_000_000: | |
| indices = torch.randperm(len(data_flat), device=self.device)[:1_000_000] | |
| subset = data_flat[indices] | |
| else: | |
| subset = data_flat | |
| # LƯỢNG TỬ HÓA CẢI TIẾN: Percentile Clipping (Cách 3) | |
| # Cắt bỏ 1% giá trị dị biệt ở hai đầu để chống nhiễu (outliers), giúp độ phân giải 4-bit tốt hơn | |
| q_low = torch.quantile(subset, 0.01) | |
| q_high = torch.quantile(subset, 0.99) | |
| subset = torch.clamp(subset, q_low, q_high) | |
| # 1. Khởi tạo Centroids bằng Quantiles (Gần với tối ưu Gaussian ngay từ đầu) | |
| p = torch.linspace(0, 1, self.n_clusters + 1, device=self.device) | |
| p_mid = (p[:-1] + p[1:]) / 2 | |
| centroids = torch.quantile(subset, p_mid).sort()[0] | |
| # 2. Vòng lặp Max-Lloyd | |
| for _ in range(iterations): | |
| boundaries = self._get_boundaries(centroids) | |
| # Gán điểm vào các buckets (Searchsorted tìm ranh giới nhanh trong 1D) | |
| bucket_indices = torch.searchsorted(boundaries, subset) - 1 | |
| bucket_indices = bucket_indices.clamp(0, self.n_clusters - 1) | |
| # Tính Centroids mới = Mean của từng bucket | |
| new_centroids = torch.zeros_like(centroids) | |
| counts = torch.zeros(self.n_clusters, device=self.device) | |
| sums = torch.zeros(self.n_clusters, device=self.device) | |
| sums.scatter_add_(0, bucket_indices, subset) | |
| counts.scatter_add_(0, bucket_indices, torch.ones_like(subset)) | |
| mask = counts > 0 | |
| new_centroids[mask] = sums[mask] / counts[mask] | |
| # Xử lý các bucket trống (nếu có) bằng cách nội suy từ centroids cũ | |
| new_centroids[~mask] = centroids[~mask] | |
| # Kiểm tra hội tụ | |
| if torch.allclose(centroids, new_centroids, atol=1e-6): | |
| break | |
| centroids = new_centroids.sort()[0] | |
| self.centroids = centroids | |
| self.boundaries = self._get_boundaries(centroids) | |
| def quantize(self, x: torch.Tensor) -> SQQuantized: | |
| """Lượng tử hóa vector x thành các chỉ số (indices).""" | |
| # Đảm bảo contiguous cho performance | |
| # Sửa lỗi: Cần trừ 1 vì searchsorted trả về vị trí chèn, index của bucket là searchsorted - 1 | |
| indices = (torch.searchsorted(self.boundaries, x.contiguous()) - 1).clamp(0, self.n_clusters - 1) | |
| indices_np = indices.cpu().numpy().astype(np.uint8) | |
| # Bit-Packing Logic (2 values per byte for 4-bit) | |
| n, d = indices_np.shape | |
| bits = self.bits | |
| vals_per_byte = 1 | |
| if bits == 1: vals_per_byte = 8 | |
| elif bits == 2: vals_per_byte = 4 | |
| elif bits in [3, 4]: vals_per_byte = 2 | |
| if vals_per_byte > 1: | |
| packed_d = (d + vals_per_byte - 1) // vals_per_byte | |
| packed_codes = np.zeros((n, packed_d), dtype=np.uint8) | |
| for i in range(vals_per_byte): | |
| # Pack values into bytes | |
| subset = indices_np[:, i::vals_per_byte] | |
| curr_d = subset.shape[1] | |
| packed_codes[:, :curr_d] |= (subset << (i * bits)) | |
| return SQQuantized(indices=torch.from_numpy(packed_codes), centroids=self.centroids) | |
| return SQQuantized(indices=indices.to(torch.uint8), centroids=self.centroids) | |
| def reconstruct(self, codes: torch.Tensor) -> torch.Tensor: | |
| """Giải nén vector (De-quantize) từ mã đã nén.""" | |
| codes_np = codes.cpu().numpy() | |
| n, packed_d = codes_np.shape | |
| bits = self.bits | |
| vals_per_byte = 1 | |
| if bits == 1: vals_per_byte = 8 | |
| elif bits == 2: vals_per_byte = 4 | |
| elif bits in [3, 4]: vals_per_byte = 2 | |
| bit_mask = (1 << bits) - 1 | |
| if vals_per_byte > 1: | |
| indices = np.zeros((n, self.dim), dtype=np.uint8) | |
| for i in range(vals_per_byte): | |
| # Unpack | |
| subset = (codes_np >> (i * bits)) & bit_mask | |
| # i::vals_per_byte unpacking | |
| indices[:, i::vals_per_byte] = subset[:, :((self.dim - i + vals_per_byte - 1) // vals_per_byte)] | |
| indices_t = torch.from_numpy(indices).to(self.device).long() | |
| return self.centroids[indices_t] | |