| from __future__ import annotations | |
| import math | |
| from typing import Dict | |
| import numpy as np | |
| import torch | |
| from scipy.fft import dct as scipy_dct | |
| from scipy.linalg import hadamard | |
| def normalize_columns(C: torch.Tensor) -> torch.Tensor: | |
| """Return tensor with columns normalized to unit L2 norm.""" | |
| norms = torch.linalg.norm(C, dim=0, keepdim=True) | |
| norms = norms.clamp_min(torch.finfo(C.dtype).eps) | |
| return C / norms | |
| def hadamard_codes(L: int, K: int) -> torch.Tensor: | |
| """Return first ``K`` columns of a Hadamard matrix with ``L`` rows.""" | |
| if L <= 0 or K <= 0: | |
| return torch.empty(L, K) | |
| n = 1 << (max(L, K) - 1).bit_length() | |
| H = hadamard(n) | |
| C = torch.from_numpy(H[:L, :K]).to(dtype=torch.float32) | |
| return normalize_columns(C) | |
| def dct_codes(L: int, K: int) -> torch.Tensor: | |
| """Return first ``K`` DCT basis vectors of length ``L``.""" | |
| if L <= 0 or K <= 0: | |
| return torch.empty(L, K) | |
| basis = scipy_dct(np.eye(L), type=2, axis=0, norm="ortho") | |
| C = torch.from_numpy(basis[:, :K]).to(dtype=torch.float32) | |
| return normalize_columns(C) | |
| def gaussian_codes(L: int, K: int, seed: int = 0) -> torch.Tensor: | |
| """Return ``K`` Gaussian random codes of length ``L`` with unit norm.""" | |
| if L <= 0 or K <= 0: | |
| return torch.empty(L, K) | |
| gen = torch.Generator().manual_seed(seed) | |
| C = torch.randn(L, K, generator=gen) / math.sqrt(L) | |
| return normalize_columns(C) | |
| def gram_matrix(C: torch.Tensor) -> torch.Tensor: | |
| """Return the Gram matrix ``C^T C``.""" | |
| return C.T @ C | |
| def coherence_stats(C: torch.Tensor) -> Dict[str, float]: | |
| """Return coherence statistics for column-normalized codes.""" | |
| Cn = normalize_columns(C) | |
| G = gram_matrix(Cn) | |
| K = G.shape[0] | |
| mask = ~torch.eye(K, dtype=torch.bool, device=G.device) | |
| off_diag = G.abs()[mask] | |
| return { | |
| "max_abs_offdiag": off_diag.max().item(), | |
| "mean_abs_offdiag": off_diag.mean().item(), | |
| } | |