File size: 2,217 Bytes
4ef7879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""Walsh-Hadamard transform helpers.

We use:
  - hadamard_matrix(n)            for arbitrary power-of-2 n
  - block_hadamard_inplace()      to apply n×n WHT to fixed-size blocks of
                                  a longer vector / row of a matrix
"""
from __future__ import annotations

import math

import numpy as np
import torch


def hadamard_matrix(n: int, dtype=torch.float32) -> torch.Tensor:
    """Return the n×n NORMALIZED Walsh-Hadamard matrix.

    H is its own inverse (H H = I) so quantization rotations cancel under
    H @ Wᵀ ··· W @ Hᵀ ≡ identity. n must be a power of 2.
    """
    if n <= 0 or (n & (n - 1)) != 0:
        raise ValueError(f"n must be a positive power of 2, got {n}")
    H = torch.tensor([[1.0]], dtype=dtype)
    while H.shape[0] < n:
        H = torch.cat(
            [torch.cat([H, H], dim=1), torch.cat([H, -H], dim=1)],
            dim=0,
        )
    return H / math.sqrt(n)


def block_hadamard_inplace(W: torch.Tensor, axis: int = -1, block: int = 128) -> None:
    """Apply n×n Hadamard to every contiguous `block`-sized slice along `axis`.

    Used when the full dim isn't a power of 2 (e.g. ffn_dim=11008). Block
    size 128 fits comfortably in L1 and is a power of 2.
    """
    n = W.shape[axis]
    if n % block != 0:
        raise ValueError(f"axis dim {n} not divisible by block {block}")
    H = hadamard_matrix(block, dtype=W.dtype).to(W.device)
    # Reshape axis -> (n//block, block), apply H on the last dim, reshape back.
    moved = W.transpose(axis, -1)            # bring axis to last
    shape = moved.shape
    g = shape[-1] // block
    moved = moved.reshape(*shape[:-1], g, block)
    moved = moved @ H                        # last-axis matmul; H is symmetric
    moved = moved.reshape(*shape)
    out = moved.transpose(axis, -1)
    W.copy_(out)


def is_orthogonal(H: torch.Tensor, tol: float = 1e-5) -> bool:
    """Self-check: H @ Hᵀ ≈ I."""
    n = H.shape[0]
    err = (H @ H.t() - torch.eye(n, dtype=H.dtype)).abs().max().item()
    return err < tol


# numpy convenience for tooling that doesn't want a torch dep
def hadamard_matrix_np(n: int) -> np.ndarray:
    return hadamard_matrix(n, dtype=torch.float64).numpy()