Spaces:
Sleeping
Sleeping
| """Walsh-Hadamard transform helpers. | |
| We use: | |
| - hadamard_matrix(n) for arbitrary power-of-2 n | |
| - block_hadamard_inplace() to apply n×n WHT to fixed-size blocks of | |
| a longer vector / row of a matrix | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import numpy as np | |
| import torch | |
| def hadamard_matrix(n: int, dtype=torch.float32) -> torch.Tensor: | |
| """Return the n×n NORMALIZED Walsh-Hadamard matrix. | |
| H is its own inverse (H H = I) so quantization rotations cancel under | |
| H @ Wᵀ ··· W @ Hᵀ ≡ identity. n must be a power of 2. | |
| """ | |
| if n <= 0 or (n & (n - 1)) != 0: | |
| raise ValueError(f"n must be a positive power of 2, got {n}") | |
| H = torch.tensor([[1.0]], dtype=dtype) | |
| while H.shape[0] < n: | |
| H = torch.cat( | |
| [torch.cat([H, H], dim=1), torch.cat([H, -H], dim=1)], | |
| dim=0, | |
| ) | |
| return H / math.sqrt(n) | |
| def block_hadamard_inplace(W: torch.Tensor, axis: int = -1, block: int = 128) -> None: | |
| """Apply n×n Hadamard to every contiguous `block`-sized slice along `axis`. | |
| Used when the full dim isn't a power of 2 (e.g. ffn_dim=11008). Block | |
| size 128 fits comfortably in L1 and is a power of 2. | |
| """ | |
| n = W.shape[axis] | |
| if n % block != 0: | |
| raise ValueError(f"axis dim {n} not divisible by block {block}") | |
| H = hadamard_matrix(block, dtype=W.dtype).to(W.device) | |
| # Reshape axis -> (n//block, block), apply H on the last dim, reshape back. | |
| moved = W.transpose(axis, -1) # bring axis to last | |
| shape = moved.shape | |
| g = shape[-1] // block | |
| moved = moved.reshape(*shape[:-1], g, block) | |
| moved = moved @ H # last-axis matmul; H is symmetric | |
| moved = moved.reshape(*shape) | |
| out = moved.transpose(axis, -1) | |
| W.copy_(out) | |
| def is_orthogonal(H: torch.Tensor, tol: float = 1e-5) -> bool: | |
| """Self-check: H @ Hᵀ ≈ I.""" | |
| n = H.shape[0] | |
| err = (H @ H.t() - torch.eye(n, dtype=H.dtype)).abs().max().item() | |
| return err < tol | |
| # numpy convenience for tooling that doesn't want a torch dep | |
| def hadamard_matrix_np(n: int) -> np.ndarray: | |
| return hadamard_matrix(n, dtype=torch.float64).numpy() | |