Spaces:
Running
Running
File size: 7,655 Bytes
ba86059 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | import torch
import numpy as np
from typing import Tuple, List, Optional
from dataclasses import dataclass
# =============================================================================
# Scalar Quantizer - Stage 1 (MSE)
# =============================================================================
@dataclass
class SQQuantized:
indices: torch.Tensor # (N, D) uint8/int - indices to centroids
centroids: torch.Tensor # (K,) - centroids values
class ScalarQuantizer:
"""
Scalar Quantizer (STAGE 1): Nén từng chiều độc lập dựa trên codebook.
Tối ưu hóa bằng thuật toán Max-Lloyd cho phân phối Gaussian.
"""
def __init__(self, dim: int, bits: int = 4, device: str = "cpu", dtype=torch.float32, seed: int = 42):
self.dim = dim
self.bits = bits
self.n_clusters = 2**bits
self.device = device
self.dtype = dtype
# Mặc định: Centroids tối ưu cho phân phối Gaussian N(0, 1/d)
# sau khi nhân với ma trận xoay Π (QJL).
# Các giá trị được tính sẵn từ thuật toán Max-Lloyd (Continuous K-Means).
scale = 1.0 / np.sqrt(dim)
if bits == 1:
# 1-bit MSE (2 states): ±sqrt(2/pi) * (1/sqrt(d)) ≈ ±0.798/sqrt(d)
val = 0.79788456 * scale
centroids = torch.tensor([-val, val], device=self.device, dtype=self.dtype)
elif bits == 3:
# 3-bit MSE (8 states): ±0.245, ±0.756, ±1.344, ±2.152
vals = np.array([-2.152, -1.344, -0.756, -0.245, 0.245, 0.756, 1.344, 2.152]) * scale
centroids = torch.from_numpy(vals).to(device=self.device, dtype=self.dtype)
else:
raise ValueError(f"ScalarQuantizer only supports 1-bit and 3-bit MSE configurations. Received: {bits}")
self.centroids = centroids
self.boundaries = self._get_boundaries(centroids)
def _get_boundaries(self, centroids: torch.Tensor) -> torch.Tensor:
"""Tính các mốc ranh giới (Decision Boundaries) giữa các centroids."""
boundaries = torch.zeros(len(centroids) + 1, device=self.device)
boundaries[0], boundaries[-1] = -1e10, 1e10
if len(centroids) > 1:
# Ranh giới tối ưu Voronoi là trung điểm của 2 centroids kế tiếp
boundaries[1:-1] = (centroids[:-1] + centroids[1:]) / 2
return boundaries
def fit(self, x: torch.Tensor, iterations: int = 50):
"""
Học centroids tối ưu bằng thuật toán Max-Lloyd (1D K-Means).
LƯU Ý: Với dữ liệu TurboQuant đã qua phép xoay Π, dữ liệu đã hội tụ về
phân phối chuẩn N(0, 1/d). Các centroids khởi tạo trong __init__ đã là
tối ưu lý thuyết, việc chạy fit() có thể không cần thiết hoặc chỉ điều chỉnh nhẹ.
"""
data_flat = x.flatten()
if data_flat.device.type != self.device:
data_flat = data_flat.to(self.device)
# Sampling 1M points để tối ưu tốc độ huấn luyện
if len(data_flat) > 1_000_000:
indices = torch.randperm(len(data_flat), device=self.device)[:1_000_000]
subset = data_flat[indices]
else:
subset = data_flat
# LƯỢNG TỬ HÓA CẢI TIẾN: Percentile Clipping (Cách 3)
# Cắt bỏ 1% giá trị dị biệt ở hai đầu để chống nhiễu (outliers), giúp độ phân giải 4-bit tốt hơn
q_low = torch.quantile(subset, 0.01)
q_high = torch.quantile(subset, 0.99)
subset = torch.clamp(subset, q_low, q_high)
# 1. Khởi tạo Centroids bằng Quantiles (Gần với tối ưu Gaussian ngay từ đầu)
p = torch.linspace(0, 1, self.n_clusters + 1, device=self.device)
p_mid = (p[:-1] + p[1:]) / 2
centroids = torch.quantile(subset, p_mid).sort()[0]
# 2. Vòng lặp Max-Lloyd
for _ in range(iterations):
boundaries = self._get_boundaries(centroids)
# Gán điểm vào các buckets (Searchsorted tìm ranh giới nhanh trong 1D)
bucket_indices = torch.searchsorted(boundaries, subset) - 1
bucket_indices = bucket_indices.clamp(0, self.n_clusters - 1)
# Tính Centroids mới = Mean của từng bucket
new_centroids = torch.zeros_like(centroids)
counts = torch.zeros(self.n_clusters, device=self.device)
sums = torch.zeros(self.n_clusters, device=self.device)
sums.scatter_add_(0, bucket_indices, subset)
counts.scatter_add_(0, bucket_indices, torch.ones_like(subset))
mask = counts > 0
new_centroids[mask] = sums[mask] / counts[mask]
# Xử lý các bucket trống (nếu có) bằng cách nội suy từ centroids cũ
new_centroids[~mask] = centroids[~mask]
# Kiểm tra hội tụ
if torch.allclose(centroids, new_centroids, atol=1e-6):
break
centroids = new_centroids.sort()[0]
self.centroids = centroids
self.boundaries = self._get_boundaries(centroids)
def quantize(self, x: torch.Tensor) -> SQQuantized:
"""Lượng tử hóa vector x thành các chỉ số (indices)."""
# Đảm bảo contiguous cho performance
# Sửa lỗi: Cần trừ 1 vì searchsorted trả về vị trí chèn, index của bucket là searchsorted - 1
indices = (torch.searchsorted(self.boundaries, x.contiguous()) - 1).clamp(0, self.n_clusters - 1)
indices_np = indices.cpu().numpy().astype(np.uint8)
# Bit-Packing Logic (2 values per byte for 4-bit)
n, d = indices_np.shape
bits = self.bits
vals_per_byte = 1
if bits == 1: vals_per_byte = 8
elif bits == 2: vals_per_byte = 4
elif bits in [3, 4]: vals_per_byte = 2
if vals_per_byte > 1:
packed_d = (d + vals_per_byte - 1) // vals_per_byte
packed_codes = np.zeros((n, packed_d), dtype=np.uint8)
for i in range(vals_per_byte):
# Pack values into bytes
subset = indices_np[:, i::vals_per_byte]
curr_d = subset.shape[1]
packed_codes[:, :curr_d] |= (subset << (i * bits))
return SQQuantized(indices=torch.from_numpy(packed_codes), centroids=self.centroids)
return SQQuantized(indices=indices.to(torch.uint8), centroids=self.centroids)
def reconstruct(self, codes: torch.Tensor) -> torch.Tensor:
"""Giải nén vector (De-quantize) từ mã đã nén."""
codes_np = codes.cpu().numpy()
n, packed_d = codes_np.shape
bits = self.bits
vals_per_byte = 1
if bits == 1: vals_per_byte = 8
elif bits == 2: vals_per_byte = 4
elif bits in [3, 4]: vals_per_byte = 2
bit_mask = (1 << bits) - 1
if vals_per_byte > 1:
indices = np.zeros((n, self.dim), dtype=np.uint8)
for i in range(vals_per_byte):
# Unpack
subset = (codes_np >> (i * bits)) & bit_mask
# i::vals_per_byte unpacking
indices[:, i::vals_per_byte] = subset[:, :((self.dim - i + vals_per_byte - 1) // vals_per_byte)]
indices_t = torch.from_numpy(indices).to(self.device).long()
return self.centroids[indices_t]
|