WrinkleBrane / Wrinkle /src /wrinklebrane /standalone_model.py
WCNegentropy's picture
Upload 510 files
3d7f6c5 verified
"""Standalone WrinkleBrane sequence model.
Direction 9: Assembles continuous addressing (Dir 1), 1D membranes (Dir 4),
multi-head banks (Dir 5), learnable codebooks (Dir 6), and iterative
refinement (Dir 7) into a complete trainable language model.
The WrinkleBrane layer replaces both self-attention and FFN from a
transformer. The key innovation is *parallel causal membrane reads*
via cumulative sum over per-position membrane deltas, enabling
teacher-forced training while preserving causality.
Key components
--------------
``WrinkleBraneConfig``
Dataclass holding all hyperparameters.
``PositionalEncoding``
Sinusoidal positional encoding (shared with baseline transformer).
``GatedFFN``
Feed-forward block with zero-initialised gate (Dir 7 insight).
``WrinkleBraneLayer``
Single layer: multi-head causal membrane attention + gated FFN.
``WrinkleBraneModel``
Full model: embedding + positional encoding + N layers + output head.
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Tuple
import torch
from torch import nn, Tensor
from wrinklebrane.learnable_codes import LearnableCodebook, orthogonality_loss
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass
class WrinkleBraneConfig:
"""Configuration for a standalone WrinkleBrane model.
Groups hyperparameters from all prerequisite directions:
- Dirs 1/4/5: membrane dimensions and addressing
- Dir 6: learnable codes
- Dir 7: gated FFN
- Dir 3/8: no activation, optional importance weighting
"""
# Vocabulary & embedding
vocab_size: int = 256
d_model: int = 128
max_seq_len: int = 256
# WrinkleBrane architecture
n_layers: int = 4
n_heads: int = 4
L: int = 32 # code layers per head
K: int = 64 # codes per head (capacity)
code_init: str = "hadamard"
learnable_codes: bool = True
# Continuous addressing (Dir 1)
temperature: float = 0.05
# FFN (Dir 7: ResidualGated architecture)
ffn_expansion: int = 4
use_gated_ffn: bool = True
# Regularization
dropout: float = 0.1
ortho_lambda: float = 0.01
# Persistence (for RNN mode)
persistence_lambda: float = 0.99
# Optional
weight_tying: bool = True
@property
def d_head(self) -> int:
"""Per-head embedding dimension."""
assert self.d_model % self.n_heads == 0, (
f"d_model={self.d_model} must be divisible by n_heads={self.n_heads}"
)
return self.d_model // self.n_heads
# ---------------------------------------------------------------------------
# Positional Encoding
# ---------------------------------------------------------------------------
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding.
Adds position-dependent sinusoidal signals to the input embeddings.
Compatible with both WrinkleBrane and transformer baselines.
"""
def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float)
* (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe.unsqueeze(0)) # [1, max_len, d_model]
def forward(self, x: Tensor) -> Tensor:
"""Add positional encoding to ``x: [B, T, D]``."""
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
# ---------------------------------------------------------------------------
# Gated FFN (Dir 7: ResidualGatedProcessor adapted as FFN)
# ---------------------------------------------------------------------------
class GatedFFN(nn.Module):
"""Feed-forward block with zero-initialised gate.
From Direction 7: ``ResidualGatedProcessor`` dominated all alternatives
(+14.2 dB). The zero-init gate means the layer starts as identity,
and the network learns what computation to add.
``f(x) = x + gate * MLP(x)``
"""
def __init__(self, d_model: int, expansion: int = 4):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model * expansion),
nn.GELU(),
nn.Linear(d_model * expansion, d_model),
)
# Xavier init for MLP weights, zero for biases
for m in self.mlp:
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# Zero-init gate: identity at initialization
self.gate = nn.Parameter(torch.zeros(1))
def forward(self, x: Tensor) -> Tensor:
return x + self.gate * self.mlp(x)
class StandardFFN(nn.Module):
"""Standard transformer FFN (for comparison when gated FFN is disabled)."""
def __init__(self, d_model: int, expansion: int = 4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_model * expansion),
nn.GELU(),
nn.Linear(d_model * expansion, d_model),
)
for m in self.net:
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: Tensor) -> Tensor:
return self.net(x)
# ---------------------------------------------------------------------------
# Causal Membrane Attention (core innovation)
# ---------------------------------------------------------------------------
def causal_membrane_attention(
V_h: Tensor,
C_h: Tensor,
Q_h: Tensor,
P_h: Tensor,
temperature: Tensor,
persistence_lambda: float = 1.0,
) -> Tensor:
"""Parallel causal WrinkleBrane attention for one head.
This is the key innovation of Direction 9: using cumulative sum over
per-position membrane deltas to compute causal readouts in parallel.
The computation for each position t:
1. Write: ``M_t = Σ_{i≤t} λ^(t-i) · C[:, key_i] ⊗ V[i]`` (causal prefix with decay)
2. Read: ``Y_t[k] = einsum(M_t, C[:, k])`` (all K slots)
3. Blend: ``out_t = Σ_k softmax(Q_t @ P)[k] · Y_t[k]`` (continuous)
When ``persistence_lambda < 1.0``, exponential decay is applied via a
rescaled cumulative sum so that training (parallel) and inference
(sequential) see identical dynamics::
M_t = Σ_{i≤t} λ^(t-i) · δ_i
= λ^t · Σ_{i≤t} λ^(-i) · δ_i
Pre-multiply each delta by ``λ^(-i)``, take the cumsum, then
post-multiply each result by ``λ^t``. Fully parallel, same result
as the sequential recurrence.
Parameters
----------
V_h : Tensor ``[B, T, d_head]``
Values to store (projected from input).
C_h : Tensor ``[L, K]``
Normalised codebook for this head.
Q_h : Tensor ``[B, T, d_head]``
Queries for continuous readout.
P_h : Tensor ``[d_head, K]``
Learned read projection (query → code weights).
temperature : Tensor
Softmax temperature (scalar, learnable).
persistence_lambda : float
Exponential decay factor applied per timestep. 1.0 means no
decay (backward compatible). Values like 0.99 match the
sequential forward_step decay.
Returns
-------
Tensor ``[B, T, d_head]``
Causal readout per position.
"""
B, T, d = V_h.shape
L, K = C_h.shape
# Discrete key assignment: token t → key t % K
keys = torch.arange(T, device=V_h.device) % K
code_vecs = C_h[:, keys] # [L, T]
# Per-position membrane deltas: delta_t[l, d] = C[l, key_t] * V[b, t, d]
deltas = torch.einsum("lt,btd->btld", code_vecs, V_h) # [B, T, L, d]
if persistence_lambda < 1.0:
# Parallel exponential decay via rescaled cumsum:
# M_t = Σ_{i≤t} λ^(t-i) · δ_i = λ^t · Σ_{i≤t} λ^(-i) · δ_i
t_idx = torch.arange(T, device=V_h.device, dtype=V_h.dtype)
log_lam = math.log(persistence_lambda)
# Pre-multiply: delta_i * λ^(-i)
inv_decay = torch.exp(-log_lam * t_idx) # [T]
scaled = deltas * inv_decay[None, :, None, None]
M_causal = torch.cumsum(scaled, dim=1)
# Post-multiply: M_t * λ^t
decay = torch.exp(log_lam * t_idx) # [T]
M_causal = M_causal * decay[None, :, None, None]
else:
# No decay — plain causal prefix sum
M_causal = torch.cumsum(deltas, dim=1) # [B, T, L, d]
# Read from each M_t: Y_t[k] = Σ_l M_t[l] * C[l, k]
Y_all = torch.einsum("btld,lk->btkd", M_causal, C_h) # [B, T, K, d]
# Continuous soft blend (Dir 1: write-discrete / read-continuous)
logits = torch.einsum("btd,dk->btk", Q_h, P_h) # [B, T, K]
weights = torch.softmax(logits / temperature, dim=-1) # [B, T, K]
# Weighted readout per position
output = torch.einsum("btk,btkd->btd", weights, Y_all) # [B, T, d]
return output
# ---------------------------------------------------------------------------
# WrinkleBrane Layer
# ---------------------------------------------------------------------------
class WrinkleBraneLayer(nn.Module):
"""Single WrinkleBrane layer replacing self-attention + FFN.
Architecture:
1. Multi-head causal membrane attention (parallel via cumsum)
2. Residual + LayerNorm
3. Gated FFN (Dir 7: zero-init gate)
4. Residual + LayerNorm
Parameters
----------
config : WrinkleBraneConfig
"""
def __init__(self, config: WrinkleBraneConfig):
super().__init__()
self.config = config
D = config.d_model
d_head = config.d_head
N = config.n_heads
# Value and query projections
self.W_v = nn.Linear(D, D, bias=False)
self.W_q = nn.Linear(D, D, bias=False)
# Per-head learnable codebooks (Dir 6)
self.codebooks = nn.ModuleList([
LearnableCodebook(
config.L, config.K,
init=config.code_init,
freeze=not config.learnable_codes,
)
for _ in range(N)
])
# Per-head read projections: d_head → K soft weights (Dir 1)
self.read_projections = nn.ParameterList([
nn.Parameter(torch.empty(d_head, config.K))
for _ in range(N)
])
for P in self.read_projections:
nn.init.xavier_uniform_(P)
# Per-head learnable temperature (Dir 1: sweet spot ~ 0.05)
self.temperatures = nn.ParameterList([
nn.Parameter(torch.tensor(config.temperature))
for _ in range(N)
])
# Output projection
self.W_o = nn.Linear(D, D, bias=False)
# Layer norms (pre-norm style)
self.norm1 = nn.LayerNorm(D)
self.norm2 = nn.LayerNorm(D)
# FFN block
if config.use_gated_ffn:
self.ffn = GatedFFN(D, config.ffn_expansion)
else:
self.ffn = StandardFFN(D, config.ffn_expansion)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: Tensor) -> Tensor:
"""Process sequence through membrane attention + FFN.
Parameters
----------
x : Tensor ``[B, T, D]``
Returns
-------
Tensor ``[B, T, D]``
"""
B, T, D = x.shape
N = self.config.n_heads
d_head = self.config.d_head
# === Membrane Attention Block ===
residual = x
x_normed = self.norm1(x)
# Project values and queries
V = self.W_v(x_normed) # [B, T, D]
Q = self.W_q(x_normed) # [B, T, D]
# Split into heads
V_heads = V.view(B, T, N, d_head).transpose(1, 2) # [B, N, T, d_head]
Q_heads = Q.view(B, T, N, d_head).transpose(1, 2) # [B, N, T, d_head]
# Per-head causal membrane read
head_outputs = []
for h in range(N):
C_h = self.codebooks[h]() # [L, K] normalised
out_h = causal_membrane_attention(
V_h=V_heads[:, h], # [B, T, d_head]
C_h=C_h, # [L, K]
Q_h=Q_heads[:, h], # [B, T, d_head]
P_h=self.read_projections[h], # [d_head, K]
temperature=self.temperatures[h],
persistence_lambda=self.config.persistence_lambda,
)
head_outputs.append(out_h)
# Concatenate heads + output projection
out = torch.cat(head_outputs, dim=-1) # [B, T, D]
out = self.W_o(out)
out = self.dropout(out)
x = residual + out
# === FFN Block ===
residual = x
x = residual + self.dropout(self.ffn(self.norm2(x)))
return x
def forward_step(
self,
x_t: Tensor,
membrane_states: List[Tensor],
step: int,
) -> Tuple[Tensor, List[Tensor]]:
"""Process a single token (sequential / RNN mode).
Parameters
----------
x_t : Tensor ``[B, D]``
Single token embedding.
membrane_states : list of Tensor ``[B, L, d_head]``
Per-head membrane states from previous step.
step : int
Current timestep (for key assignment).
Returns
-------
Tensor ``[B, D]``
Processed token embedding.
list of Tensor ``[B, L, d_head]``
Updated membrane states.
"""
B, D = x_t.shape
N = self.config.n_heads
d_head = self.config.d_head
# === Membrane Attention ===
residual = x_t
x_normed = self.norm1(x_t)
V = self.W_v(x_normed) # [B, D]
Q = self.W_q(x_normed) # [B, D]
V_heads = V.view(B, N, d_head) # [B, N, d_head]
Q_heads = Q.view(B, N, d_head) # [B, N, d_head]
new_states = []
head_outputs = []
for h in range(N):
C_h = self.codebooks[h]() # [L, K]
v_h = V_heads[:, h] # [B, d_head]
q_h = Q_heads[:, h] # [B, d_head]
M_h = membrane_states[h] # [B, L, d_head]
# Write: M += C[:, key] ⊗ v
key = step % self.config.K
code_vec = C_h[:, key] # [L]
delta = torch.einsum("l,bd->bld", code_vec, v_h)
M_h = M_h + delta
# Read: Y = einsum(M, C) → [B, K, d_head]
Y = torch.einsum("bld,lk->bkd", M_h, C_h)
# Continuous blend
logits = torch.einsum("bd,dk->bk", q_h, self.read_projections[h])
weights = torch.softmax(
logits / self.temperatures[h], dim=-1
) # [B, K]
out_h = torch.einsum("bk,bkd->bd", weights, Y) # [B, d_head]
# Persistence decay
M_h = M_h * self.config.persistence_lambda
new_states.append(M_h)
head_outputs.append(out_h)
out = torch.cat(head_outputs, dim=-1) # [B, D]
out = self.W_o(out)
out = self.dropout(out)
x_t = residual + out
# === FFN ===
residual = x_t
x_t = residual + self.dropout(self.ffn(self.norm2(x_t)))
return x_t, new_states
def init_membrane_states(self, B: int) -> List[Tensor]:
"""Create zero-initialised membrane states for RNN mode."""
device = self.W_v.weight.device
dtype = self.W_v.weight.dtype
return [
torch.zeros(B, self.config.L, self.config.d_head,
device=device, dtype=dtype)
for _ in range(self.config.n_heads)
]
# ---------------------------------------------------------------------------
# Full Model
# ---------------------------------------------------------------------------
class WrinkleBraneModel(nn.Module):
"""Complete WrinkleBrane language model.
Architecture:
token_embedding → positional_encoding → N × WrinkleBraneLayer
→ output_norm → output_head
Supports two forward modes:
- ``forward()``: parallel (training), processes full sequences
- ``forward_sequential()``: RNN (inference), token-by-token
Parameters
----------
config : WrinkleBraneConfig
"""
def __init__(self, config: WrinkleBraneConfig):
super().__init__()
self.config = config
# Token embedding
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
# Positional encoding
self.pos_encoding = PositionalEncoding(
config.d_model, config.max_seq_len, dropout=config.dropout,
)
# WrinkleBrane layers
self.layers = nn.ModuleList([
WrinkleBraneLayer(config) for _ in range(config.n_layers)
])
# Output
self.output_norm = nn.LayerNorm(config.d_model)
self.output_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Weight tying (Dir 4: memory efficient)
if config.weight_tying:
self.output_head.weight = self.embedding.weight
# Init weights
self._init_weights()
def _init_weights(self) -> None:
"""Initialise embedding and output projection."""
nn.init.normal_(self.embedding.weight, std=0.02)
def forward(self, token_ids: Tensor) -> Tensor:
"""Parallel forward pass (for training).
Parameters
----------
token_ids : Tensor ``[B, T]``
Long tensor of token indices.
Returns
-------
Tensor ``[B, T, vocab_size]``
Logits for next-token prediction.
"""
# Embed + position
x = self.embedding(token_ids) * math.sqrt(self.config.d_model)
x = self.pos_encoding(x) # [B, T, D]
# Process through WrinkleBrane layers
for layer in self.layers:
x = layer(x)
# Output projection
x = self.output_norm(x)
logits = self.output_head(x) # [B, T, vocab_size]
return logits
def forward_sequential(
self,
token_ids: Tensor,
states: Optional[List[List[Tensor]]] = None,
) -> Tuple[Tensor, List[List[Tensor]]]:
"""Sequential (RNN) forward pass.
Processes tokens one at a time, maintaining membrane states.
Useful for autoregressive generation with fixed memory.
Parameters
----------
token_ids : Tensor ``[B, T]``
Token indices.
states : list of list of Tensor, optional
Per-layer, per-head membrane states. If None, initialised
to zeros.
Returns
-------
Tensor ``[B, T, vocab_size]``
Logits.
list of list of Tensor
Updated membrane states.
"""
B, T = token_ids.shape
# Init states if needed
if states is None:
states = [layer.init_membrane_states(B) for layer in self.layers]
outputs = []
for t in range(T):
# Embed single token
x_t = self.embedding(token_ids[:, t]) * math.sqrt(self.config.d_model)
# Add positional encoding for position t
x_t = x_t + self.pos_encoding.pe[:, t]
x_t = self.pos_encoding.dropout(x_t)
# Process through layers
for i, layer in enumerate(self.layers):
x_t, states[i] = layer.forward_step(x_t, states[i], t)
outputs.append(x_t)
# Stack outputs
x = torch.stack(outputs, dim=1) # [B, T, D]
x = self.output_norm(x)
logits = self.output_head(x)
return logits, states
def ortho_loss(self) -> Tensor:
"""Total orthogonality regularisation across all codebooks.
Returns
-------
Tensor
Scalar loss (0 for perfectly orthogonal codes).
"""
total = torch.tensor(0.0, device=self.embedding.weight.device)
for layer in self.layers:
for codebook in layer.codebooks:
total = total + codebook.ortho_loss()
return total
def count_parameters(self) -> Dict[str, int]:
"""Count parameters by component."""
counts = {
"embedding": sum(p.numel() for p in self.embedding.parameters()),
"pos_encoding": 0, # buffer, not parameter
"output_head": 0 if self.config.weight_tying else sum(
p.numel() for p in self.output_head.parameters()
),
"output_norm": sum(p.numel() for p in self.output_norm.parameters()),
}
layer_params = 0
codebook_params = 0
for layer in self.layers:
for name, p in layer.named_parameters():
if "codebook" in name:
codebook_params += p.numel()
else:
layer_params += p.numel()
counts["layers"] = layer_params
counts["codebooks"] = codebook_params
counts["total"] = sum(p.numel() for p in self.parameters())
return counts
# ---------------------------------------------------------------------------
# Rotary Position Embeddings (RoPE)
# ---------------------------------------------------------------------------
class RotaryEmbedding(nn.Module):
"""Rotary Position Embeddings (RoPE).
Precomputes sin/cos tables for position-dependent rotation of paired
dimensions in query/value vectors. Applied per-layer rather than once
at the embedding level, giving the model fresh position signals at
every depth.
Reference: Su et al., "RoFormer: Enhanced Transformer with Rotary
Position Embedding" (2021).
"""
def __init__(self, d_head: int, max_seq_len: int = 512, base: float = 10000.0):
super().__init__()
self.d_head = d_head
inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head))
self.register_buffer("inv_freq", inv_freq)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int) -> None:
t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) # [T, d/2]
emb = torch.cat([freqs, freqs], dim=-1) # [T, d]
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, seq_len: int) -> Tuple[Tensor, Tensor]:
"""Return (cos, sin) tables for positions 0..seq_len-1."""
if seq_len > self.cos_cached.size(0):
self._build_cache(seq_len)
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def _rotate_half(x: Tensor) -> Tensor:
"""For x = [x1, x2], return [-x2, x1] (rotate 90° in each pair)."""
d = x.shape[-1]
return torch.cat([-x[..., d // 2:], x[..., : d // 2]], dim=-1)
def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
"""Apply RoPE rotation to tensor x.
Parameters
----------
x : Tensor [..., T, d_head]
cos : Tensor [T, d_head]
sin : Tensor [T, d_head]
"""
return x * cos + _rotate_half(x) * sin
# ---------------------------------------------------------------------------
# RoPE-enabled WrinkleBrane Layer
# ---------------------------------------------------------------------------
class WrinkleBraneLayerRoPE(nn.Module):
"""WrinkleBrane layer with per-layer RoPE instead of additive sinusoidal PE.
Identical to ``WrinkleBraneLayer`` except (cos, sin) tables are passed
in and applied to V_h and Q_h per-head before membrane attention.
This gives fresh positional information at every layer depth, rather
than a single additive signal at the embedding level.
Benchmark result (Dir 9 Round 3): RoPE wins — PPL 11.73 vs 12.03
at 500 steps with identical parameter count and Muon optimizer.
"""
def __init__(self, config: WrinkleBraneConfig):
super().__init__()
self.config = config
D = config.d_model
d_head = config.d_head
N = config.n_heads
self.W_v = nn.Linear(D, D, bias=False)
self.W_q = nn.Linear(D, D, bias=False)
self.codebooks = nn.ModuleList([
LearnableCodebook(
config.L, config.K,
init=config.code_init,
freeze=not config.learnable_codes,
)
for _ in range(N)
])
self.read_projections = nn.ParameterList([
nn.Parameter(torch.empty(d_head, config.K))
for _ in range(N)
])
for P in self.read_projections:
nn.init.xavier_uniform_(P)
self.temperatures = nn.ParameterList([
nn.Parameter(torch.tensor(config.temperature))
for _ in range(N)
])
self.W_o = nn.Linear(D, D, bias=False)
self.norm1 = nn.LayerNorm(D)
self.norm2 = nn.LayerNorm(D)
if config.use_gated_ffn:
self.ffn = GatedFFN(D, config.ffn_expansion)
else:
self.ffn = StandardFFN(D, config.ffn_expansion)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
"""Process sequence with RoPE-rotated V and Q.
Parameters
----------
x : Tensor [B, T, D]
cos : Tensor [T, d_head] — from RotaryEmbedding
sin : Tensor [T, d_head]
"""
B, T, D = x.shape
N = self.config.n_heads
d_head = self.config.d_head
residual = x
x_normed = self.norm1(x)
V = self.W_v(x_normed)
Q = self.W_q(x_normed)
V_heads = V.view(B, T, N, d_head).transpose(1, 2) # [B, N, T, d]
Q_heads = Q.view(B, T, N, d_head).transpose(1, 2)
# Apply RoPE to all heads simultaneously (broadcast over B, N)
V_heads = apply_rotary_emb(V_heads, cos, sin)
Q_heads = apply_rotary_emb(Q_heads, cos, sin)
head_outputs = []
for h in range(N):
C_h = self.codebooks[h]()
out_h = causal_membrane_attention(
V_h=V_heads[:, h],
C_h=C_h,
Q_h=Q_heads[:, h],
P_h=self.read_projections[h],
temperature=self.temperatures[h],
persistence_lambda=self.config.persistence_lambda,
)
head_outputs.append(out_h)
out = torch.cat(head_outputs, dim=-1)
out = self.W_o(out)
out = self.dropout(out)
x = residual + out
residual = x
x = residual + self.dropout(self.ffn(self.norm2(x)))
return x
def forward_step(
self,
x_t: Tensor,
membrane_states: List[Tensor],
step: int,
cos_t: Tensor,
sin_t: Tensor,
) -> Tuple[Tensor, List[Tensor]]:
"""Process a single token in sequential (RNN) mode with RoPE."""
B, D = x_t.shape
N = self.config.n_heads
d_head = self.config.d_head
residual = x_t
x_normed = self.norm1(x_t)
V = self.W_v(x_normed)
Q = self.W_q(x_normed)
V_heads = V.view(B, N, d_head)
Q_heads = Q.view(B, N, d_head)
# Apply RoPE for this single position
V_heads = V_heads * cos_t + _rotate_half(V_heads) * sin_t
Q_heads = Q_heads * cos_t + _rotate_half(Q_heads) * sin_t
new_states = []
head_outputs = []
for h in range(N):
C_h = self.codebooks[h]()
v_h = V_heads[:, h]
q_h = Q_heads[:, h]
M_h = membrane_states[h]
key = step % self.config.K
code_vec = C_h[:, key]
delta = torch.einsum("l,bd->bld", code_vec, v_h)
M_h = M_h + delta
Y = torch.einsum("bld,lk->bkd", M_h, C_h)
logits = torch.einsum("bd,dk->bk", q_h, self.read_projections[h])
weights = torch.softmax(logits / self.temperatures[h], dim=-1)
out_h = torch.einsum("bk,bkd->bd", weights, Y)
M_h = M_h * self.config.persistence_lambda
new_states.append(M_h)
head_outputs.append(out_h)
out = torch.cat(head_outputs, dim=-1)
out = self.W_o(out)
out = self.dropout(out)
x_t = residual + out
residual = x_t
x_t = residual + self.dropout(self.ffn(self.norm2(x_t)))
return x_t, new_states
def init_membrane_states(self, B: int) -> List[Tensor]:
device = self.W_v.weight.device
dtype = self.W_v.weight.dtype
return [
torch.zeros(B, self.config.L, self.config.d_head,
device=device, dtype=dtype)
for _ in range(self.config.n_heads)
]
# ---------------------------------------------------------------------------
# RoPE-enabled Full Model
# ---------------------------------------------------------------------------
class WrinkleBraneModelRoPE(nn.Module):
"""WrinkleBrane language model with RoPE positional encoding.
Replaces the one-time additive sinusoidal positional encoding with
Rotary Position Embeddings applied to V and Q at every layer.
Benchmark result (Dir 9 Round 3, 500 steps, Muon, same param count):
Sinusoidal PE → eval PPL 12.03
RoPE → eval PPL 11.73 (+1% improvement, free quality win)
Same external interface as ``WrinkleBraneModel``:
- ``forward(token_ids) -> logits``
- ``forward_sequential(token_ids, states) -> (logits, states)``
- ``ortho_loss() -> scalar``
- ``count_parameters() -> dict``
"""
def __init__(self, config: WrinkleBraneConfig):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.rope = RotaryEmbedding(config.d_head, config.max_seq_len)
self.embed_dropout = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([
WrinkleBraneLayerRoPE(config) for _ in range(config.n_layers)
])
self.output_norm = nn.LayerNorm(config.d_model)
self.output_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
if config.weight_tying:
self.output_head.weight = self.embedding.weight
self._init_weights()
def _init_weights(self) -> None:
nn.init.normal_(self.embedding.weight, std=0.02)
def forward(self, token_ids: Tensor) -> Tensor:
"""Parallel forward pass (training).
Parameters
----------
token_ids : Tensor [B, T]
Returns
-------
Tensor [B, T, vocab_size]
"""
B, T = token_ids.shape
x = self.embedding(token_ids) * math.sqrt(self.config.d_model)
x = self.embed_dropout(x)
cos, sin = self.rope(T)
for layer in self.layers:
x = layer(x, cos, sin)
x = self.output_norm(x)
return self.output_head(x)
def forward_sequential(
self,
token_ids: Tensor,
states: Optional[List[List[Tensor]]] = None,
) -> Tuple[Tensor, List[List[Tensor]]]:
"""Sequential (RNN) forward pass for autoregressive generation.
Same interface as ``WrinkleBraneModel.forward_sequential``.
"""
B, T = token_ids.shape
if states is None:
states = [layer.init_membrane_states(B) for layer in self.layers]
cos_full, sin_full = self.rope(T)
outputs = []
for t in range(T):
x_t = self.embedding(token_ids[:, t]) * math.sqrt(self.config.d_model)
x_t = self.embed_dropout(x_t)
cos_t = cos_full[t]
sin_t = sin_full[t]
for i, layer in enumerate(self.layers):
x_t, states[i] = layer.forward_step(x_t, states[i], t, cos_t, sin_t)
outputs.append(x_t)
x = torch.stack(outputs, dim=1)
x = self.output_norm(x)
return self.output_head(x), states
def ortho_loss(self) -> Tensor:
total = torch.tensor(0.0, device=self.embedding.weight.device)
for layer in self.layers:
for codebook in layer.codebooks:
total = total + codebook.ortho_loss()
return total
def count_parameters(self) -> Dict[str, int]:
counts = {
"embedding": sum(p.numel() for p in self.embedding.parameters()),
"rope": 0, # buffers only, no learnable params
"output_head": 0 if self.config.weight_tying else sum(
p.numel() for p in self.output_head.parameters()
),
"output_norm": sum(p.numel() for p in self.output_norm.parameters()),
}
layer_params = 0
codebook_params = 0
for layer in self.layers:
for name, p in layer.named_parameters():
if "codebook" in name:
codebook_params += p.numel()
else:
layer_params += p.numel()
counts["layers"] = layer_params
counts["codebooks"] = codebook_params
counts["total"] = sum(p.numel() for p in self.parameters())
return counts