"""Learnable code evolution for WrinkleBrane. Direction 6: Makes the codebook ``C ∈ ℝ[L, K]`` a learnable parameter shared between write and read paths, enabling end-to-end training with reconstruction loss and orthogonality regularisation. Key components -------------- ``LearnableCodebook`` ``nn.Module`` wrapping ``C`` as a learnable parameter with on-the-fly column normalisation and coherence tracking. ``orthogonality_loss`` Off-diagonal coherence penalty: ``||off_diag(C_n^T C_n)||_F^2``. """ from __future__ import annotations from typing import Dict import torch from torch import nn, Tensor from wrinklebrane.codes import ( hadamard_codes, dct_codes, gaussian_codes, normalize_columns, coherence_stats, gram_matrix, ) # --------------------------------------------------------------------------- # Orthogonality loss # --------------------------------------------------------------------------- def orthogonality_loss(C: Tensor) -> Tensor: """Off-diagonal coherence penalty for code separation. ``loss = ||off_diag(C_n^T C_n)||_F^2`` where ``C_n`` is column-normalised ``C``. Only penalises cross- correlation between code columns (off-diagonal Gram matrix elements), without demanding unit diagonal — which is impossible when the codebook is overcomplete (``K > L``). This is consistent with the ``coherence_stats`` diagnostic: both target the same quantity (off-diagonal magnitudes). Parameters ---------- C : Tensor ``[L, K]`` Returns ------- Tensor Scalar loss (0 when all code columns are orthogonal). """ K = C.shape[1] # Normalise columns (differentiable) norms = C.norm(dim=0, keepdim=True).clamp_min(1e-8) C_n = C / norms G = C_n.T @ C_n # [K, K] mask = ~torch.eye(K, dtype=torch.bool, device=G.device) return G[mask].pow(2).sum() # --------------------------------------------------------------------------- # LearnableCodebook # --------------------------------------------------------------------------- class LearnableCodebook(nn.Module): """Learnable codebook ``C ∈ ℝ[L, K]`` with unit-norm column output. The raw parameter ``C_raw`` is stored as ``nn.Parameter``. Calling the module returns column-normalised ``C`` (differentiable), ensuring the write and read paths always use normalised codes. Parameters ---------- L : int Number of code layers. K : int Number of code columns (capacity). init : str Initialisation: ``"hadamard"``, ``"dct"``, ``"gaussian"``, ``"random"``, or ``"identity"`` (zero-padded eye). seed : int RNG seed for stochastic initialisations. freeze : bool If ``True``, ``C_raw`` is not learnable (``requires_grad=False``). """ def __init__( self, L: int, K: int, init: str = "hadamard", seed: int = 0, freeze: bool = False, ): super().__init__() self.L = L self.K = K C_init = _init_codebook(L, K, init, seed) self.C_raw = nn.Parameter(C_init, requires_grad=not freeze) def forward(self) -> Tensor: """Return column-normalised codebook ``[L, K]``.""" norms = self.C_raw.norm(dim=0, keepdim=True).clamp_min(1e-8) return self.C_raw / norms def ortho_loss(self) -> Tensor: """Orthogonality regularisation loss (scalar).""" return orthogonality_loss(self.C_raw) def coherence(self) -> Dict[str, float]: """Current coherence statistics (detached).""" with torch.no_grad(): return coherence_stats(self.forward()) def gram(self) -> Tensor: """Return Gram matrix ``C_n^T C_n`` (differentiable).""" C_n = self.forward() return C_n.T @ C_n def _init_codebook(L: int, K: int, init: str, seed: int = 0) -> Tensor: """Create initial codebook tensor.""" init = init.lower().strip() if init == "hadamard": return hadamard_codes(L, K) if init == "dct": return dct_codes(L, K) if init == "gaussian": return gaussian_codes(L, K, seed=seed) if init == "random": gen = torch.Generator().manual_seed(seed) C = torch.randn(L, K, generator=gen) return normalize_columns(C) if init == "identity": # Zero-padded identity: perfect orthogonality if K ≤ L C = torch.zeros(L, K) n = min(L, K) C[:n, :n] = torch.eye(n) return C raise ValueError(f"Unknown init '{init}'")