| """Learnable code evolution for WrinkleBrane. |
| |
| Direction 6: Makes the codebook ``C ∈ ℝ[L, K]`` a learnable parameter |
| shared between write and read paths, enabling end-to-end training with |
| reconstruction loss and orthogonality regularisation. |
| |
| Key components |
| -------------- |
| ``LearnableCodebook`` |
| ``nn.Module`` wrapping ``C`` as a learnable parameter with on-the-fly |
| column normalisation and coherence tracking. |
| |
| ``orthogonality_loss`` |
| Off-diagonal coherence penalty: ``||off_diag(C_n^T C_n)||_F^2``. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Dict |
|
|
| import torch |
| from torch import nn, Tensor |
|
|
| from wrinklebrane.codes import ( |
| hadamard_codes, |
| dct_codes, |
| gaussian_codes, |
| normalize_columns, |
| coherence_stats, |
| gram_matrix, |
| ) |
|
|
| |
| |
| |
|
|
| def orthogonality_loss(C: Tensor) -> Tensor: |
| """Off-diagonal coherence penalty for code separation. |
| |
| ``loss = ||off_diag(C_n^T C_n)||_F^2`` |
| |
| where ``C_n`` is column-normalised ``C``. Only penalises cross- |
| correlation between code columns (off-diagonal Gram matrix elements), |
| without demanding unit diagonal — which is impossible when the |
| codebook is overcomplete (``K > L``). |
| |
| This is consistent with the ``coherence_stats`` diagnostic: both |
| target the same quantity (off-diagonal magnitudes). |
| |
| Parameters |
| ---------- |
| C : Tensor ``[L, K]`` |
| |
| Returns |
| ------- |
| Tensor |
| Scalar loss (0 when all code columns are orthogonal). |
| """ |
| K = C.shape[1] |
| |
| norms = C.norm(dim=0, keepdim=True).clamp_min(1e-8) |
| C_n = C / norms |
| G = C_n.T @ C_n |
| mask = ~torch.eye(K, dtype=torch.bool, device=G.device) |
| return G[mask].pow(2).sum() |
|
|
|
|
| |
| |
| |
|
|
| class LearnableCodebook(nn.Module): |
| """Learnable codebook ``C ∈ ℝ[L, K]`` with unit-norm column output. |
| |
| The raw parameter ``C_raw`` is stored as ``nn.Parameter``. Calling |
| the module returns column-normalised ``C`` (differentiable), ensuring |
| the write and read paths always use normalised codes. |
| |
| Parameters |
| ---------- |
| L : int |
| Number of code layers. |
| K : int |
| Number of code columns (capacity). |
| init : str |
| Initialisation: ``"hadamard"``, ``"dct"``, ``"gaussian"``, |
| ``"random"``, or ``"identity"`` (zero-padded eye). |
| seed : int |
| RNG seed for stochastic initialisations. |
| freeze : bool |
| If ``True``, ``C_raw`` is not learnable (``requires_grad=False``). |
| """ |
|
|
| def __init__( |
| self, |
| L: int, |
| K: int, |
| init: str = "hadamard", |
| seed: int = 0, |
| freeze: bool = False, |
| ): |
| super().__init__() |
| self.L = L |
| self.K = K |
|
|
| C_init = _init_codebook(L, K, init, seed) |
| self.C_raw = nn.Parameter(C_init, requires_grad=not freeze) |
|
|
| def forward(self) -> Tensor: |
| """Return column-normalised codebook ``[L, K]``.""" |
| norms = self.C_raw.norm(dim=0, keepdim=True).clamp_min(1e-8) |
| return self.C_raw / norms |
|
|
| def ortho_loss(self) -> Tensor: |
| """Orthogonality regularisation loss (scalar).""" |
| return orthogonality_loss(self.C_raw) |
|
|
| def coherence(self) -> Dict[str, float]: |
| """Current coherence statistics (detached).""" |
| with torch.no_grad(): |
| return coherence_stats(self.forward()) |
|
|
| def gram(self) -> Tensor: |
| """Return Gram matrix ``C_n^T C_n`` (differentiable).""" |
| C_n = self.forward() |
| return C_n.T @ C_n |
|
|
|
|
| def _init_codebook(L: int, K: int, init: str, seed: int = 0) -> Tensor: |
| """Create initial codebook tensor.""" |
| init = init.lower().strip() |
| if init == "hadamard": |
| return hadamard_codes(L, K) |
| if init == "dct": |
| return dct_codes(L, K) |
| if init == "gaussian": |
| return gaussian_codes(L, K, seed=seed) |
| if init == "random": |
| gen = torch.Generator().manual_seed(seed) |
| C = torch.randn(L, K, generator=gen) |
| return normalize_columns(C) |
| if init == "identity": |
| |
| C = torch.zeros(L, K) |
| n = min(L, K) |
| C[:n, :n] = torch.eye(n) |
| return C |
| raise ValueError(f"Unknown init '{init}'") |
|
|