WrinkleBrane / Wrinkle /src /wrinklebrane /learnable_codes.py
WCNegentropy's picture
Upload 510 files
3d7f6c5 verified
"""Learnable code evolution for WrinkleBrane.
Direction 6: Makes the codebook ``C ∈ ℝ[L, K]`` a learnable parameter
shared between write and read paths, enabling end-to-end training with
reconstruction loss and orthogonality regularisation.
Key components
--------------
``LearnableCodebook``
``nn.Module`` wrapping ``C`` as a learnable parameter with on-the-fly
column normalisation and coherence tracking.
``orthogonality_loss``
Off-diagonal coherence penalty: ``||off_diag(C_n^T C_n)||_F^2``.
"""
from __future__ import annotations
from typing import Dict
import torch
from torch import nn, Tensor
from wrinklebrane.codes import (
hadamard_codes,
dct_codes,
gaussian_codes,
normalize_columns,
coherence_stats,
gram_matrix,
)
# ---------------------------------------------------------------------------
# Orthogonality loss
# ---------------------------------------------------------------------------
def orthogonality_loss(C: Tensor) -> Tensor:
"""Off-diagonal coherence penalty for code separation.
``loss = ||off_diag(C_n^T C_n)||_F^2``
where ``C_n`` is column-normalised ``C``. Only penalises cross-
correlation between code columns (off-diagonal Gram matrix elements),
without demanding unit diagonal — which is impossible when the
codebook is overcomplete (``K > L``).
This is consistent with the ``coherence_stats`` diagnostic: both
target the same quantity (off-diagonal magnitudes).
Parameters
----------
C : Tensor ``[L, K]``
Returns
-------
Tensor
Scalar loss (0 when all code columns are orthogonal).
"""
K = C.shape[1]
# Normalise columns (differentiable)
norms = C.norm(dim=0, keepdim=True).clamp_min(1e-8)
C_n = C / norms
G = C_n.T @ C_n # [K, K]
mask = ~torch.eye(K, dtype=torch.bool, device=G.device)
return G[mask].pow(2).sum()
# ---------------------------------------------------------------------------
# LearnableCodebook
# ---------------------------------------------------------------------------
class LearnableCodebook(nn.Module):
"""Learnable codebook ``C ∈ ℝ[L, K]`` with unit-norm column output.
The raw parameter ``C_raw`` is stored as ``nn.Parameter``. Calling
the module returns column-normalised ``C`` (differentiable), ensuring
the write and read paths always use normalised codes.
Parameters
----------
L : int
Number of code layers.
K : int
Number of code columns (capacity).
init : str
Initialisation: ``"hadamard"``, ``"dct"``, ``"gaussian"``,
``"random"``, or ``"identity"`` (zero-padded eye).
seed : int
RNG seed for stochastic initialisations.
freeze : bool
If ``True``, ``C_raw`` is not learnable (``requires_grad=False``).
"""
def __init__(
self,
L: int,
K: int,
init: str = "hadamard",
seed: int = 0,
freeze: bool = False,
):
super().__init__()
self.L = L
self.K = K
C_init = _init_codebook(L, K, init, seed)
self.C_raw = nn.Parameter(C_init, requires_grad=not freeze)
def forward(self) -> Tensor:
"""Return column-normalised codebook ``[L, K]``."""
norms = self.C_raw.norm(dim=0, keepdim=True).clamp_min(1e-8)
return self.C_raw / norms
def ortho_loss(self) -> Tensor:
"""Orthogonality regularisation loss (scalar)."""
return orthogonality_loss(self.C_raw)
def coherence(self) -> Dict[str, float]:
"""Current coherence statistics (detached)."""
with torch.no_grad():
return coherence_stats(self.forward())
def gram(self) -> Tensor:
"""Return Gram matrix ``C_n^T C_n`` (differentiable)."""
C_n = self.forward()
return C_n.T @ C_n
def _init_codebook(L: int, K: int, init: str, seed: int = 0) -> Tensor:
"""Create initial codebook tensor."""
init = init.lower().strip()
if init == "hadamard":
return hadamard_codes(L, K)
if init == "dct":
return dct_codes(L, K)
if init == "gaussian":
return gaussian_codes(L, K, seed=seed)
if init == "random":
gen = torch.Generator().manual_seed(seed)
C = torch.randn(L, K, generator=gen)
return normalize_columns(C)
if init == "identity":
# Zero-padded identity: perfect orthogonality if K ≤ L
C = torch.zeros(L, K)
n = min(L, K)
C[:n, :n] = torch.eye(n)
return C
raise ValueError(f"Unknown init '{init}'")