File size: 1,995 Bytes
f86dc09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | """tilelli.core.hadamard — orthogonal-rotation utilities for ternary quantization.
Quantization-error reduction trick from QuaRot / SpinQuant (2024). Multiplying
a weight matrix by an orthogonal matrix H spreads the energy of any single
position across all positions, flattening outliers and producing a more
Gaussian-like distribution that ternarizes with less rounding error.
Sylvester construction works only for n = 2^k. For other sizes we fall
back to a fixed-seed random orthogonal matrix (Householder/QR rotations),
treated as equivalent in practice for quantization purposes.
"""
from __future__ import annotations
import functools
import torch
from torch import Tensor
def _is_power_of_two(n: int) -> bool:
return n > 0 and (n & (n - 1)) == 0
def _sylvester_hadamard(n: int) -> Tensor:
if not _is_power_of_two(n):
raise ValueError(f"Sylvester Hadamard requires power-of-2 size, got {n}")
h = torch.tensor([[1.0]])
while h.size(0) < n:
top = torch.cat([h, h], dim=1)
bot = torch.cat([h, -h], dim=1)
h = torch.cat([top, bot], dim=0) / (2.0**0.5)
return h
def _random_orthogonal(n: int, seed: int = 1234) -> Tensor:
g = torch.Generator(device="cpu").manual_seed(seed)
a = torch.randn(n, n, generator=g, dtype=torch.float64)
q, _r = torch.linalg.qr(a)
return q.to(torch.float32)
@functools.lru_cache(maxsize=64)
def hadamard_matrix(n: int, seed: int = 1234) -> Tensor:
if _is_power_of_two(n):
return _sylvester_hadamard(n)
return _random_orthogonal(n, seed=seed)
def rotate_columns(w: Tensor, h: Tensor | None = None) -> Tensor:
n = w.size(-1)
if h is None:
h = hadamard_matrix(n).to(dtype=w.dtype, device=w.device)
return w @ h
def rotate_input(x: Tensor, n: int, h: Tensor | None = None) -> Tensor:
if h is None:
h = hadamard_matrix(n).to(dtype=x.dtype, device=x.device)
return x @ h
__all__ = ["hadamard_matrix", "rotate_columns", "rotate_input"]
|