File size: 4,583 Bytes
3d7f6c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Learnable code evolution for WrinkleBrane.

Direction 6: Makes the codebook ``C ∈ ℝ[L, K]`` a learnable parameter
shared between write and read paths, enabling end-to-end training with
reconstruction loss and orthogonality regularisation.

Key components
--------------
``LearnableCodebook``
    ``nn.Module`` wrapping ``C`` as a learnable parameter with on-the-fly
    column normalisation and coherence tracking.

``orthogonality_loss``
    Off-diagonal coherence penalty: ``||off_diag(C_n^T C_n)||_F^2``.
"""

from __future__ import annotations

from typing import Dict

import torch
from torch import nn, Tensor

from wrinklebrane.codes import (
    hadamard_codes,
    dct_codes,
    gaussian_codes,
    normalize_columns,
    coherence_stats,
    gram_matrix,
)

# ---------------------------------------------------------------------------
# Orthogonality loss
# ---------------------------------------------------------------------------

def orthogonality_loss(C: Tensor) -> Tensor:
    """Off-diagonal coherence penalty for code separation.

    ``loss = ||off_diag(C_n^T C_n)||_F^2``

    where ``C_n`` is column-normalised ``C``.  Only penalises cross-
    correlation between code columns (off-diagonal Gram matrix elements),
    without demanding unit diagonal — which is impossible when the
    codebook is overcomplete (``K > L``).

    This is consistent with the ``coherence_stats`` diagnostic: both
    target the same quantity (off-diagonal magnitudes).

    Parameters
    ----------
    C : Tensor ``[L, K]``

    Returns
    -------
    Tensor
        Scalar loss (0 when all code columns are orthogonal).
    """
    K = C.shape[1]
    # Normalise columns (differentiable)
    norms = C.norm(dim=0, keepdim=True).clamp_min(1e-8)
    C_n = C / norms
    G = C_n.T @ C_n  # [K, K]
    mask = ~torch.eye(K, dtype=torch.bool, device=G.device)
    return G[mask].pow(2).sum()


# ---------------------------------------------------------------------------
# LearnableCodebook
# ---------------------------------------------------------------------------

class LearnableCodebook(nn.Module):
    """Learnable codebook ``C ∈ ℝ[L, K]`` with unit-norm column output.

    The raw parameter ``C_raw`` is stored as ``nn.Parameter``.  Calling
    the module returns column-normalised ``C`` (differentiable), ensuring
    the write and read paths always use normalised codes.

    Parameters
    ----------
    L : int
        Number of code layers.
    K : int
        Number of code columns (capacity).
    init : str
        Initialisation: ``"hadamard"``, ``"dct"``, ``"gaussian"``,
        ``"random"``, or ``"identity"`` (zero-padded eye).
    seed : int
        RNG seed for stochastic initialisations.
    freeze : bool
        If ``True``, ``C_raw`` is not learnable (``requires_grad=False``).
    """

    def __init__(
        self,
        L: int,
        K: int,
        init: str = "hadamard",
        seed: int = 0,
        freeze: bool = False,
    ):
        super().__init__()
        self.L = L
        self.K = K

        C_init = _init_codebook(L, K, init, seed)
        self.C_raw = nn.Parameter(C_init, requires_grad=not freeze)

    def forward(self) -> Tensor:
        """Return column-normalised codebook ``[L, K]``."""
        norms = self.C_raw.norm(dim=0, keepdim=True).clamp_min(1e-8)
        return self.C_raw / norms

    def ortho_loss(self) -> Tensor:
        """Orthogonality regularisation loss (scalar)."""
        return orthogonality_loss(self.C_raw)

    def coherence(self) -> Dict[str, float]:
        """Current coherence statistics (detached)."""
        with torch.no_grad():
            return coherence_stats(self.forward())

    def gram(self) -> Tensor:
        """Return Gram matrix ``C_n^T C_n`` (differentiable)."""
        C_n = self.forward()
        return C_n.T @ C_n


def _init_codebook(L: int, K: int, init: str, seed: int = 0) -> Tensor:
    """Create initial codebook tensor."""
    init = init.lower().strip()
    if init == "hadamard":
        return hadamard_codes(L, K)
    if init == "dct":
        return dct_codes(L, K)
    if init == "gaussian":
        return gaussian_codes(L, K, seed=seed)
    if init == "random":
        gen = torch.Generator().manual_seed(seed)
        C = torch.randn(L, K, generator=gen)
        return normalize_columns(C)
    if init == "identity":
        # Zero-padded identity: perfect orthogonality if K ≤ L
        C = torch.zeros(L, K)
        n = min(L, K)
        C[:n, :n] = torch.eye(n)
        return C
    raise ValueError(f"Unknown init '{init}'")