WCNegentropy's picture
Upload 510 files
3d7f6c5 verified
"""Learnable code evolution for WrinkleBrane.
Direction 6: Makes the codebook ``C ∈ ℝ[L, K]`` a learnable parameter
shared between write and read paths, enabling end-to-end training with
reconstruction loss and orthogonality regularisation.
The core problem this solves: in the existing codebase, ``store_pairs_1d``
receives ``C`` as a function argument while ``Slicer1D`` stores a separate
copy. ``LearnableCodebook`` wraps ``C`` as a single ``nn.Parameter``; the
write and read paths call ``codebook()`` to get the current normalised
``C``, ensuring they always agree.
Key components
--------------
``LearnableCodebook``
``nn.Module`` wrapping ``C`` as a learnable parameter with on-the-fly
column normalisation and coherence tracking.
``orthogonality_loss``
Frobenius-norm penalty: ``||C^T C - I||_F^2``.
``LearnableMemoryBank1D``
End-to-end differentiable membrane: learnable codebook shared by
``store`` (write) and ``retrieve`` (read).
``train_codebook``
Training loop helper: reconstruction loss + orthogonality
regularisation, with coherence tracking.
"""
from __future__ import annotations
from typing import Optional, List, Dict
import math
import torch
from torch import nn, Tensor
from wrinklebrane.codes import (
hadamard_codes,
dct_codes,
gaussian_codes,
normalize_columns,
coherence_stats,
gram_matrix,
)
from wrinklebrane.membrane_1d import (
MembraneBank1D,
store_pairs_1d,
Slicer1D,
ContinuousWriter1D,
ContinuousReader1D,
soft_code_weights_1d,
cosine_similarity_matrix,
)
# ---------------------------------------------------------------------------
# Orthogonality loss
# ---------------------------------------------------------------------------
def orthogonality_loss(C: Tensor) -> Tensor:
"""Frobenius-norm penalty for deviation from orthogonality.
``loss = ||C_n^T C_n - I_K||_F^2``
where ``C_n`` is column-normalised ``C``. Returns a differentiable
scalar suitable for use as a regularisation term.
Parameters
----------
C : Tensor ``[L, K]``
Returns
-------
Tensor
Scalar loss (0 for perfectly orthogonal codes).
"""
K = C.shape[1]
# Normalise columns (differentiable)
norms = C.norm(dim=0, keepdim=True).clamp_min(1e-8)
C_n = C / norms
G = C_n.T @ C_n # [K, K]
I = torch.eye(K, device=C.device, dtype=C.dtype)
return (G - I).pow(2).sum()
# ---------------------------------------------------------------------------
# LearnableCodebook
# ---------------------------------------------------------------------------
class LearnableCodebook(nn.Module):
"""Learnable codebook ``C ∈ ℝ[L, K]`` with unit-norm column output.
The raw parameter ``C_raw`` is stored as ``nn.Parameter``. Calling
the module returns column-normalised ``C`` (differentiable), ensuring
the write and read paths always use normalised codes.
Parameters
----------
L : int
Number of code layers.
K : int
Number of code columns (capacity).
init : str
Initialisation: ``"hadamard"``, ``"dct"``, ``"gaussian"``,
``"random"``, or ``"identity"`` (zero-padded eye).
seed : int
RNG seed for stochastic initialisations.
freeze : bool
If ``True``, ``C_raw`` is not learnable (``requires_grad=False``).
"""
def __init__(
self,
L: int,
K: int,
init: str = "hadamard",
seed: int = 0,
freeze: bool = False,
):
super().__init__()
self.L = L
self.K = K
C_init = _init_codebook(L, K, init, seed)
self.C_raw = nn.Parameter(C_init, requires_grad=not freeze)
def forward(self) -> Tensor:
"""Return column-normalised codebook ``[L, K]``."""
norms = self.C_raw.norm(dim=0, keepdim=True).clamp_min(1e-8)
return self.C_raw / norms
def ortho_loss(self) -> Tensor:
"""Orthogonality regularisation loss (scalar)."""
return orthogonality_loss(self.C_raw)
def coherence(self) -> Dict[str, float]:
"""Current coherence statistics (detached)."""
with torch.no_grad():
return coherence_stats(self.forward())
def gram(self) -> Tensor:
"""Return Gram matrix ``C_n^T C_n`` (differentiable)."""
C_n = self.forward()
return C_n.T @ C_n
def _init_codebook(L: int, K: int, init: str, seed: int = 0) -> Tensor:
"""Create initial codebook tensor."""
init = init.lower().strip()
if init == "hadamard":
return hadamard_codes(L, K)
if init == "dct":
return dct_codes(L, K)
if init == "gaussian":
return gaussian_codes(L, K, seed=seed)
if init == "random":
gen = torch.Generator().manual_seed(seed)
C = torch.randn(L, K, generator=gen)
return normalize_columns(C)
if init == "identity":
# Zero-padded identity: perfect orthogonality if K ≤ L
C = torch.zeros(L, K)
n = min(L, K)
C[:n, :n] = torch.eye(n)
return C
raise ValueError(f"Unknown init '{init}'")
# ---------------------------------------------------------------------------
# LearnableMemoryBank1D
# ---------------------------------------------------------------------------
class LearnableMemoryBank1D(nn.Module):
"""End-to-end differentiable 1D membrane with a shared learnable codebook.
Write and read paths both call ``self.codebook()`` to get the current
normalised ``C``, ensuring consistency.
Parameters
----------
L : int
Number of code layers.
K : int
Number of code columns (capacity).
D : int
Embedding dimension.
init : str
Codebook initialisation (``"hadamard"``, ``"dct"``, etc.).
freeze_codes : bool
If ``True``, codebook is fixed (non-learnable).
device, dtype : standard.
"""
def __init__(
self,
L: int,
K: int,
D: int,
init: str = "hadamard",
freeze_codes: bool = False,
device: Optional[torch.device | str] = None,
dtype: torch.dtype = torch.float32,
):
super().__init__()
self.L = L
self.K = K
self.D = D
self.codebook = LearnableCodebook(L, K, init=init, freeze=freeze_codes)
self.bank = MembraneBank1D(L=L, D=D, device=device, dtype=dtype)
# ---- helpers ----------------------------------------------------------
def _get_C(self) -> Tensor:
"""Current normalised codebook (differentiable)."""
return self.codebook()
# ---- allocation / reset -----------------------------------------------
def allocate(self, B: int) -> None:
self.bank.allocate(B)
def reset(self, B: Optional[int] = None) -> None:
self.bank.reset(B)
# ---- write ------------------------------------------------------------
def store(
self,
keys: Tensor,
values: Tensor,
alphas: Tensor,
) -> None:
"""Discrete write using the shared learnable codebook.
Parameters
----------
keys : Tensor ``[T]``
values : Tensor ``[T, D]``
alphas : Tensor ``[T]``
"""
C = self._get_C()
M = self.bank.read()
M_new = store_pairs_1d(M, C, keys, values, alphas)
self.bank.M = M_new
def store_continuous(
self,
queries: Tensor,
values: Tensor,
alphas: Tensor,
projection: Tensor,
temperature: float | Tensor = 1.0,
) -> None:
"""Continuous write using the shared learnable codebook.
Parameters
----------
queries : Tensor ``[T, D_query]``
values : Tensor ``[T, D]``
alphas : Tensor ``[T]``
projection : Tensor ``[D_query, K]``
temperature : float or Tensor
"""
C = self._get_C()
M = self.bank.read()
weights = soft_code_weights_1d(queries, projection, temperature)
codes = C @ weights.T # [L, T]
codes = codes * alphas.unsqueeze(0)
delta = torch.einsum("lt,td->ld", codes, values)
self.bank.M = M + delta.unsqueeze(0)
# ---- read -------------------------------------------------------------
def retrieve(self) -> Tensor:
"""Discrete read using the shared learnable codebook.
Returns ``[B, K, D]``.
"""
C = self._get_C()
M = self.bank.read()
return torch.einsum("bld,lk->bkd", M, C)
def retrieve_continuous(
self,
queries: Tensor,
projection: Tensor,
temperature: float | Tensor = 1.0,
) -> Tensor:
"""Continuous read using the shared learnable codebook.
Parameters
----------
queries : Tensor ``[T, D_query]``
projection : Tensor ``[D_query, K]``
temperature : float or Tensor
Returns ``[B, T, D]``
"""
C = self._get_C()
M = self.bank.read()
Y_full = torch.einsum("bld,lk->bkd", M, C) # [B, K, D]
weights = soft_code_weights_1d(queries, projection, temperature)
return torch.einsum("bkd,tk->btd", Y_full, weights)
# ---- diagnostics ------------------------------------------------------
def coherence(self) -> Dict[str, float]:
return self.codebook.coherence()
def ortho_loss(self) -> Tensor:
return self.codebook.ortho_loss()
# ---------------------------------------------------------------------------
# Training utilities
# ---------------------------------------------------------------------------
def reconstruction_loss(
retrieved: Tensor,
targets: Tensor,
) -> Tensor:
"""MSE between retrieved embeddings and targets.
Parameters
----------
retrieved : Tensor ``[B, K, D]`` or ``[B, T, D]``
targets : Tensor matching shape (or broadcastable).
"""
return (retrieved - targets).pow(2).mean()
def train_codebook(
bank: LearnableMemoryBank1D,
data_fn,
*,
n_steps: int = 100,
lr: float = 1e-3,
ortho_lambda: float = 0.1,
B: int = 1,
log_every: int = 10,
) -> List[Dict[str, float]]:
"""Train a learnable codebook with reconstruction loss + orthogonality reg.
Each step:
1. Reset membrane, generate fresh data via ``data_fn()``
2. Store data with discrete keys
3. Retrieve data
4. Compute ``loss = MSE(retrieved, original) + λ * ortho_loss(C)``
5. Backprop and update ``C_raw``
Parameters
----------
bank : LearnableMemoryBank1D
Must have a learnable (unfrozen) codebook.
data_fn : callable
``data_fn() -> (keys, values, alphas)`` returning tensors for one
training step. ``keys: [T]``, ``values: [T, D]``, ``alphas: [T]``.
n_steps : int
Number of training steps.
lr : float
Learning rate for Adam.
ortho_lambda : float
Weight of orthogonality regularisation.
B : int
Batch size for membrane allocation.
log_every : int
Logging frequency.
Returns
-------
list[dict]
Per-step metrics: ``step``, ``total_loss``, ``recon_loss``,
``ortho_loss``, ``max_coherence``, ``mean_coherence``.
"""
optimizer = torch.optim.Adam(bank.parameters(), lr=lr)
history: List[Dict[str, float]] = []
for step in range(n_steps):
optimizer.zero_grad()
# Fresh membrane each step
bank.allocate(B)
keys, values, alphas = data_fn()
# Store → retrieve
bank.store(keys, values, alphas)
Y = bank.retrieve() # [B, K, D]
# Loss
# Target: values at the corresponding key indices
target = values.unsqueeze(0).expand(B, -1, -1) # [B, T, D]
recon = reconstruction_loss(Y, target)
ortho = bank.ortho_loss()
loss = recon + ortho_lambda * ortho
loss.backward()
optimizer.step()
if step % log_every == 0 or step == n_steps - 1:
with torch.no_grad():
coh = bank.coherence()
record = {
"step": step,
"total_loss": float(loss),
"recon_loss": float(recon),
"ortho_loss": float(ortho),
"max_coherence": coh["max_abs_offdiag"],
"mean_coherence": coh["mean_abs_offdiag"],
}
history.append(record)
return history