File size: 1,995 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""tilelli.core.hadamard — orthogonal-rotation utilities for ternary quantization.

Quantization-error reduction trick from QuaRot / SpinQuant (2024). Multiplying
a weight matrix by an orthogonal matrix H spreads the energy of any single
position across all positions, flattening outliers and producing a more
Gaussian-like distribution that ternarizes with less rounding error.

Sylvester construction works only for n = 2^k. For other sizes we fall
back to a fixed-seed random orthogonal matrix (Householder/QR rotations),
treated as equivalent in practice for quantization purposes.
"""
from __future__ import annotations

import functools

import torch
from torch import Tensor


def _is_power_of_two(n: int) -> bool:
    return n > 0 and (n & (n - 1)) == 0


def _sylvester_hadamard(n: int) -> Tensor:
    if not _is_power_of_two(n):
        raise ValueError(f"Sylvester Hadamard requires power-of-2 size, got {n}")
    h = torch.tensor([[1.0]])
    while h.size(0) < n:
        top = torch.cat([h, h], dim=1)
        bot = torch.cat([h, -h], dim=1)
        h = torch.cat([top, bot], dim=0) / (2.0**0.5)
    return h


def _random_orthogonal(n: int, seed: int = 1234) -> Tensor:
    g = torch.Generator(device="cpu").manual_seed(seed)
    a = torch.randn(n, n, generator=g, dtype=torch.float64)
    q, _r = torch.linalg.qr(a)
    return q.to(torch.float32)


@functools.lru_cache(maxsize=64)
def hadamard_matrix(n: int, seed: int = 1234) -> Tensor:
    if _is_power_of_two(n):
        return _sylvester_hadamard(n)
    return _random_orthogonal(n, seed=seed)


def rotate_columns(w: Tensor, h: Tensor | None = None) -> Tensor:
    n = w.size(-1)
    if h is None:
        h = hadamard_matrix(n).to(dtype=w.dtype, device=w.device)
    return w @ h


def rotate_input(x: Tensor, n: int, h: Tensor | None = None) -> Tensor:
    if h is None:
        h = hadamard_matrix(n).to(dtype=x.dtype, device=x.device)
    return x @ h


__all__ = ["hadamard_matrix", "rotate_columns", "rotate_input"]