File size: 4,974 Bytes
18f4d80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Lloyd-Max optimal scalar quantizer for the Beta distribution
arising from random rotation of unit-norm vectors.

After rotating a d-dimensional unit vector by a random orthogonal matrix,
each coordinate follows: f(x) = Gamma(d/2) / (sqrt(pi) * Gamma((d-1)/2)) * (1-x^2)^((d-3)/2)
supported on [-1, 1].

For practical dimensions (d >= 64), this is well-approximated by N(0, 1/d).
We solve the Lloyd-Max conditions (continuous 1-D k-means) to find optimal centroids.
"""

import torch
import math
from scipy import integrate, special


def beta_pdf(x: float, d: int) -> float:
    """PDF of a single coordinate after random rotation of a d-dim unit vector."""
    if abs(x) >= 1.0:
        return 0.0
    coeff = math.gamma(d / 2) / (math.sqrt(math.pi) * math.gamma((d - 1) / 2))
    return coeff * (1 - x * x) ** ((d - 3) / 2)


def gaussian_approx_pdf(x: float, d: int) -> float:
    """Gaussian approximation N(0, 1/d) -- accurate for d >= 64."""
    sigma2 = 1.0 / d
    return (1.0 / math.sqrt(2 * math.pi * sigma2)) * math.exp(-x * x / (2 * sigma2))


def solve_lloyd_max(d: int, bits: int, use_exact: bool = False, max_iter: int = 200, tol: float = 1e-10):
    """
    Solve Lloyd-Max optimal quantizer for the coordinate distribution.

    Args:
        d: vector dimension
        bits: number of quantization bits
        use_exact: if True, use exact Beta PDF; if False, use Gaussian approx
        max_iter: maximum Lloyd-Max iterations
        tol: convergence tolerance

    Returns:
        centroids: sorted tensor of 2^bits optimal centroids
        boundaries: sorted tensor of 2^bits - 1 boundaries between centroids
    """
    n_levels = 2 ** bits
    pdf = (lambda x: beta_pdf(x, d)) if use_exact else (lambda x: gaussian_approx_pdf(x, d))
    sigma = 1.0 / math.sqrt(d)

    # Initialize centroids uniformly in [-3*sigma, 3*sigma]
    lo, hi = -3.5 * sigma, 3.5 * sigma
    centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]

    for iteration in range(max_iter):
        # Step 1: Compute boundaries (midpoints between adjacent centroids)
        boundaries = [(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)]

        # Step 2: Update centroids as conditional expectations E[X | X in partition_i]
        edges = [lo * 3] + boundaries + [hi * 3]
        new_centroids = []
        for i in range(n_levels):
            a, b = edges[i], edges[i + 1]

            numerator, _ = integrate.quad(lambda x: x * pdf(x), a, b)
            denominator, _ = integrate.quad(pdf, a, b)

            if denominator > 1e-15:
                new_centroids.append(numerator / denominator)
            else:
                new_centroids.append(centroids[i])

        # Check convergence
        max_shift = max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels))
        centroids = new_centroids

        if max_shift < tol:
            break

    # Final boundaries
    boundaries = [(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)]

    return (
        torch.tensor(centroids, dtype=torch.float32),
        torch.tensor(boundaries, dtype=torch.float32),
    )


def compute_expected_distortion(d: int, bits: int, centroids: torch.Tensor, boundaries: torch.Tensor, use_exact: bool = False) -> float:
    """Compute the expected MSE distortion per coordinate for the given quantizer."""
    pdf = (lambda x: beta_pdf(x, d)) if use_exact else (lambda x: gaussian_approx_pdf(x, d))
    sigma = 1.0 / math.sqrt(d)
    n_levels = len(centroids)

    edges = [-3.5 * sigma * 3] + boundaries.tolist() + [3.5 * sigma * 3]
    total_distortion = 0.0

    for i in range(n_levels):
        a, b = edges[i], edges[i + 1]
        c = centroids[i].item()
        dist, _ = integrate.quad(lambda x: (x - c) ** 2 * pdf(x), a, b)
        total_distortion += dist

    return total_distortion


class LloydMaxCodebook:
    """Precomputed Lloyd-Max codebook for a given dimension and bit-width."""

    def __init__(self, d: int, bits: int, use_exact: bool = False):
        self.d = d
        self.bits = bits
        self.n_levels = 2 ** bits
        self.centroids, self.boundaries = solve_lloyd_max(d, bits, use_exact)
        self.distortion = compute_expected_distortion(d, bits, self.centroids, self.boundaries, use_exact)

    def quantize(self, x: torch.Tensor) -> torch.Tensor:
        """Quantize values to nearest centroid indices."""
        # x: (...,) -> indices: (...,) as uint8/int16
        diffs = (x.unsqueeze(-1) - self.centroids.to(x.device))  # (..., n_levels)
        return diffs.abs().argmin(dim=-1)

    def dequantize(self, indices: torch.Tensor) -> torch.Tensor:
        """Map indices back to centroid values."""
        return self.centroids.to(indices.device)[indices]

    def __repr__(self):
        return (
            f"LloydMaxCodebook(d={self.d}, bits={self.bits}, "
            f"levels={self.n_levels}, distortion_per_coord={self.distortion:.6f})"
        )