| """Standalone WrinkleBrane sequence model. |
| |
| Direction 9: Assembles continuous addressing (Dir 1), 1D membranes (Dir 4), |
| multi-head banks (Dir 5), learnable codebooks (Dir 6), and iterative |
| refinement (Dir 7) into a complete trainable language model. |
| |
| The WrinkleBrane layer replaces both self-attention and FFN from a |
| transformer. The key innovation is *parallel causal membrane reads* |
| via cumulative sum over per-position membrane deltas, enabling |
| teacher-forced training while preserving causality. |
| |
| Key components |
| -------------- |
| ``WrinkleBraneConfig`` |
| Dataclass holding all hyperparameters. |
| |
| ``PositionalEncoding`` |
| Sinusoidal positional encoding (shared with baseline transformer). |
| |
| ``GatedFFN`` |
| Feed-forward block with zero-initialised gate (Dir 7 insight). |
| |
| ``WrinkleBraneLayer`` |
| Single layer: multi-head causal membrane attention + gated FFN. |
| |
| ``WrinkleBraneModel`` |
| Full model: embedding + positional encoding + N layers + output head. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass, field |
| from typing import Optional, List, Dict, Tuple |
|
|
| import torch |
| from torch import nn, Tensor |
|
|
| from wrinklebrane.learnable_codes import LearnableCodebook, orthogonality_loss |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class WrinkleBraneConfig: |
| """Configuration for a standalone WrinkleBrane model. |
| |
| Groups hyperparameters from all prerequisite directions: |
| - Dirs 1/4/5: membrane dimensions and addressing |
| - Dir 6: learnable codes |
| - Dir 7: gated FFN |
| - Dir 3/8: no activation, optional importance weighting |
| """ |
|
|
| |
| vocab_size: int = 256 |
| d_model: int = 128 |
| max_seq_len: int = 256 |
|
|
| |
| n_layers: int = 4 |
| n_heads: int = 4 |
| L: int = 32 |
| K: int = 64 |
| code_init: str = "hadamard" |
| learnable_codes: bool = True |
|
|
| |
| temperature: float = 0.05 |
|
|
| |
| ffn_expansion: int = 4 |
| use_gated_ffn: bool = True |
|
|
| |
| dropout: float = 0.1 |
| ortho_lambda: float = 0.01 |
|
|
| |
| persistence_lambda: float = 0.99 |
|
|
| |
| weight_tying: bool = True |
|
|
| @property |
| def d_head(self) -> int: |
| """Per-head embedding dimension.""" |
| assert self.d_model % self.n_heads == 0, ( |
| f"d_model={self.d_model} must be divisible by n_heads={self.n_heads}" |
| ) |
| return self.d_model // self.n_heads |
|
|
|
|
| |
| |
| |
|
|
| class PositionalEncoding(nn.Module): |
| """Sinusoidal positional encoding. |
| |
| Adds position-dependent sinusoidal signals to the input embeddings. |
| Compatible with both WrinkleBrane and transformer baselines. |
| """ |
|
|
| def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1): |
| super().__init__() |
| self.dropout = nn.Dropout(p=dropout) |
|
|
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, d_model, 2, dtype=torch.float) |
| * (-math.log(10000.0) / d_model) |
| ) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| self.register_buffer("pe", pe.unsqueeze(0)) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| """Add positional encoding to ``x: [B, T, D]``.""" |
| x = x + self.pe[:, :x.size(1)] |
| return self.dropout(x) |
|
|
|
|
| |
| |
| |
|
|
| class GatedFFN(nn.Module): |
| """Feed-forward block with zero-initialised gate. |
| |
| From Direction 7: ``ResidualGatedProcessor`` dominated all alternatives |
| (+14.2 dB). The zero-init gate means the layer starts as identity, |
| and the network learns what computation to add. |
| |
| ``f(x) = x + gate * MLP(x)`` |
| """ |
|
|
| def __init__(self, d_model: int, expansion: int = 4): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model, d_model * expansion), |
| nn.GELU(), |
| nn.Linear(d_model * expansion, d_model), |
| ) |
| |
| for m in self.mlp: |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| |
| self.gate = nn.Parameter(torch.zeros(1)) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return x + self.gate * self.mlp(x) |
|
|
|
|
| class StandardFFN(nn.Module): |
| """Standard transformer FFN (for comparison when gated FFN is disabled).""" |
|
|
| def __init__(self, d_model: int, expansion: int = 4): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(d_model, d_model * expansion), |
| nn.GELU(), |
| nn.Linear(d_model * expansion, d_model), |
| ) |
| for m in self.net: |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.net(x) |
|
|
|
|
| |
| |
| |
|
|
| def causal_membrane_attention( |
| V_h: Tensor, |
| C_h: Tensor, |
| Q_h: Tensor, |
| P_h: Tensor, |
| temperature: Tensor, |
| persistence_lambda: float = 1.0, |
| ) -> Tensor: |
| """Parallel causal WrinkleBrane attention for one head. |
| |
| This is the key innovation of Direction 9: using cumulative sum over |
| per-position membrane deltas to compute causal readouts in parallel. |
| |
| The computation for each position t: |
| 1. Write: ``M_t = Σ_{i≤t} λ^(t-i) · C[:, key_i] ⊗ V[i]`` (causal prefix with decay) |
| 2. Read: ``Y_t[k] = einsum(M_t, C[:, k])`` (all K slots) |
| 3. Blend: ``out_t = Σ_k softmax(Q_t @ P)[k] · Y_t[k]`` (continuous) |
| |
| When ``persistence_lambda < 1.0``, exponential decay is applied via a |
| rescaled cumulative sum so that training (parallel) and inference |
| (sequential) see identical dynamics:: |
| |
| M_t = Σ_{i≤t} λ^(t-i) · δ_i |
| = λ^t · Σ_{i≤t} λ^(-i) · δ_i |
| |
| Pre-multiply each delta by ``λ^(-i)``, take the cumsum, then |
| post-multiply each result by ``λ^t``. Fully parallel, same result |
| as the sequential recurrence. |
| |
| Parameters |
| ---------- |
| V_h : Tensor ``[B, T, d_head]`` |
| Values to store (projected from input). |
| C_h : Tensor ``[L, K]`` |
| Normalised codebook for this head. |
| Q_h : Tensor ``[B, T, d_head]`` |
| Queries for continuous readout. |
| P_h : Tensor ``[d_head, K]`` |
| Learned read projection (query → code weights). |
| temperature : Tensor |
| Softmax temperature (scalar, learnable). |
| persistence_lambda : float |
| Exponential decay factor applied per timestep. 1.0 means no |
| decay (backward compatible). Values like 0.99 match the |
| sequential forward_step decay. |
| |
| Returns |
| ------- |
| Tensor ``[B, T, d_head]`` |
| Causal readout per position. |
| """ |
| B, T, d = V_h.shape |
| L, K = C_h.shape |
|
|
| |
| keys = torch.arange(T, device=V_h.device) % K |
| code_vecs = C_h[:, keys] |
|
|
| |
| deltas = torch.einsum("lt,btd->btld", code_vecs, V_h) |
|
|
| if persistence_lambda < 1.0: |
| |
| |
| t_idx = torch.arange(T, device=V_h.device, dtype=V_h.dtype) |
| log_lam = math.log(persistence_lambda) |
| |
| inv_decay = torch.exp(-log_lam * t_idx) |
| scaled = deltas * inv_decay[None, :, None, None] |
| M_causal = torch.cumsum(scaled, dim=1) |
| |
| decay = torch.exp(log_lam * t_idx) |
| M_causal = M_causal * decay[None, :, None, None] |
| else: |
| |
| M_causal = torch.cumsum(deltas, dim=1) |
|
|
| |
| Y_all = torch.einsum("btld,lk->btkd", M_causal, C_h) |
|
|
| |
| logits = torch.einsum("btd,dk->btk", Q_h, P_h) |
| weights = torch.softmax(logits / temperature, dim=-1) |
|
|
| |
| output = torch.einsum("btk,btkd->btd", weights, Y_all) |
| return output |
|
|
|
|
| |
| |
| |
|
|
| class WrinkleBraneLayer(nn.Module): |
| """Single WrinkleBrane layer replacing self-attention + FFN. |
| |
| Architecture: |
| 1. Multi-head causal membrane attention (parallel via cumsum) |
| 2. Residual + LayerNorm |
| 3. Gated FFN (Dir 7: zero-init gate) |
| 4. Residual + LayerNorm |
| |
| Parameters |
| ---------- |
| config : WrinkleBraneConfig |
| """ |
|
|
| def __init__(self, config: WrinkleBraneConfig): |
| super().__init__() |
| self.config = config |
| D = config.d_model |
| d_head = config.d_head |
| N = config.n_heads |
|
|
| |
| self.W_v = nn.Linear(D, D, bias=False) |
| self.W_q = nn.Linear(D, D, bias=False) |
|
|
| |
| self.codebooks = nn.ModuleList([ |
| LearnableCodebook( |
| config.L, config.K, |
| init=config.code_init, |
| freeze=not config.learnable_codes, |
| ) |
| for _ in range(N) |
| ]) |
|
|
| |
| self.read_projections = nn.ParameterList([ |
| nn.Parameter(torch.empty(d_head, config.K)) |
| for _ in range(N) |
| ]) |
| for P in self.read_projections: |
| nn.init.xavier_uniform_(P) |
|
|
| |
| self.temperatures = nn.ParameterList([ |
| nn.Parameter(torch.tensor(config.temperature)) |
| for _ in range(N) |
| ]) |
|
|
| |
| self.W_o = nn.Linear(D, D, bias=False) |
|
|
| |
| self.norm1 = nn.LayerNorm(D) |
| self.norm2 = nn.LayerNorm(D) |
|
|
| |
| if config.use_gated_ffn: |
| self.ffn = GatedFFN(D, config.ffn_expansion) |
| else: |
| self.ffn = StandardFFN(D, config.ffn_expansion) |
|
|
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| """Process sequence through membrane attention + FFN. |
| |
| Parameters |
| ---------- |
| x : Tensor ``[B, T, D]`` |
| |
| Returns |
| ------- |
| Tensor ``[B, T, D]`` |
| """ |
| B, T, D = x.shape |
| N = self.config.n_heads |
| d_head = self.config.d_head |
|
|
| |
| residual = x |
| x_normed = self.norm1(x) |
|
|
| |
| V = self.W_v(x_normed) |
| Q = self.W_q(x_normed) |
|
|
| |
| V_heads = V.view(B, T, N, d_head).transpose(1, 2) |
| Q_heads = Q.view(B, T, N, d_head).transpose(1, 2) |
|
|
| |
| head_outputs = [] |
| for h in range(N): |
| C_h = self.codebooks[h]() |
| out_h = causal_membrane_attention( |
| V_h=V_heads[:, h], |
| C_h=C_h, |
| Q_h=Q_heads[:, h], |
| P_h=self.read_projections[h], |
| temperature=self.temperatures[h], |
| persistence_lambda=self.config.persistence_lambda, |
| ) |
| head_outputs.append(out_h) |
|
|
| |
| out = torch.cat(head_outputs, dim=-1) |
| out = self.W_o(out) |
| out = self.dropout(out) |
| x = residual + out |
|
|
| |
| residual = x |
| x = residual + self.dropout(self.ffn(self.norm2(x))) |
|
|
| return x |
|
|
| def forward_step( |
| self, |
| x_t: Tensor, |
| membrane_states: List[Tensor], |
| step: int, |
| ) -> Tuple[Tensor, List[Tensor]]: |
| """Process a single token (sequential / RNN mode). |
| |
| Parameters |
| ---------- |
| x_t : Tensor ``[B, D]`` |
| Single token embedding. |
| membrane_states : list of Tensor ``[B, L, d_head]`` |
| Per-head membrane states from previous step. |
| step : int |
| Current timestep (for key assignment). |
| |
| Returns |
| ------- |
| Tensor ``[B, D]`` |
| Processed token embedding. |
| list of Tensor ``[B, L, d_head]`` |
| Updated membrane states. |
| """ |
| B, D = x_t.shape |
| N = self.config.n_heads |
| d_head = self.config.d_head |
|
|
| |
| residual = x_t |
| x_normed = self.norm1(x_t) |
|
|
| V = self.W_v(x_normed) |
| Q = self.W_q(x_normed) |
|
|
| V_heads = V.view(B, N, d_head) |
| Q_heads = Q.view(B, N, d_head) |
|
|
| new_states = [] |
| head_outputs = [] |
| for h in range(N): |
| C_h = self.codebooks[h]() |
| v_h = V_heads[:, h] |
| q_h = Q_heads[:, h] |
|
|
| M_h = membrane_states[h] |
|
|
| |
| key = step % self.config.K |
| code_vec = C_h[:, key] |
| delta = torch.einsum("l,bd->bld", code_vec, v_h) |
| M_h = M_h + delta |
|
|
| |
| Y = torch.einsum("bld,lk->bkd", M_h, C_h) |
|
|
| |
| logits = torch.einsum("bd,dk->bk", q_h, self.read_projections[h]) |
| weights = torch.softmax( |
| logits / self.temperatures[h], dim=-1 |
| ) |
| out_h = torch.einsum("bk,bkd->bd", weights, Y) |
|
|
| |
| M_h = M_h * self.config.persistence_lambda |
|
|
| new_states.append(M_h) |
| head_outputs.append(out_h) |
|
|
| out = torch.cat(head_outputs, dim=-1) |
| out = self.W_o(out) |
| out = self.dropout(out) |
| x_t = residual + out |
|
|
| |
| residual = x_t |
| x_t = residual + self.dropout(self.ffn(self.norm2(x_t))) |
|
|
| return x_t, new_states |
|
|
| def init_membrane_states(self, B: int) -> List[Tensor]: |
| """Create zero-initialised membrane states for RNN mode.""" |
| device = self.W_v.weight.device |
| dtype = self.W_v.weight.dtype |
| return [ |
| torch.zeros(B, self.config.L, self.config.d_head, |
| device=device, dtype=dtype) |
| for _ in range(self.config.n_heads) |
| ] |
|
|
|
|
| |
| |
| |
|
|
| class WrinkleBraneModel(nn.Module): |
| """Complete WrinkleBrane language model. |
| |
| Architecture: |
| token_embedding → positional_encoding → N × WrinkleBraneLayer |
| → output_norm → output_head |
| |
| Supports two forward modes: |
| - ``forward()``: parallel (training), processes full sequences |
| - ``forward_sequential()``: RNN (inference), token-by-token |
| |
| Parameters |
| ---------- |
| config : WrinkleBraneConfig |
| """ |
|
|
| def __init__(self, config: WrinkleBraneConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| self.embedding = nn.Embedding(config.vocab_size, config.d_model) |
|
|
| |
| self.pos_encoding = PositionalEncoding( |
| config.d_model, config.max_seq_len, dropout=config.dropout, |
| ) |
|
|
| |
| self.layers = nn.ModuleList([ |
| WrinkleBraneLayer(config) for _ in range(config.n_layers) |
| ]) |
|
|
| |
| self.output_norm = nn.LayerNorm(config.d_model) |
| self.output_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
| |
| if config.weight_tying: |
| self.output_head.weight = self.embedding.weight |
|
|
| |
| self._init_weights() |
|
|
| def _init_weights(self) -> None: |
| """Initialise embedding and output projection.""" |
| nn.init.normal_(self.embedding.weight, std=0.02) |
|
|
| def forward(self, token_ids: Tensor) -> Tensor: |
| """Parallel forward pass (for training). |
| |
| Parameters |
| ---------- |
| token_ids : Tensor ``[B, T]`` |
| Long tensor of token indices. |
| |
| Returns |
| ------- |
| Tensor ``[B, T, vocab_size]`` |
| Logits for next-token prediction. |
| """ |
| |
| x = self.embedding(token_ids) * math.sqrt(self.config.d_model) |
| x = self.pos_encoding(x) |
|
|
| |
| for layer in self.layers: |
| x = layer(x) |
|
|
| |
| x = self.output_norm(x) |
| logits = self.output_head(x) |
| return logits |
|
|
| def forward_sequential( |
| self, |
| token_ids: Tensor, |
| states: Optional[List[List[Tensor]]] = None, |
| ) -> Tuple[Tensor, List[List[Tensor]]]: |
| """Sequential (RNN) forward pass. |
| |
| Processes tokens one at a time, maintaining membrane states. |
| Useful for autoregressive generation with fixed memory. |
| |
| Parameters |
| ---------- |
| token_ids : Tensor ``[B, T]`` |
| Token indices. |
| states : list of list of Tensor, optional |
| Per-layer, per-head membrane states. If None, initialised |
| to zeros. |
| |
| Returns |
| ------- |
| Tensor ``[B, T, vocab_size]`` |
| Logits. |
| list of list of Tensor |
| Updated membrane states. |
| """ |
| B, T = token_ids.shape |
|
|
| |
| if states is None: |
| states = [layer.init_membrane_states(B) for layer in self.layers] |
|
|
| outputs = [] |
| for t in range(T): |
| |
| x_t = self.embedding(token_ids[:, t]) * math.sqrt(self.config.d_model) |
| |
| x_t = x_t + self.pos_encoding.pe[:, t] |
| x_t = self.pos_encoding.dropout(x_t) |
|
|
| |
| for i, layer in enumerate(self.layers): |
| x_t, states[i] = layer.forward_step(x_t, states[i], t) |
|
|
| outputs.append(x_t) |
|
|
| |
| x = torch.stack(outputs, dim=1) |
| x = self.output_norm(x) |
| logits = self.output_head(x) |
| return logits, states |
|
|
| def ortho_loss(self) -> Tensor: |
| """Total orthogonality regularisation across all codebooks. |
| |
| Returns |
| ------- |
| Tensor |
| Scalar loss (0 for perfectly orthogonal codes). |
| """ |
| total = torch.tensor(0.0, device=self.embedding.weight.device) |
| for layer in self.layers: |
| for codebook in layer.codebooks: |
| total = total + codebook.ortho_loss() |
| return total |
|
|
| def count_parameters(self) -> Dict[str, int]: |
| """Count parameters by component.""" |
| counts = { |
| "embedding": sum(p.numel() for p in self.embedding.parameters()), |
| "pos_encoding": 0, |
| "output_head": 0 if self.config.weight_tying else sum( |
| p.numel() for p in self.output_head.parameters() |
| ), |
| "output_norm": sum(p.numel() for p in self.output_norm.parameters()), |
| } |
|
|
| layer_params = 0 |
| codebook_params = 0 |
| for layer in self.layers: |
| for name, p in layer.named_parameters(): |
| if "codebook" in name: |
| codebook_params += p.numel() |
| else: |
| layer_params += p.numel() |
| counts["layers"] = layer_params |
| counts["codebooks"] = codebook_params |
| counts["total"] = sum(p.numel() for p in self.parameters()) |
| return counts |
|
|
|
|
| |
| |
| |
|
|
| class RotaryEmbedding(nn.Module): |
| """Rotary Position Embeddings (RoPE). |
| |
| Precomputes sin/cos tables for position-dependent rotation of paired |
| dimensions in query/value vectors. Applied per-layer rather than once |
| at the embedding level, giving the model fresh position signals at |
| every depth. |
| |
| Reference: Su et al., "RoFormer: Enhanced Transformer with Rotary |
| Position Embedding" (2021). |
| """ |
|
|
| def __init__(self, d_head: int, max_seq_len: int = 512, base: float = 10000.0): |
| super().__init__() |
| self.d_head = d_head |
| inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head)) |
| self.register_buffer("inv_freq", inv_freq) |
| self._build_cache(max_seq_len) |
|
|
| def _build_cache(self, seq_len: int) -> None: |
| t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device) |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| self.register_buffer("cos_cached", emb.cos(), persistent=False) |
| self.register_buffer("sin_cached", emb.sin(), persistent=False) |
|
|
| def forward(self, seq_len: int) -> Tuple[Tensor, Tensor]: |
| """Return (cos, sin) tables for positions 0..seq_len-1.""" |
| if seq_len > self.cos_cached.size(0): |
| self._build_cache(seq_len) |
| return self.cos_cached[:seq_len], self.sin_cached[:seq_len] |
|
|
|
|
| def _rotate_half(x: Tensor) -> Tensor: |
| """For x = [x1, x2], return [-x2, x1] (rotate 90° in each pair).""" |
| d = x.shape[-1] |
| return torch.cat([-x[..., d // 2:], x[..., : d // 2]], dim=-1) |
|
|
|
|
| def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: |
| """Apply RoPE rotation to tensor x. |
| |
| Parameters |
| ---------- |
| x : Tensor [..., T, d_head] |
| cos : Tensor [T, d_head] |
| sin : Tensor [T, d_head] |
| """ |
| return x * cos + _rotate_half(x) * sin |
|
|
|
|
| |
| |
| |
|
|
| class WrinkleBraneLayerRoPE(nn.Module): |
| """WrinkleBrane layer with per-layer RoPE instead of additive sinusoidal PE. |
| |
| Identical to ``WrinkleBraneLayer`` except (cos, sin) tables are passed |
| in and applied to V_h and Q_h per-head before membrane attention. |
| This gives fresh positional information at every layer depth, rather |
| than a single additive signal at the embedding level. |
| |
| Benchmark result (Dir 9 Round 3): RoPE wins — PPL 11.73 vs 12.03 |
| at 500 steps with identical parameter count and Muon optimizer. |
| """ |
|
|
| def __init__(self, config: WrinkleBraneConfig): |
| super().__init__() |
| self.config = config |
| D = config.d_model |
| d_head = config.d_head |
| N = config.n_heads |
|
|
| self.W_v = nn.Linear(D, D, bias=False) |
| self.W_q = nn.Linear(D, D, bias=False) |
|
|
| self.codebooks = nn.ModuleList([ |
| LearnableCodebook( |
| config.L, config.K, |
| init=config.code_init, |
| freeze=not config.learnable_codes, |
| ) |
| for _ in range(N) |
| ]) |
|
|
| self.read_projections = nn.ParameterList([ |
| nn.Parameter(torch.empty(d_head, config.K)) |
| for _ in range(N) |
| ]) |
| for P in self.read_projections: |
| nn.init.xavier_uniform_(P) |
|
|
| self.temperatures = nn.ParameterList([ |
| nn.Parameter(torch.tensor(config.temperature)) |
| for _ in range(N) |
| ]) |
|
|
| self.W_o = nn.Linear(D, D, bias=False) |
| self.norm1 = nn.LayerNorm(D) |
| self.norm2 = nn.LayerNorm(D) |
|
|
| if config.use_gated_ffn: |
| self.ffn = GatedFFN(D, config.ffn_expansion) |
| else: |
| self.ffn = StandardFFN(D, config.ffn_expansion) |
|
|
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: |
| """Process sequence with RoPE-rotated V and Q. |
| |
| Parameters |
| ---------- |
| x : Tensor [B, T, D] |
| cos : Tensor [T, d_head] — from RotaryEmbedding |
| sin : Tensor [T, d_head] |
| """ |
| B, T, D = x.shape |
| N = self.config.n_heads |
| d_head = self.config.d_head |
|
|
| residual = x |
| x_normed = self.norm1(x) |
|
|
| V = self.W_v(x_normed) |
| Q = self.W_q(x_normed) |
|
|
| V_heads = V.view(B, T, N, d_head).transpose(1, 2) |
| Q_heads = Q.view(B, T, N, d_head).transpose(1, 2) |
|
|
| |
| V_heads = apply_rotary_emb(V_heads, cos, sin) |
| Q_heads = apply_rotary_emb(Q_heads, cos, sin) |
|
|
| head_outputs = [] |
| for h in range(N): |
| C_h = self.codebooks[h]() |
| out_h = causal_membrane_attention( |
| V_h=V_heads[:, h], |
| C_h=C_h, |
| Q_h=Q_heads[:, h], |
| P_h=self.read_projections[h], |
| temperature=self.temperatures[h], |
| persistence_lambda=self.config.persistence_lambda, |
| ) |
| head_outputs.append(out_h) |
|
|
| out = torch.cat(head_outputs, dim=-1) |
| out = self.W_o(out) |
| out = self.dropout(out) |
| x = residual + out |
|
|
| residual = x |
| x = residual + self.dropout(self.ffn(self.norm2(x))) |
| return x |
|
|
| def forward_step( |
| self, |
| x_t: Tensor, |
| membrane_states: List[Tensor], |
| step: int, |
| cos_t: Tensor, |
| sin_t: Tensor, |
| ) -> Tuple[Tensor, List[Tensor]]: |
| """Process a single token in sequential (RNN) mode with RoPE.""" |
| B, D = x_t.shape |
| N = self.config.n_heads |
| d_head = self.config.d_head |
|
|
| residual = x_t |
| x_normed = self.norm1(x_t) |
|
|
| V = self.W_v(x_normed) |
| Q = self.W_q(x_normed) |
|
|
| V_heads = V.view(B, N, d_head) |
| Q_heads = Q.view(B, N, d_head) |
|
|
| |
| V_heads = V_heads * cos_t + _rotate_half(V_heads) * sin_t |
| Q_heads = Q_heads * cos_t + _rotate_half(Q_heads) * sin_t |
|
|
| new_states = [] |
| head_outputs = [] |
| for h in range(N): |
| C_h = self.codebooks[h]() |
| v_h = V_heads[:, h] |
| q_h = Q_heads[:, h] |
| M_h = membrane_states[h] |
|
|
| key = step % self.config.K |
| code_vec = C_h[:, key] |
| delta = torch.einsum("l,bd->bld", code_vec, v_h) |
| M_h = M_h + delta |
|
|
| Y = torch.einsum("bld,lk->bkd", M_h, C_h) |
|
|
| logits = torch.einsum("bd,dk->bk", q_h, self.read_projections[h]) |
| weights = torch.softmax(logits / self.temperatures[h], dim=-1) |
| out_h = torch.einsum("bk,bkd->bd", weights, Y) |
|
|
| M_h = M_h * self.config.persistence_lambda |
| new_states.append(M_h) |
| head_outputs.append(out_h) |
|
|
| out = torch.cat(head_outputs, dim=-1) |
| out = self.W_o(out) |
| out = self.dropout(out) |
| x_t = residual + out |
|
|
| residual = x_t |
| x_t = residual + self.dropout(self.ffn(self.norm2(x_t))) |
| return x_t, new_states |
|
|
| def init_membrane_states(self, B: int) -> List[Tensor]: |
| device = self.W_v.weight.device |
| dtype = self.W_v.weight.dtype |
| return [ |
| torch.zeros(B, self.config.L, self.config.d_head, |
| device=device, dtype=dtype) |
| for _ in range(self.config.n_heads) |
| ] |
|
|
|
|
| |
| |
| |
|
|
| class WrinkleBraneModelRoPE(nn.Module): |
| """WrinkleBrane language model with RoPE positional encoding. |
| |
| Replaces the one-time additive sinusoidal positional encoding with |
| Rotary Position Embeddings applied to V and Q at every layer. |
| |
| Benchmark result (Dir 9 Round 3, 500 steps, Muon, same param count): |
| Sinusoidal PE → eval PPL 12.03 |
| RoPE → eval PPL 11.73 (+1% improvement, free quality win) |
| |
| Same external interface as ``WrinkleBraneModel``: |
| - ``forward(token_ids) -> logits`` |
| - ``forward_sequential(token_ids, states) -> (logits, states)`` |
| - ``ortho_loss() -> scalar`` |
| - ``count_parameters() -> dict`` |
| """ |
|
|
| def __init__(self, config: WrinkleBraneConfig): |
| super().__init__() |
| self.config = config |
|
|
| self.embedding = nn.Embedding(config.vocab_size, config.d_model) |
| self.rope = RotaryEmbedding(config.d_head, config.max_seq_len) |
| self.embed_dropout = nn.Dropout(config.dropout) |
|
|
| self.layers = nn.ModuleList([ |
| WrinkleBraneLayerRoPE(config) for _ in range(config.n_layers) |
| ]) |
|
|
| self.output_norm = nn.LayerNorm(config.d_model) |
| self.output_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
| if config.weight_tying: |
| self.output_head.weight = self.embedding.weight |
|
|
| self._init_weights() |
|
|
| def _init_weights(self) -> None: |
| nn.init.normal_(self.embedding.weight, std=0.02) |
|
|
| def forward(self, token_ids: Tensor) -> Tensor: |
| """Parallel forward pass (training). |
| |
| Parameters |
| ---------- |
| token_ids : Tensor [B, T] |
| |
| Returns |
| ------- |
| Tensor [B, T, vocab_size] |
| """ |
| B, T = token_ids.shape |
|
|
| x = self.embedding(token_ids) * math.sqrt(self.config.d_model) |
| x = self.embed_dropout(x) |
|
|
| cos, sin = self.rope(T) |
| for layer in self.layers: |
| x = layer(x, cos, sin) |
|
|
| x = self.output_norm(x) |
| return self.output_head(x) |
|
|
| def forward_sequential( |
| self, |
| token_ids: Tensor, |
| states: Optional[List[List[Tensor]]] = None, |
| ) -> Tuple[Tensor, List[List[Tensor]]]: |
| """Sequential (RNN) forward pass for autoregressive generation. |
| |
| Same interface as ``WrinkleBraneModel.forward_sequential``. |
| """ |
| B, T = token_ids.shape |
|
|
| if states is None: |
| states = [layer.init_membrane_states(B) for layer in self.layers] |
|
|
| cos_full, sin_full = self.rope(T) |
|
|
| outputs = [] |
| for t in range(T): |
| x_t = self.embedding(token_ids[:, t]) * math.sqrt(self.config.d_model) |
| x_t = self.embed_dropout(x_t) |
| cos_t = cos_full[t] |
| sin_t = sin_full[t] |
|
|
| for i, layer in enumerate(self.layers): |
| x_t, states[i] = layer.forward_step(x_t, states[i], t, cos_t, sin_t) |
|
|
| outputs.append(x_t) |
|
|
| x = torch.stack(outputs, dim=1) |
| x = self.output_norm(x) |
| return self.output_head(x), states |
|
|
| def ortho_loss(self) -> Tensor: |
| total = torch.tensor(0.0, device=self.embedding.weight.device) |
| for layer in self.layers: |
| for codebook in layer.codebooks: |
| total = total + codebook.ortho_loss() |
| return total |
|
|
| def count_parameters(self) -> Dict[str, int]: |
| counts = { |
| "embedding": sum(p.numel() for p in self.embedding.parameters()), |
| "rope": 0, |
| "output_head": 0 if self.config.weight_tying else sum( |
| p.numel() for p in self.output_head.parameters() |
| ), |
| "output_norm": sum(p.numel() for p in self.output_norm.parameters()), |
| } |
| layer_params = 0 |
| codebook_params = 0 |
| for layer in self.layers: |
| for name, p in layer.named_parameters(): |
| if "codebook" in name: |
| codebook_params += p.numel() |
| else: |
| layer_params += p.numel() |
| counts["layers"] = layer_params |
| counts["codebooks"] = codebook_params |
| counts["total"] = sum(p.numel() for p in self.parameters()) |
| return counts |
|
|