"""Learnable code evolution for WrinkleBrane. Direction 6: Makes the codebook ``C ∈ ℝ[L, K]`` a learnable parameter shared between write and read paths, enabling end-to-end training with reconstruction loss and orthogonality regularisation. The core problem this solves: in the existing codebase, ``store_pairs_1d`` receives ``C`` as a function argument while ``Slicer1D`` stores a separate copy. ``LearnableCodebook`` wraps ``C`` as a single ``nn.Parameter``; the write and read paths call ``codebook()`` to get the current normalised ``C``, ensuring they always agree. Key components -------------- ``LearnableCodebook`` ``nn.Module`` wrapping ``C`` as a learnable parameter with on-the-fly column normalisation and coherence tracking. ``orthogonality_loss`` Frobenius-norm penalty: ``||C^T C - I||_F^2``. ``LearnableMemoryBank1D`` End-to-end differentiable membrane: learnable codebook shared by ``store`` (write) and ``retrieve`` (read). ``train_codebook`` Training loop helper: reconstruction loss + orthogonality regularisation, with coherence tracking. """ from __future__ import annotations from typing import Optional, List, Dict import math import torch from torch import nn, Tensor from wrinklebrane.codes import ( hadamard_codes, dct_codes, gaussian_codes, normalize_columns, coherence_stats, gram_matrix, ) from wrinklebrane.membrane_1d import ( MembraneBank1D, store_pairs_1d, Slicer1D, ContinuousWriter1D, ContinuousReader1D, soft_code_weights_1d, cosine_similarity_matrix, ) # --------------------------------------------------------------------------- # Orthogonality loss # --------------------------------------------------------------------------- def orthogonality_loss(C: Tensor) -> Tensor: """Frobenius-norm penalty for deviation from orthogonality. ``loss = ||C_n^T C_n - I_K||_F^2`` where ``C_n`` is column-normalised ``C``. Returns a differentiable scalar suitable for use as a regularisation term. Parameters ---------- C : Tensor ``[L, K]`` Returns ------- Tensor Scalar loss (0 for perfectly orthogonal codes). """ K = C.shape[1] # Normalise columns (differentiable) norms = C.norm(dim=0, keepdim=True).clamp_min(1e-8) C_n = C / norms G = C_n.T @ C_n # [K, K] I = torch.eye(K, device=C.device, dtype=C.dtype) return (G - I).pow(2).sum() # --------------------------------------------------------------------------- # LearnableCodebook # --------------------------------------------------------------------------- class LearnableCodebook(nn.Module): """Learnable codebook ``C ∈ ℝ[L, K]`` with unit-norm column output. The raw parameter ``C_raw`` is stored as ``nn.Parameter``. Calling the module returns column-normalised ``C`` (differentiable), ensuring the write and read paths always use normalised codes. Parameters ---------- L : int Number of code layers. K : int Number of code columns (capacity). init : str Initialisation: ``"hadamard"``, ``"dct"``, ``"gaussian"``, ``"random"``, or ``"identity"`` (zero-padded eye). seed : int RNG seed for stochastic initialisations. freeze : bool If ``True``, ``C_raw`` is not learnable (``requires_grad=False``). """ def __init__( self, L: int, K: int, init: str = "hadamard", seed: int = 0, freeze: bool = False, ): super().__init__() self.L = L self.K = K C_init = _init_codebook(L, K, init, seed) self.C_raw = nn.Parameter(C_init, requires_grad=not freeze) def forward(self) -> Tensor: """Return column-normalised codebook ``[L, K]``.""" norms = self.C_raw.norm(dim=0, keepdim=True).clamp_min(1e-8) return self.C_raw / norms def ortho_loss(self) -> Tensor: """Orthogonality regularisation loss (scalar).""" return orthogonality_loss(self.C_raw) def coherence(self) -> Dict[str, float]: """Current coherence statistics (detached).""" with torch.no_grad(): return coherence_stats(self.forward()) def gram(self) -> Tensor: """Return Gram matrix ``C_n^T C_n`` (differentiable).""" C_n = self.forward() return C_n.T @ C_n def _init_codebook(L: int, K: int, init: str, seed: int = 0) -> Tensor: """Create initial codebook tensor.""" init = init.lower().strip() if init == "hadamard": return hadamard_codes(L, K) if init == "dct": return dct_codes(L, K) if init == "gaussian": return gaussian_codes(L, K, seed=seed) if init == "random": gen = torch.Generator().manual_seed(seed) C = torch.randn(L, K, generator=gen) return normalize_columns(C) if init == "identity": # Zero-padded identity: perfect orthogonality if K ≤ L C = torch.zeros(L, K) n = min(L, K) C[:n, :n] = torch.eye(n) return C raise ValueError(f"Unknown init '{init}'") # --------------------------------------------------------------------------- # LearnableMemoryBank1D # --------------------------------------------------------------------------- class LearnableMemoryBank1D(nn.Module): """End-to-end differentiable 1D membrane with a shared learnable codebook. Write and read paths both call ``self.codebook()`` to get the current normalised ``C``, ensuring consistency. Parameters ---------- L : int Number of code layers. K : int Number of code columns (capacity). D : int Embedding dimension. init : str Codebook initialisation (``"hadamard"``, ``"dct"``, etc.). freeze_codes : bool If ``True``, codebook is fixed (non-learnable). device, dtype : standard. """ def __init__( self, L: int, K: int, D: int, init: str = "hadamard", freeze_codes: bool = False, device: Optional[torch.device | str] = None, dtype: torch.dtype = torch.float32, ): super().__init__() self.L = L self.K = K self.D = D self.codebook = LearnableCodebook(L, K, init=init, freeze=freeze_codes) self.bank = MembraneBank1D(L=L, D=D, device=device, dtype=dtype) # ---- helpers ---------------------------------------------------------- def _get_C(self) -> Tensor: """Current normalised codebook (differentiable).""" return self.codebook() # ---- allocation / reset ----------------------------------------------- def allocate(self, B: int) -> None: self.bank.allocate(B) def reset(self, B: Optional[int] = None) -> None: self.bank.reset(B) # ---- write ------------------------------------------------------------ def store( self, keys: Tensor, values: Tensor, alphas: Tensor, ) -> None: """Discrete write using the shared learnable codebook. Parameters ---------- keys : Tensor ``[T]`` values : Tensor ``[T, D]`` alphas : Tensor ``[T]`` """ C = self._get_C() M = self.bank.read() M_new = store_pairs_1d(M, C, keys, values, alphas) self.bank.M = M_new def store_continuous( self, queries: Tensor, values: Tensor, alphas: Tensor, projection: Tensor, temperature: float | Tensor = 1.0, ) -> None: """Continuous write using the shared learnable codebook. Parameters ---------- queries : Tensor ``[T, D_query]`` values : Tensor ``[T, D]`` alphas : Tensor ``[T]`` projection : Tensor ``[D_query, K]`` temperature : float or Tensor """ C = self._get_C() M = self.bank.read() weights = soft_code_weights_1d(queries, projection, temperature) codes = C @ weights.T # [L, T] codes = codes * alphas.unsqueeze(0) delta = torch.einsum("lt,td->ld", codes, values) self.bank.M = M + delta.unsqueeze(0) # ---- read ------------------------------------------------------------- def retrieve(self) -> Tensor: """Discrete read using the shared learnable codebook. Returns ``[B, K, D]``. """ C = self._get_C() M = self.bank.read() return torch.einsum("bld,lk->bkd", M, C) def retrieve_continuous( self, queries: Tensor, projection: Tensor, temperature: float | Tensor = 1.0, ) -> Tensor: """Continuous read using the shared learnable codebook. Parameters ---------- queries : Tensor ``[T, D_query]`` projection : Tensor ``[D_query, K]`` temperature : float or Tensor Returns ``[B, T, D]`` """ C = self._get_C() M = self.bank.read() Y_full = torch.einsum("bld,lk->bkd", M, C) # [B, K, D] weights = soft_code_weights_1d(queries, projection, temperature) return torch.einsum("bkd,tk->btd", Y_full, weights) # ---- diagnostics ------------------------------------------------------ def coherence(self) -> Dict[str, float]: return self.codebook.coherence() def ortho_loss(self) -> Tensor: return self.codebook.ortho_loss() # --------------------------------------------------------------------------- # Training utilities # --------------------------------------------------------------------------- def reconstruction_loss( retrieved: Tensor, targets: Tensor, ) -> Tensor: """MSE between retrieved embeddings and targets. Parameters ---------- retrieved : Tensor ``[B, K, D]`` or ``[B, T, D]`` targets : Tensor matching shape (or broadcastable). """ return (retrieved - targets).pow(2).mean() def train_codebook( bank: LearnableMemoryBank1D, data_fn, *, n_steps: int = 100, lr: float = 1e-3, ortho_lambda: float = 0.1, B: int = 1, log_every: int = 10, ) -> List[Dict[str, float]]: """Train a learnable codebook with reconstruction loss + orthogonality reg. Each step: 1. Reset membrane, generate fresh data via ``data_fn()`` 2. Store data with discrete keys 3. Retrieve data 4. Compute ``loss = MSE(retrieved, original) + λ * ortho_loss(C)`` 5. Backprop and update ``C_raw`` Parameters ---------- bank : LearnableMemoryBank1D Must have a learnable (unfrozen) codebook. data_fn : callable ``data_fn() -> (keys, values, alphas)`` returning tensors for one training step. ``keys: [T]``, ``values: [T, D]``, ``alphas: [T]``. n_steps : int Number of training steps. lr : float Learning rate for Adam. ortho_lambda : float Weight of orthogonality regularisation. B : int Batch size for membrane allocation. log_every : int Logging frequency. Returns ------- list[dict] Per-step metrics: ``step``, ``total_loss``, ``recon_loss``, ``ortho_loss``, ``max_coherence``, ``mean_coherence``. """ optimizer = torch.optim.Adam(bank.parameters(), lr=lr) history: List[Dict[str, float]] = [] for step in range(n_steps): optimizer.zero_grad() # Fresh membrane each step bank.allocate(B) keys, values, alphas = data_fn() # Store → retrieve bank.store(keys, values, alphas) Y = bank.retrieve() # [B, K, D] # Loss # Target: values at the corresponding key indices target = values.unsqueeze(0).expand(B, -1, -1) # [B, T, D] recon = reconstruction_loss(Y, target) ortho = bank.ortho_loss() loss = recon + ortho_lambda * ortho loss.backward() optimizer.step() if step % log_every == 0 or step == n_steps - 1: with torch.no_grad(): coh = bank.coherence() record = { "step": step, "total_loss": float(loss), "recon_loss": float(recon), "ortho_loss": float(ortho), "max_coherence": coh["max_abs_offdiag"], "mean_coherence": coh["mean_abs_offdiag"], } history.append(record) return history