| from __future__ import annotations |
|
|
| import math |
| from typing import Dict |
|
|
| import numpy as np |
| import torch |
| from scipy.fft import dct as scipy_dct |
| from scipy.linalg import hadamard |
|
|
|
|
| def normalize_columns(C: torch.Tensor) -> torch.Tensor: |
| """Return tensor with columns normalized to unit L2 norm.""" |
| norms = torch.linalg.norm(C, dim=0, keepdim=True) |
| norms = norms.clamp_min(torch.finfo(C.dtype).eps) |
| return C / norms |
|
|
|
|
| def hadamard_codes(L: int, K: int) -> torch.Tensor: |
| """Return first ``K`` columns of a Hadamard matrix with ``L`` rows.""" |
| if L <= 0 or K <= 0: |
| return torch.empty(L, K) |
| n = 1 << (max(L, K) - 1).bit_length() |
| H = hadamard(n) |
| C = torch.from_numpy(H[:L, :K]).to(dtype=torch.float32) |
| return normalize_columns(C) |
|
|
|
|
| def dct_codes(L: int, K: int) -> torch.Tensor: |
| """Return first ``K`` DCT basis vectors of length ``L``.""" |
| if L <= 0 or K <= 0: |
| return torch.empty(L, K) |
| basis = scipy_dct(np.eye(L), type=2, axis=0, norm="ortho") |
| C = torch.from_numpy(basis[:, :K]).to(dtype=torch.float32) |
| return normalize_columns(C) |
|
|
|
|
| def gaussian_codes(L: int, K: int, seed: int = 0) -> torch.Tensor: |
| """Return ``K`` Gaussian random codes of length ``L`` with unit norm.""" |
| if L <= 0 or K <= 0: |
| return torch.empty(L, K) |
| gen = torch.Generator().manual_seed(seed) |
| C = torch.randn(L, K, generator=gen) / math.sqrt(L) |
| return normalize_columns(C) |
|
|
|
|
| def gram_matrix(C: torch.Tensor) -> torch.Tensor: |
| """Return the Gram matrix ``C^T C``.""" |
| return C.T @ C |
|
|
|
|
| def coherence_stats(C: torch.Tensor) -> Dict[str, float]: |
| """Return coherence statistics for column-normalized codes.""" |
| Cn = normalize_columns(C) |
| G = gram_matrix(Cn) |
| K = G.shape[0] |
| mask = ~torch.eye(K, dtype=torch.bool, device=G.device) |
| off_diag = G.abs()[mask] |
| return { |
| "max_abs_offdiag": off_diag.max().item(), |
| "mean_abs_offdiag": off_diag.mean().item(), |
| } |
|
|
|
|