WCNegentropy's picture
๐Ÿ“š Updated with scientifically rigorous documentation
dc2b9f3 verified
from __future__ import annotations
import math
from typing import Dict
import numpy as np
import torch
from scipy.fft import dct as scipy_dct
from scipy.linalg import hadamard
def normalize_columns(C: torch.Tensor) -> torch.Tensor:
"""Return tensor with columns normalized to unit L2 norm."""
norms = torch.linalg.norm(C, dim=0, keepdim=True)
norms = norms.clamp_min(torch.finfo(C.dtype).eps)
return C / norms
def hadamard_codes(L: int, K: int) -> torch.Tensor:
"""Return first ``K`` columns of a Hadamard matrix with ``L`` rows."""
if L <= 0 or K <= 0:
return torch.empty(L, K)
n = 1 << (max(L, K) - 1).bit_length()
H = hadamard(n)
C = torch.from_numpy(H[:L, :K]).to(dtype=torch.float32)
return normalize_columns(C)
def dct_codes(L: int, K: int) -> torch.Tensor:
"""Return first ``K`` DCT basis vectors of length ``L``."""
if L <= 0 or K <= 0:
return torch.empty(L, K)
basis = scipy_dct(np.eye(L), type=2, axis=0, norm="ortho")
C = torch.from_numpy(basis[:, :K]).to(dtype=torch.float32)
return normalize_columns(C)
def gaussian_codes(L: int, K: int, seed: int = 0) -> torch.Tensor:
"""Return ``K`` Gaussian random codes of length ``L`` with unit norm."""
if L <= 0 or K <= 0:
return torch.empty(L, K)
gen = torch.Generator().manual_seed(seed)
C = torch.randn(L, K, generator=gen) / math.sqrt(L)
return normalize_columns(C)
def gram_matrix(C: torch.Tensor) -> torch.Tensor:
"""Return the Gram matrix ``C^T C``."""
return C.T @ C
def coherence_stats(C: torch.Tensor) -> Dict[str, float]:
"""Return coherence statistics for column-normalized codes."""
Cn = normalize_columns(C)
G = gram_matrix(Cn)
K = G.shape[0]
mask = ~torch.eye(K, dtype=torch.bool, device=G.device)
off_diag = G.abs()[mask]
return {
"max_abs_offdiag": off_diag.max().item(),
"mean_abs_offdiag": off_diag.mean().item(),
}