File size: 4,583 Bytes
3d7f6c5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | """Learnable code evolution for WrinkleBrane.
Direction 6: Makes the codebook ``C ∈ ℝ[L, K]`` a learnable parameter
shared between write and read paths, enabling end-to-end training with
reconstruction loss and orthogonality regularisation.
Key components
--------------
``LearnableCodebook``
``nn.Module`` wrapping ``C`` as a learnable parameter with on-the-fly
column normalisation and coherence tracking.
``orthogonality_loss``
Off-diagonal coherence penalty: ``||off_diag(C_n^T C_n)||_F^2``.
"""
from __future__ import annotations
from typing import Dict
import torch
from torch import nn, Tensor
from wrinklebrane.codes import (
hadamard_codes,
dct_codes,
gaussian_codes,
normalize_columns,
coherence_stats,
gram_matrix,
)
# ---------------------------------------------------------------------------
# Orthogonality loss
# ---------------------------------------------------------------------------
def orthogonality_loss(C: Tensor) -> Tensor:
"""Off-diagonal coherence penalty for code separation.
``loss = ||off_diag(C_n^T C_n)||_F^2``
where ``C_n`` is column-normalised ``C``. Only penalises cross-
correlation between code columns (off-diagonal Gram matrix elements),
without demanding unit diagonal — which is impossible when the
codebook is overcomplete (``K > L``).
This is consistent with the ``coherence_stats`` diagnostic: both
target the same quantity (off-diagonal magnitudes).
Parameters
----------
C : Tensor ``[L, K]``
Returns
-------
Tensor
Scalar loss (0 when all code columns are orthogonal).
"""
K = C.shape[1]
# Normalise columns (differentiable)
norms = C.norm(dim=0, keepdim=True).clamp_min(1e-8)
C_n = C / norms
G = C_n.T @ C_n # [K, K]
mask = ~torch.eye(K, dtype=torch.bool, device=G.device)
return G[mask].pow(2).sum()
# ---------------------------------------------------------------------------
# LearnableCodebook
# ---------------------------------------------------------------------------
class LearnableCodebook(nn.Module):
"""Learnable codebook ``C ∈ ℝ[L, K]`` with unit-norm column output.
The raw parameter ``C_raw`` is stored as ``nn.Parameter``. Calling
the module returns column-normalised ``C`` (differentiable), ensuring
the write and read paths always use normalised codes.
Parameters
----------
L : int
Number of code layers.
K : int
Number of code columns (capacity).
init : str
Initialisation: ``"hadamard"``, ``"dct"``, ``"gaussian"``,
``"random"``, or ``"identity"`` (zero-padded eye).
seed : int
RNG seed for stochastic initialisations.
freeze : bool
If ``True``, ``C_raw`` is not learnable (``requires_grad=False``).
"""
def __init__(
self,
L: int,
K: int,
init: str = "hadamard",
seed: int = 0,
freeze: bool = False,
):
super().__init__()
self.L = L
self.K = K
C_init = _init_codebook(L, K, init, seed)
self.C_raw = nn.Parameter(C_init, requires_grad=not freeze)
def forward(self) -> Tensor:
"""Return column-normalised codebook ``[L, K]``."""
norms = self.C_raw.norm(dim=0, keepdim=True).clamp_min(1e-8)
return self.C_raw / norms
def ortho_loss(self) -> Tensor:
"""Orthogonality regularisation loss (scalar)."""
return orthogonality_loss(self.C_raw)
def coherence(self) -> Dict[str, float]:
"""Current coherence statistics (detached)."""
with torch.no_grad():
return coherence_stats(self.forward())
def gram(self) -> Tensor:
"""Return Gram matrix ``C_n^T C_n`` (differentiable)."""
C_n = self.forward()
return C_n.T @ C_n
def _init_codebook(L: int, K: int, init: str, seed: int = 0) -> Tensor:
"""Create initial codebook tensor."""
init = init.lower().strip()
if init == "hadamard":
return hadamard_codes(L, K)
if init == "dct":
return dct_codes(L, K)
if init == "gaussian":
return gaussian_codes(L, K, seed=seed)
if init == "random":
gen = torch.Generator().manual_seed(seed)
C = torch.randn(L, K, generator=gen)
return normalize_columns(C)
if init == "identity":
# Zero-padded identity: perfect orthogonality if K ≤ L
C = torch.zeros(L, K)
n = min(L, K)
C[:n, :n] = torch.eye(n)
return C
raise ValueError(f"Unknown init '{init}'")
|