WrinkleBrane / Wrinkle /WrinkleBrane-Research /06_learnable_codes /src /wrinklebrane /learnable_codes.py
| """Learnable code evolution for WrinkleBrane. | |
| Direction 6: Makes the codebook ``C ∈ ℝ[L, K]`` a learnable parameter | |
| shared between write and read paths, enabling end-to-end training with | |
| reconstruction loss and orthogonality regularisation. | |
| The core problem this solves: in the existing codebase, ``store_pairs_1d`` | |
| receives ``C`` as a function argument while ``Slicer1D`` stores a separate | |
| copy. ``LearnableCodebook`` wraps ``C`` as a single ``nn.Parameter``; the | |
| write and read paths call ``codebook()`` to get the current normalised | |
| ``C``, ensuring they always agree. | |
| Key components | |
| -------------- | |
| ``LearnableCodebook`` | |
| ``nn.Module`` wrapping ``C`` as a learnable parameter with on-the-fly | |
| column normalisation and coherence tracking. | |
| ``orthogonality_loss`` | |
| Frobenius-norm penalty: ``||C^T C - I||_F^2``. | |
| ``LearnableMemoryBank1D`` | |
| End-to-end differentiable membrane: learnable codebook shared by | |
| ``store`` (write) and ``retrieve`` (read). | |
| ``train_codebook`` | |
| Training loop helper: reconstruction loss + orthogonality | |
| regularisation, with coherence tracking. | |
| """ | |
| from __future__ import annotations | |
| from typing import Optional, List, Dict | |
| import math | |
| import torch | |
| from torch import nn, Tensor | |
| from wrinklebrane.codes import ( | |
| hadamard_codes, | |
| dct_codes, | |
| gaussian_codes, | |
| normalize_columns, | |
| coherence_stats, | |
| gram_matrix, | |
| ) | |
| from wrinklebrane.membrane_1d import ( | |
| MembraneBank1D, | |
| store_pairs_1d, | |
| Slicer1D, | |
| ContinuousWriter1D, | |
| ContinuousReader1D, | |
| soft_code_weights_1d, | |
| cosine_similarity_matrix, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Orthogonality loss | |
| # --------------------------------------------------------------------------- | |
| def orthogonality_loss(C: Tensor) -> Tensor: | |
| """Frobenius-norm penalty for deviation from orthogonality. | |
| ``loss = ||C_n^T C_n - I_K||_F^2`` | |
| where ``C_n`` is column-normalised ``C``. Returns a differentiable | |
| scalar suitable for use as a regularisation term. | |
| Parameters | |
| ---------- | |
| C : Tensor ``[L, K]`` | |
| Returns | |
| ------- | |
| Tensor | |
| Scalar loss (0 for perfectly orthogonal codes). | |
| """ | |
| K = C.shape[1] | |
| # Normalise columns (differentiable) | |
| norms = C.norm(dim=0, keepdim=True).clamp_min(1e-8) | |
| C_n = C / norms | |
| G = C_n.T @ C_n # [K, K] | |
| I = torch.eye(K, device=C.device, dtype=C.dtype) | |
| return (G - I).pow(2).sum() | |
| # --------------------------------------------------------------------------- | |
| # LearnableCodebook | |
| # --------------------------------------------------------------------------- | |
| class LearnableCodebook(nn.Module): | |
| """Learnable codebook ``C ∈ ℝ[L, K]`` with unit-norm column output. | |
| The raw parameter ``C_raw`` is stored as ``nn.Parameter``. Calling | |
| the module returns column-normalised ``C`` (differentiable), ensuring | |
| the write and read paths always use normalised codes. | |
| Parameters | |
| ---------- | |
| L : int | |
| Number of code layers. | |
| K : int | |
| Number of code columns (capacity). | |
| init : str | |
| Initialisation: ``"hadamard"``, ``"dct"``, ``"gaussian"``, | |
| ``"random"``, or ``"identity"`` (zero-padded eye). | |
| seed : int | |
| RNG seed for stochastic initialisations. | |
| freeze : bool | |
| If ``True``, ``C_raw`` is not learnable (``requires_grad=False``). | |
| """ | |
| def __init__( | |
| self, | |
| L: int, | |
| K: int, | |
| init: str = "hadamard", | |
| seed: int = 0, | |
| freeze: bool = False, | |
| ): | |
| super().__init__() | |
| self.L = L | |
| self.K = K | |
| C_init = _init_codebook(L, K, init, seed) | |
| self.C_raw = nn.Parameter(C_init, requires_grad=not freeze) | |
| def forward(self) -> Tensor: | |
| """Return column-normalised codebook ``[L, K]``.""" | |
| norms = self.C_raw.norm(dim=0, keepdim=True).clamp_min(1e-8) | |
| return self.C_raw / norms | |
| def ortho_loss(self) -> Tensor: | |
| """Orthogonality regularisation loss (scalar).""" | |
| return orthogonality_loss(self.C_raw) | |
| def coherence(self) -> Dict[str, float]: | |
| """Current coherence statistics (detached).""" | |
| with torch.no_grad(): | |
| return coherence_stats(self.forward()) | |
| def gram(self) -> Tensor: | |
| """Return Gram matrix ``C_n^T C_n`` (differentiable).""" | |
| C_n = self.forward() | |
| return C_n.T @ C_n | |
| def _init_codebook(L: int, K: int, init: str, seed: int = 0) -> Tensor: | |
| """Create initial codebook tensor.""" | |
| init = init.lower().strip() | |
| if init == "hadamard": | |
| return hadamard_codes(L, K) | |
| if init == "dct": | |
| return dct_codes(L, K) | |
| if init == "gaussian": | |
| return gaussian_codes(L, K, seed=seed) | |
| if init == "random": | |
| gen = torch.Generator().manual_seed(seed) | |
| C = torch.randn(L, K, generator=gen) | |
| return normalize_columns(C) | |
| if init == "identity": | |
| # Zero-padded identity: perfect orthogonality if K ≤ L | |
| C = torch.zeros(L, K) | |
| n = min(L, K) | |
| C[:n, :n] = torch.eye(n) | |
| return C | |
| raise ValueError(f"Unknown init '{init}'") | |
| # --------------------------------------------------------------------------- | |
| # LearnableMemoryBank1D | |
| # --------------------------------------------------------------------------- | |
| class LearnableMemoryBank1D(nn.Module): | |
| """End-to-end differentiable 1D membrane with a shared learnable codebook. | |
| Write and read paths both call ``self.codebook()`` to get the current | |
| normalised ``C``, ensuring consistency. | |
| Parameters | |
| ---------- | |
| L : int | |
| Number of code layers. | |
| K : int | |
| Number of code columns (capacity). | |
| D : int | |
| Embedding dimension. | |
| init : str | |
| Codebook initialisation (``"hadamard"``, ``"dct"``, etc.). | |
| freeze_codes : bool | |
| If ``True``, codebook is fixed (non-learnable). | |
| device, dtype : standard. | |
| """ | |
| def __init__( | |
| self, | |
| L: int, | |
| K: int, | |
| D: int, | |
| init: str = "hadamard", | |
| freeze_codes: bool = False, | |
| device: Optional[torch.device | str] = None, | |
| dtype: torch.dtype = torch.float32, | |
| ): | |
| super().__init__() | |
| self.L = L | |
| self.K = K | |
| self.D = D | |
| self.codebook = LearnableCodebook(L, K, init=init, freeze=freeze_codes) | |
| self.bank = MembraneBank1D(L=L, D=D, device=device, dtype=dtype) | |
| # ---- helpers ---------------------------------------------------------- | |
| def _get_C(self) -> Tensor: | |
| """Current normalised codebook (differentiable).""" | |
| return self.codebook() | |
| # ---- allocation / reset ----------------------------------------------- | |
| def allocate(self, B: int) -> None: | |
| self.bank.allocate(B) | |
| def reset(self, B: Optional[int] = None) -> None: | |
| self.bank.reset(B) | |
| # ---- write ------------------------------------------------------------ | |
| def store( | |
| self, | |
| keys: Tensor, | |
| values: Tensor, | |
| alphas: Tensor, | |
| ) -> None: | |
| """Discrete write using the shared learnable codebook. | |
| Parameters | |
| ---------- | |
| keys : Tensor ``[T]`` | |
| values : Tensor ``[T, D]`` | |
| alphas : Tensor ``[T]`` | |
| """ | |
| C = self._get_C() | |
| M = self.bank.read() | |
| M_new = store_pairs_1d(M, C, keys, values, alphas) | |
| self.bank.M = M_new | |
| def store_continuous( | |
| self, | |
| queries: Tensor, | |
| values: Tensor, | |
| alphas: Tensor, | |
| projection: Tensor, | |
| temperature: float | Tensor = 1.0, | |
| ) -> None: | |
| """Continuous write using the shared learnable codebook. | |
| Parameters | |
| ---------- | |
| queries : Tensor ``[T, D_query]`` | |
| values : Tensor ``[T, D]`` | |
| alphas : Tensor ``[T]`` | |
| projection : Tensor ``[D_query, K]`` | |
| temperature : float or Tensor | |
| """ | |
| C = self._get_C() | |
| M = self.bank.read() | |
| weights = soft_code_weights_1d(queries, projection, temperature) | |
| codes = C @ weights.T # [L, T] | |
| codes = codes * alphas.unsqueeze(0) | |
| delta = torch.einsum("lt,td->ld", codes, values) | |
| self.bank.M = M + delta.unsqueeze(0) | |
| # ---- read ------------------------------------------------------------- | |
| def retrieve(self) -> Tensor: | |
| """Discrete read using the shared learnable codebook. | |
| Returns ``[B, K, D]``. | |
| """ | |
| C = self._get_C() | |
| M = self.bank.read() | |
| return torch.einsum("bld,lk->bkd", M, C) | |
| def retrieve_continuous( | |
| self, | |
| queries: Tensor, | |
| projection: Tensor, | |
| temperature: float | Tensor = 1.0, | |
| ) -> Tensor: | |
| """Continuous read using the shared learnable codebook. | |
| Parameters | |
| ---------- | |
| queries : Tensor ``[T, D_query]`` | |
| projection : Tensor ``[D_query, K]`` | |
| temperature : float or Tensor | |
| Returns ``[B, T, D]`` | |
| """ | |
| C = self._get_C() | |
| M = self.bank.read() | |
| Y_full = torch.einsum("bld,lk->bkd", M, C) # [B, K, D] | |
| weights = soft_code_weights_1d(queries, projection, temperature) | |
| return torch.einsum("bkd,tk->btd", Y_full, weights) | |
| # ---- diagnostics ------------------------------------------------------ | |
| def coherence(self) -> Dict[str, float]: | |
| return self.codebook.coherence() | |
| def ortho_loss(self) -> Tensor: | |
| return self.codebook.ortho_loss() | |
| # --------------------------------------------------------------------------- | |
| # Training utilities | |
| # --------------------------------------------------------------------------- | |
| def reconstruction_loss( | |
| retrieved: Tensor, | |
| targets: Tensor, | |
| ) -> Tensor: | |
| """MSE between retrieved embeddings and targets. | |
| Parameters | |
| ---------- | |
| retrieved : Tensor ``[B, K, D]`` or ``[B, T, D]`` | |
| targets : Tensor matching shape (or broadcastable). | |
| """ | |
| return (retrieved - targets).pow(2).mean() | |
| def train_codebook( | |
| bank: LearnableMemoryBank1D, | |
| data_fn, | |
| *, | |
| n_steps: int = 100, | |
| lr: float = 1e-3, | |
| ortho_lambda: float = 0.1, | |
| B: int = 1, | |
| log_every: int = 10, | |
| ) -> List[Dict[str, float]]: | |
| """Train a learnable codebook with reconstruction loss + orthogonality reg. | |
| Each step: | |
| 1. Reset membrane, generate fresh data via ``data_fn()`` | |
| 2. Store data with discrete keys | |
| 3. Retrieve data | |
| 4. Compute ``loss = MSE(retrieved, original) + λ * ortho_loss(C)`` | |
| 5. Backprop and update ``C_raw`` | |
| Parameters | |
| ---------- | |
| bank : LearnableMemoryBank1D | |
| Must have a learnable (unfrozen) codebook. | |
| data_fn : callable | |
| ``data_fn() -> (keys, values, alphas)`` returning tensors for one | |
| training step. ``keys: [T]``, ``values: [T, D]``, ``alphas: [T]``. | |
| n_steps : int | |
| Number of training steps. | |
| lr : float | |
| Learning rate for Adam. | |
| ortho_lambda : float | |
| Weight of orthogonality regularisation. | |
| B : int | |
| Batch size for membrane allocation. | |
| log_every : int | |
| Logging frequency. | |
| Returns | |
| ------- | |
| list[dict] | |
| Per-step metrics: ``step``, ``total_loss``, ``recon_loss``, | |
| ``ortho_loss``, ``max_coherence``, ``mean_coherence``. | |
| """ | |
| optimizer = torch.optim.Adam(bank.parameters(), lr=lr) | |
| history: List[Dict[str, float]] = [] | |
| for step in range(n_steps): | |
| optimizer.zero_grad() | |
| # Fresh membrane each step | |
| bank.allocate(B) | |
| keys, values, alphas = data_fn() | |
| # Store → retrieve | |
| bank.store(keys, values, alphas) | |
| Y = bank.retrieve() # [B, K, D] | |
| # Loss | |
| # Target: values at the corresponding key indices | |
| target = values.unsqueeze(0).expand(B, -1, -1) # [B, T, D] | |
| recon = reconstruction_loss(Y, target) | |
| ortho = bank.ortho_loss() | |
| loss = recon + ortho_lambda * ortho | |
| loss.backward() | |
| optimizer.step() | |
| if step % log_every == 0 or step == n_steps - 1: | |
| with torch.no_grad(): | |
| coh = bank.coherence() | |
| record = { | |
| "step": step, | |
| "total_loss": float(loss), | |
| "recon_loss": float(recon), | |
| "ortho_loss": float(ortho), | |
| "max_coherence": coh["max_abs_offdiag"], | |
| "mean_coherence": coh["mean_abs_offdiag"], | |
| } | |
| history.append(record) | |
| return history | |