File size: 7,655 Bytes
ba86059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import torch
import numpy as np
from typing import Tuple, List, Optional
from dataclasses import dataclass

# =============================================================================
# Scalar Quantizer - Stage 1 (MSE)
# =============================================================================

@dataclass
class SQQuantized:
    indices: torch.Tensor  # (N, D) uint8/int - indices to centroids
    centroids: torch.Tensor # (K,) - centroids values

class ScalarQuantizer:
    """
    Scalar Quantizer (STAGE 1): Nén từng chiều độc lập dựa trên codebook.
    Tối ưu hóa bằng thuật toán Max-Lloyd cho phân phối Gaussian.
    """
    def __init__(self, dim: int, bits: int = 4, device: str = "cpu", dtype=torch.float32, seed: int = 42):
        self.dim = dim
        self.bits = bits
        self.n_clusters = 2**bits
        self.device = device
        self.dtype = dtype
        
        # Mặc định: Centroids tối ưu cho phân phối Gaussian N(0, 1/d)
        # sau khi nhân với ma trận xoay Π (QJL).
        # Các giá trị được tính sẵn từ thuật toán Max-Lloyd (Continuous K-Means).
        scale = 1.0 / np.sqrt(dim)
        
        if bits == 1:
            # 1-bit MSE (2 states): ±sqrt(2/pi) * (1/sqrt(d)) ≈ ±0.798/sqrt(d)
            val = 0.79788456 * scale
            centroids = torch.tensor([-val, val], device=self.device, dtype=self.dtype)
        elif bits == 3:
            # 3-bit MSE (8 states): ±0.245, ±0.756, ±1.344, ±2.152
            vals = np.array([-2.152, -1.344, -0.756, -0.245, 0.245, 0.756, 1.344, 2.152]) * scale
            centroids = torch.from_numpy(vals).to(device=self.device, dtype=self.dtype)
        else:
            raise ValueError(f"ScalarQuantizer only supports 1-bit and 3-bit MSE configurations. Received: {bits}")
            
        self.centroids = centroids
        self.boundaries = self._get_boundaries(centroids)

    def _get_boundaries(self, centroids: torch.Tensor) -> torch.Tensor:
        """Tính các mốc ranh giới (Decision Boundaries) giữa các centroids."""
        boundaries = torch.zeros(len(centroids) + 1, device=self.device)
        boundaries[0], boundaries[-1] = -1e10, 1e10
        if len(centroids) > 1:
            # Ranh giới tối ưu Voronoi là trung điểm của 2 centroids kế tiếp
            boundaries[1:-1] = (centroids[:-1] + centroids[1:]) / 2
        return boundaries

    def fit(self, x: torch.Tensor, iterations: int = 50):
        """
        Học centroids tối ưu bằng thuật toán Max-Lloyd (1D K-Means).
        LƯU Ý: Với dữ liệu TurboQuant đã qua phép xoay Π, dữ liệu đã hội tụ về 
        phân phối chuẩn N(0, 1/d). Các centroids khởi tạo trong __init__ đã là 
        tối ưu lý thuyết, việc chạy fit() có thể không cần thiết hoặc chỉ điều chỉnh nhẹ.
        """
        data_flat = x.flatten()
        if data_flat.device.type != self.device:
            data_flat = data_flat.to(self.device)
        
        # Sampling 1M points để tối ưu tốc độ huấn luyện
        if len(data_flat) > 1_000_000:
            indices = torch.randperm(len(data_flat), device=self.device)[:1_000_000]
            subset = data_flat[indices]
        else:
            subset = data_flat

        # LƯỢNG TỬ HÓA CẢI TIẾN: Percentile Clipping (Cách 3)
        # Cắt bỏ 1% giá trị dị biệt ở hai đầu để chống nhiễu (outliers), giúp độ phân giải 4-bit tốt hơn
        q_low = torch.quantile(subset, 0.01)
        q_high = torch.quantile(subset, 0.99)
        subset = torch.clamp(subset, q_low, q_high)

        # 1. Khởi tạo Centroids bằng Quantiles (Gần với tối ưu Gaussian ngay từ đầu)
        p = torch.linspace(0, 1, self.n_clusters + 1, device=self.device)
        p_mid = (p[:-1] + p[1:]) / 2
        centroids = torch.quantile(subset, p_mid).sort()[0]
        
        # 2. Vòng lặp Max-Lloyd
        for _ in range(iterations):
            boundaries = self._get_boundaries(centroids)
            
            # Gán điểm vào các buckets (Searchsorted tìm ranh giới nhanh trong 1D)
            bucket_indices = torch.searchsorted(boundaries, subset) - 1
            bucket_indices = bucket_indices.clamp(0, self.n_clusters - 1)
            
            # Tính Centroids mới = Mean của từng bucket
            new_centroids = torch.zeros_like(centroids)
            counts = torch.zeros(self.n_clusters, device=self.device)
            sums = torch.zeros(self.n_clusters, device=self.device)
            
            sums.scatter_add_(0, bucket_indices, subset)
            counts.scatter_add_(0, bucket_indices, torch.ones_like(subset))
            
            mask = counts > 0
            new_centroids[mask] = sums[mask] / counts[mask]
            
            # Xử lý các bucket trống (nếu có) bằng cách nội suy từ centroids cũ
            new_centroids[~mask] = centroids[~mask]
            
            # Kiểm tra hội tụ
            if torch.allclose(centroids, new_centroids, atol=1e-6):
                break
            centroids = new_centroids.sort()[0]
            
        self.centroids = centroids
        self.boundaries = self._get_boundaries(centroids)

    def quantize(self, x: torch.Tensor) -> SQQuantized:
        """Lượng tử hóa vector x thành các chỉ số (indices)."""
        # Đảm bảo contiguous cho performance
        # Sửa lỗi: Cần trừ 1 vì searchsorted trả về vị trí chèn, index của bucket là searchsorted - 1
        indices = (torch.searchsorted(self.boundaries, x.contiguous()) - 1).clamp(0, self.n_clusters - 1)
        indices_np = indices.cpu().numpy().astype(np.uint8)
        
        # Bit-Packing Logic (2 values per byte for 4-bit)
        n, d = indices_np.shape
        bits = self.bits
        vals_per_byte = 1
        if bits == 1: vals_per_byte = 8
        elif bits == 2: vals_per_byte = 4
        elif bits in [3, 4]: vals_per_byte = 2
        
        if vals_per_byte > 1:
            packed_d = (d + vals_per_byte - 1) // vals_per_byte
            packed_codes = np.zeros((n, packed_d), dtype=np.uint8)
            for i in range(vals_per_byte):
                # Pack values into bytes
                subset = indices_np[:, i::vals_per_byte]
                curr_d = subset.shape[1]
                packed_codes[:, :curr_d] |= (subset << (i * bits))
            return SQQuantized(indices=torch.from_numpy(packed_codes), centroids=self.centroids)
        
        return SQQuantized(indices=indices.to(torch.uint8), centroids=self.centroids)

    def reconstruct(self, codes: torch.Tensor) -> torch.Tensor:
        """Giải nén vector (De-quantize) từ mã đã nén."""
        codes_np = codes.cpu().numpy()
        n, packed_d = codes_np.shape
        bits = self.bits
        vals_per_byte = 1
        if bits == 1: vals_per_byte = 8
        elif bits == 2: vals_per_byte = 4
        elif bits in [3, 4]: vals_per_byte = 2
        
        bit_mask = (1 << bits) - 1
        
        if vals_per_byte > 1:
            indices = np.zeros((n, self.dim), dtype=np.uint8)
            for i in range(vals_per_byte):
                # Unpack
                subset = (codes_np >> (i * bits)) & bit_mask
                # i::vals_per_byte unpacking
                indices[:, i::vals_per_byte] = subset[:, :((self.dim - i + vals_per_byte - 1) // vals_per_byte)]
            
            indices_t = torch.from_numpy(indices).to(self.device).long()
            return self.centroids[indices_t]