File size: 1,964 Bytes
dc2b9f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from __future__ import annotations

import math
from typing import Dict

import numpy as np
import torch
from scipy.fft import dct as scipy_dct
from scipy.linalg import hadamard


def normalize_columns(C: torch.Tensor) -> torch.Tensor:
    """Return tensor with columns normalized to unit L2 norm."""
    norms = torch.linalg.norm(C, dim=0, keepdim=True)
    norms = norms.clamp_min(torch.finfo(C.dtype).eps)
    return C / norms


def hadamard_codes(L: int, K: int) -> torch.Tensor:
    """Return first ``K`` columns of a Hadamard matrix with ``L`` rows."""
    if L <= 0 or K <= 0:
        return torch.empty(L, K)
    n = 1 << (max(L, K) - 1).bit_length()
    H = hadamard(n)
    C = torch.from_numpy(H[:L, :K]).to(dtype=torch.float32)
    return normalize_columns(C)


def dct_codes(L: int, K: int) -> torch.Tensor:
    """Return first ``K`` DCT basis vectors of length ``L``."""
    if L <= 0 or K <= 0:
        return torch.empty(L, K)
    basis = scipy_dct(np.eye(L), type=2, axis=0, norm="ortho")
    C = torch.from_numpy(basis[:, :K]).to(dtype=torch.float32)
    return normalize_columns(C)


def gaussian_codes(L: int, K: int, seed: int = 0) -> torch.Tensor:
    """Return ``K`` Gaussian random codes of length ``L`` with unit norm."""
    if L <= 0 or K <= 0:
        return torch.empty(L, K)
    gen = torch.Generator().manual_seed(seed)
    C = torch.randn(L, K, generator=gen) / math.sqrt(L)
    return normalize_columns(C)


def gram_matrix(C: torch.Tensor) -> torch.Tensor:
    """Return the Gram matrix ``C^T C``."""
    return C.T @ C


def coherence_stats(C: torch.Tensor) -> Dict[str, float]:
    """Return coherence statistics for column-normalized codes."""
    Cn = normalize_columns(C)
    G = gram_matrix(Cn)
    K = G.shape[0]
    mask = ~torch.eye(K, dtype=torch.bool, device=G.device)
    off_diag = G.abs()[mask]
    return {
        "max_abs_offdiag": off_diag.max().item(),
        "mean_abs_offdiag": off_diag.mean().item(),
    }