"""Standalone WrinkleBrane sequence model. Direction 9: Assembles continuous addressing (Dir 1), 1D membranes (Dir 4), multi-head banks (Dir 5), learnable codebooks (Dir 6), and iterative refinement (Dir 7) into a complete trainable language model. The WrinkleBrane layer replaces both self-attention and FFN from a transformer. The key innovation is *parallel causal membrane reads* via cumulative sum over per-position membrane deltas, enabling teacher-forced training while preserving causality. Key components -------------- ``WrinkleBraneConfig`` Dataclass holding all hyperparameters. ``PositionalEncoding`` Sinusoidal positional encoding (shared with baseline transformer). ``GatedFFN`` Feed-forward block with zero-initialised gate (Dir 7 insight). ``WrinkleBraneLayer`` Single layer: multi-head causal membrane attention + gated FFN. ``WrinkleBraneModel`` Full model: embedding + positional encoding + N layers + output head. """ from __future__ import annotations import math from dataclasses import dataclass, field from typing import Optional, List, Dict, Tuple import torch from torch import nn, Tensor from wrinklebrane.learnable_codes import LearnableCodebook, orthogonality_loss # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- @dataclass class WrinkleBraneConfig: """Configuration for a standalone WrinkleBrane model. Groups hyperparameters from all prerequisite directions: - Dirs 1/4/5: membrane dimensions and addressing - Dir 6: learnable codes - Dir 7: gated FFN - Dir 3/8: no activation, optional importance weighting """ # Vocabulary & embedding vocab_size: int = 256 d_model: int = 128 max_seq_len: int = 256 # WrinkleBrane architecture n_layers: int = 4 n_heads: int = 4 L: int = 32 # code layers per head K: int = 64 # codes per head (capacity) code_init: str = "hadamard" learnable_codes: bool = True # Continuous addressing (Dir 1) temperature: float = 0.05 # FFN (Dir 7: ResidualGated architecture) ffn_expansion: int = 4 use_gated_ffn: bool = True # Regularization dropout: float = 0.1 ortho_lambda: float = 0.01 # Persistence (for RNN mode) persistence_lambda: float = 0.99 # Optional weight_tying: bool = True @property def d_head(self) -> int: """Per-head embedding dimension.""" assert self.d_model % self.n_heads == 0, ( f"d_model={self.d_model} must be divisible by n_heads={self.n_heads}" ) return self.d_model // self.n_heads # --------------------------------------------------------------------------- # Positional Encoding # --------------------------------------------------------------------------- class PositionalEncoding(nn.Module): """Sinusoidal positional encoding. Adds position-dependent sinusoidal signals to the input embeddings. Compatible with both WrinkleBrane and transformer baselines. """ def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1): super().__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe.unsqueeze(0)) # [1, max_len, d_model] def forward(self, x: Tensor) -> Tensor: """Add positional encoding to ``x: [B, T, D]``.""" x = x + self.pe[:, :x.size(1)] return self.dropout(x) # --------------------------------------------------------------------------- # Gated FFN (Dir 7: ResidualGatedProcessor adapted as FFN) # --------------------------------------------------------------------------- class GatedFFN(nn.Module): """Feed-forward block with zero-initialised gate. From Direction 7: ``ResidualGatedProcessor`` dominated all alternatives (+14.2 dB). The zero-init gate means the layer starts as identity, and the network learns what computation to add. ``f(x) = x + gate * MLP(x)`` """ def __init__(self, d_model: int, expansion: int = 4): super().__init__() self.mlp = nn.Sequential( nn.Linear(d_model, d_model * expansion), nn.GELU(), nn.Linear(d_model * expansion, d_model), ) # Xavier init for MLP weights, zero for biases for m in self.mlp: if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) # Zero-init gate: identity at initialization self.gate = nn.Parameter(torch.zeros(1)) def forward(self, x: Tensor) -> Tensor: return x + self.gate * self.mlp(x) class StandardFFN(nn.Module): """Standard transformer FFN (for comparison when gated FFN is disabled).""" def __init__(self, d_model: int, expansion: int = 4): super().__init__() self.net = nn.Sequential( nn.Linear(d_model, d_model * expansion), nn.GELU(), nn.Linear(d_model * expansion, d_model), ) for m in self.net: if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x: Tensor) -> Tensor: return self.net(x) # --------------------------------------------------------------------------- # Causal Membrane Attention (core innovation) # --------------------------------------------------------------------------- def causal_membrane_attention( V_h: Tensor, C_h: Tensor, Q_h: Tensor, P_h: Tensor, temperature: Tensor, persistence_lambda: float = 1.0, ) -> Tensor: """Parallel causal WrinkleBrane attention for one head. This is the key innovation of Direction 9: using cumulative sum over per-position membrane deltas to compute causal readouts in parallel. The computation for each position t: 1. Write: ``M_t = Σ_{i≤t} λ^(t-i) · C[:, key_i] ⊗ V[i]`` (causal prefix with decay) 2. Read: ``Y_t[k] = einsum(M_t, C[:, k])`` (all K slots) 3. Blend: ``out_t = Σ_k softmax(Q_t @ P)[k] · Y_t[k]`` (continuous) When ``persistence_lambda < 1.0``, exponential decay is applied via a rescaled cumulative sum so that training (parallel) and inference (sequential) see identical dynamics:: M_t = Σ_{i≤t} λ^(t-i) · δ_i = λ^t · Σ_{i≤t} λ^(-i) · δ_i Pre-multiply each delta by ``λ^(-i)``, take the cumsum, then post-multiply each result by ``λ^t``. Fully parallel, same result as the sequential recurrence. Parameters ---------- V_h : Tensor ``[B, T, d_head]`` Values to store (projected from input). C_h : Tensor ``[L, K]`` Normalised codebook for this head. Q_h : Tensor ``[B, T, d_head]`` Queries for continuous readout. P_h : Tensor ``[d_head, K]`` Learned read projection (query → code weights). temperature : Tensor Softmax temperature (scalar, learnable). persistence_lambda : float Exponential decay factor applied per timestep. 1.0 means no decay (backward compatible). Values like 0.99 match the sequential forward_step decay. Returns ------- Tensor ``[B, T, d_head]`` Causal readout per position. """ B, T, d = V_h.shape L, K = C_h.shape # Discrete key assignment: token t → key t % K keys = torch.arange(T, device=V_h.device) % K code_vecs = C_h[:, keys] # [L, T] # Per-position membrane deltas: delta_t[l, d] = C[l, key_t] * V[b, t, d] deltas = torch.einsum("lt,btd->btld", code_vecs, V_h) # [B, T, L, d] if persistence_lambda < 1.0: # Parallel exponential decay via rescaled cumsum: # M_t = Σ_{i≤t} λ^(t-i) · δ_i = λ^t · Σ_{i≤t} λ^(-i) · δ_i t_idx = torch.arange(T, device=V_h.device, dtype=V_h.dtype) log_lam = math.log(persistence_lambda) # Pre-multiply: delta_i * λ^(-i) inv_decay = torch.exp(-log_lam * t_idx) # [T] scaled = deltas * inv_decay[None, :, None, None] M_causal = torch.cumsum(scaled, dim=1) # Post-multiply: M_t * λ^t decay = torch.exp(log_lam * t_idx) # [T] M_causal = M_causal * decay[None, :, None, None] else: # No decay — plain causal prefix sum M_causal = torch.cumsum(deltas, dim=1) # [B, T, L, d] # Read from each M_t: Y_t[k] = Σ_l M_t[l] * C[l, k] Y_all = torch.einsum("btld,lk->btkd", M_causal, C_h) # [B, T, K, d] # Continuous soft blend (Dir 1: write-discrete / read-continuous) logits = torch.einsum("btd,dk->btk", Q_h, P_h) # [B, T, K] weights = torch.softmax(logits / temperature, dim=-1) # [B, T, K] # Weighted readout per position output = torch.einsum("btk,btkd->btd", weights, Y_all) # [B, T, d] return output # --------------------------------------------------------------------------- # WrinkleBrane Layer # --------------------------------------------------------------------------- class WrinkleBraneLayer(nn.Module): """Single WrinkleBrane layer replacing self-attention + FFN. Architecture: 1. Multi-head causal membrane attention (parallel via cumsum) 2. Residual + LayerNorm 3. Gated FFN (Dir 7: zero-init gate) 4. Residual + LayerNorm Parameters ---------- config : WrinkleBraneConfig """ def __init__(self, config: WrinkleBraneConfig): super().__init__() self.config = config D = config.d_model d_head = config.d_head N = config.n_heads # Value and query projections self.W_v = nn.Linear(D, D, bias=False) self.W_q = nn.Linear(D, D, bias=False) # Per-head learnable codebooks (Dir 6) self.codebooks = nn.ModuleList([ LearnableCodebook( config.L, config.K, init=config.code_init, freeze=not config.learnable_codes, ) for _ in range(N) ]) # Per-head read projections: d_head → K soft weights (Dir 1) self.read_projections = nn.ParameterList([ nn.Parameter(torch.empty(d_head, config.K)) for _ in range(N) ]) for P in self.read_projections: nn.init.xavier_uniform_(P) # Per-head learnable temperature (Dir 1: sweet spot ~ 0.05) self.temperatures = nn.ParameterList([ nn.Parameter(torch.tensor(config.temperature)) for _ in range(N) ]) # Output projection self.W_o = nn.Linear(D, D, bias=False) # Layer norms (pre-norm style) self.norm1 = nn.LayerNorm(D) self.norm2 = nn.LayerNorm(D) # FFN block if config.use_gated_ffn: self.ffn = GatedFFN(D, config.ffn_expansion) else: self.ffn = StandardFFN(D, config.ffn_expansion) self.dropout = nn.Dropout(config.dropout) def forward(self, x: Tensor) -> Tensor: """Process sequence through membrane attention + FFN. Parameters ---------- x : Tensor ``[B, T, D]`` Returns ------- Tensor ``[B, T, D]`` """ B, T, D = x.shape N = self.config.n_heads d_head = self.config.d_head # === Membrane Attention Block === residual = x x_normed = self.norm1(x) # Project values and queries V = self.W_v(x_normed) # [B, T, D] Q = self.W_q(x_normed) # [B, T, D] # Split into heads V_heads = V.view(B, T, N, d_head).transpose(1, 2) # [B, N, T, d_head] Q_heads = Q.view(B, T, N, d_head).transpose(1, 2) # [B, N, T, d_head] # Per-head causal membrane read head_outputs = [] for h in range(N): C_h = self.codebooks[h]() # [L, K] normalised out_h = causal_membrane_attention( V_h=V_heads[:, h], # [B, T, d_head] C_h=C_h, # [L, K] Q_h=Q_heads[:, h], # [B, T, d_head] P_h=self.read_projections[h], # [d_head, K] temperature=self.temperatures[h], persistence_lambda=self.config.persistence_lambda, ) head_outputs.append(out_h) # Concatenate heads + output projection out = torch.cat(head_outputs, dim=-1) # [B, T, D] out = self.W_o(out) out = self.dropout(out) x = residual + out # === FFN Block === residual = x x = residual + self.dropout(self.ffn(self.norm2(x))) return x def forward_step( self, x_t: Tensor, membrane_states: List[Tensor], step: int, ) -> Tuple[Tensor, List[Tensor]]: """Process a single token (sequential / RNN mode). Parameters ---------- x_t : Tensor ``[B, D]`` Single token embedding. membrane_states : list of Tensor ``[B, L, d_head]`` Per-head membrane states from previous step. step : int Current timestep (for key assignment). Returns ------- Tensor ``[B, D]`` Processed token embedding. list of Tensor ``[B, L, d_head]`` Updated membrane states. """ B, D = x_t.shape N = self.config.n_heads d_head = self.config.d_head # === Membrane Attention === residual = x_t x_normed = self.norm1(x_t) V = self.W_v(x_normed) # [B, D] Q = self.W_q(x_normed) # [B, D] V_heads = V.view(B, N, d_head) # [B, N, d_head] Q_heads = Q.view(B, N, d_head) # [B, N, d_head] new_states = [] head_outputs = [] for h in range(N): C_h = self.codebooks[h]() # [L, K] v_h = V_heads[:, h] # [B, d_head] q_h = Q_heads[:, h] # [B, d_head] M_h = membrane_states[h] # [B, L, d_head] # Write: M += C[:, key] ⊗ v key = step % self.config.K code_vec = C_h[:, key] # [L] delta = torch.einsum("l,bd->bld", code_vec, v_h) M_h = M_h + delta # Read: Y = einsum(M, C) → [B, K, d_head] Y = torch.einsum("bld,lk->bkd", M_h, C_h) # Continuous blend logits = torch.einsum("bd,dk->bk", q_h, self.read_projections[h]) weights = torch.softmax( logits / self.temperatures[h], dim=-1 ) # [B, K] out_h = torch.einsum("bk,bkd->bd", weights, Y) # [B, d_head] # Persistence decay M_h = M_h * self.config.persistence_lambda new_states.append(M_h) head_outputs.append(out_h) out = torch.cat(head_outputs, dim=-1) # [B, D] out = self.W_o(out) out = self.dropout(out) x_t = residual + out # === FFN === residual = x_t x_t = residual + self.dropout(self.ffn(self.norm2(x_t))) return x_t, new_states def init_membrane_states(self, B: int) -> List[Tensor]: """Create zero-initialised membrane states for RNN mode.""" device = self.W_v.weight.device dtype = self.W_v.weight.dtype return [ torch.zeros(B, self.config.L, self.config.d_head, device=device, dtype=dtype) for _ in range(self.config.n_heads) ] # --------------------------------------------------------------------------- # Full Model # --------------------------------------------------------------------------- class WrinkleBraneModel(nn.Module): """Complete WrinkleBrane language model. Architecture: token_embedding → positional_encoding → N × WrinkleBraneLayer → output_norm → output_head Supports two forward modes: - ``forward()``: parallel (training), processes full sequences - ``forward_sequential()``: RNN (inference), token-by-token Parameters ---------- config : WrinkleBraneConfig """ def __init__(self, config: WrinkleBraneConfig): super().__init__() self.config = config # Token embedding self.embedding = nn.Embedding(config.vocab_size, config.d_model) # Positional encoding self.pos_encoding = PositionalEncoding( config.d_model, config.max_seq_len, dropout=config.dropout, ) # WrinkleBrane layers self.layers = nn.ModuleList([ WrinkleBraneLayer(config) for _ in range(config.n_layers) ]) # Output self.output_norm = nn.LayerNorm(config.d_model) self.output_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # Weight tying (Dir 4: memory efficient) if config.weight_tying: self.output_head.weight = self.embedding.weight # Init weights self._init_weights() def _init_weights(self) -> None: """Initialise embedding and output projection.""" nn.init.normal_(self.embedding.weight, std=0.02) def forward(self, token_ids: Tensor) -> Tensor: """Parallel forward pass (for training). Parameters ---------- token_ids : Tensor ``[B, T]`` Long tensor of token indices. Returns ------- Tensor ``[B, T, vocab_size]`` Logits for next-token prediction. """ # Embed + position x = self.embedding(token_ids) * math.sqrt(self.config.d_model) x = self.pos_encoding(x) # [B, T, D] # Process through WrinkleBrane layers for layer in self.layers: x = layer(x) # Output projection x = self.output_norm(x) logits = self.output_head(x) # [B, T, vocab_size] return logits def forward_sequential( self, token_ids: Tensor, states: Optional[List[List[Tensor]]] = None, ) -> Tuple[Tensor, List[List[Tensor]]]: """Sequential (RNN) forward pass. Processes tokens one at a time, maintaining membrane states. Useful for autoregressive generation with fixed memory. Parameters ---------- token_ids : Tensor ``[B, T]`` Token indices. states : list of list of Tensor, optional Per-layer, per-head membrane states. If None, initialised to zeros. Returns ------- Tensor ``[B, T, vocab_size]`` Logits. list of list of Tensor Updated membrane states. """ B, T = token_ids.shape # Init states if needed if states is None: states = [layer.init_membrane_states(B) for layer in self.layers] outputs = [] for t in range(T): # Embed single token x_t = self.embedding(token_ids[:, t]) * math.sqrt(self.config.d_model) # Add positional encoding for position t x_t = x_t + self.pos_encoding.pe[:, t] x_t = self.pos_encoding.dropout(x_t) # Process through layers for i, layer in enumerate(self.layers): x_t, states[i] = layer.forward_step(x_t, states[i], t) outputs.append(x_t) # Stack outputs x = torch.stack(outputs, dim=1) # [B, T, D] x = self.output_norm(x) logits = self.output_head(x) return logits, states def ortho_loss(self) -> Tensor: """Total orthogonality regularisation across all codebooks. Returns ------- Tensor Scalar loss (0 for perfectly orthogonal codes). """ total = torch.tensor(0.0, device=self.embedding.weight.device) for layer in self.layers: for codebook in layer.codebooks: total = total + codebook.ortho_loss() return total def count_parameters(self) -> Dict[str, int]: """Count parameters by component.""" counts = { "embedding": sum(p.numel() for p in self.embedding.parameters()), "pos_encoding": 0, # buffer, not parameter "output_head": 0 if self.config.weight_tying else sum( p.numel() for p in self.output_head.parameters() ), "output_norm": sum(p.numel() for p in self.output_norm.parameters()), } layer_params = 0 codebook_params = 0 for layer in self.layers: for name, p in layer.named_parameters(): if "codebook" in name: codebook_params += p.numel() else: layer_params += p.numel() counts["layers"] = layer_params counts["codebooks"] = codebook_params counts["total"] = sum(p.numel() for p in self.parameters()) return counts # --------------------------------------------------------------------------- # Rotary Position Embeddings (RoPE) # --------------------------------------------------------------------------- class RotaryEmbedding(nn.Module): """Rotary Position Embeddings (RoPE). Precomputes sin/cos tables for position-dependent rotation of paired dimensions in query/value vectors. Applied per-layer rather than once at the embedding level, giving the model fresh position signals at every depth. Reference: Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding" (2021). """ def __init__(self, d_head: int, max_seq_len: int = 512, base: float = 10000.0): super().__init__() self.d_head = d_head inv_freq = 1.0 / (base ** (torch.arange(0, d_head, 2).float() / d_head)) self.register_buffer("inv_freq", inv_freq) self._build_cache(max_seq_len) def _build_cache(self, seq_len: int) -> None: t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # [T, d/2] emb = torch.cat([freqs, freqs], dim=-1) # [T, d] self.register_buffer("cos_cached", emb.cos(), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward(self, seq_len: int) -> Tuple[Tensor, Tensor]: """Return (cos, sin) tables for positions 0..seq_len-1.""" if seq_len > self.cos_cached.size(0): self._build_cache(seq_len) return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def _rotate_half(x: Tensor) -> Tensor: """For x = [x1, x2], return [-x2, x1] (rotate 90° in each pair).""" d = x.shape[-1] return torch.cat([-x[..., d // 2:], x[..., : d // 2]], dim=-1) def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: """Apply RoPE rotation to tensor x. Parameters ---------- x : Tensor [..., T, d_head] cos : Tensor [T, d_head] sin : Tensor [T, d_head] """ return x * cos + _rotate_half(x) * sin # --------------------------------------------------------------------------- # RoPE-enabled WrinkleBrane Layer # --------------------------------------------------------------------------- class WrinkleBraneLayerRoPE(nn.Module): """WrinkleBrane layer with per-layer RoPE instead of additive sinusoidal PE. Identical to ``WrinkleBraneLayer`` except (cos, sin) tables are passed in and applied to V_h and Q_h per-head before membrane attention. This gives fresh positional information at every layer depth, rather than a single additive signal at the embedding level. Benchmark result (Dir 9 Round 3): RoPE wins — PPL 11.73 vs 12.03 at 500 steps with identical parameter count and Muon optimizer. """ def __init__(self, config: WrinkleBraneConfig): super().__init__() self.config = config D = config.d_model d_head = config.d_head N = config.n_heads self.W_v = nn.Linear(D, D, bias=False) self.W_q = nn.Linear(D, D, bias=False) self.codebooks = nn.ModuleList([ LearnableCodebook( config.L, config.K, init=config.code_init, freeze=not config.learnable_codes, ) for _ in range(N) ]) self.read_projections = nn.ParameterList([ nn.Parameter(torch.empty(d_head, config.K)) for _ in range(N) ]) for P in self.read_projections: nn.init.xavier_uniform_(P) self.temperatures = nn.ParameterList([ nn.Parameter(torch.tensor(config.temperature)) for _ in range(N) ]) self.W_o = nn.Linear(D, D, bias=False) self.norm1 = nn.LayerNorm(D) self.norm2 = nn.LayerNorm(D) if config.use_gated_ffn: self.ffn = GatedFFN(D, config.ffn_expansion) else: self.ffn = StandardFFN(D, config.ffn_expansion) self.dropout = nn.Dropout(config.dropout) def forward(self, x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: """Process sequence with RoPE-rotated V and Q. Parameters ---------- x : Tensor [B, T, D] cos : Tensor [T, d_head] — from RotaryEmbedding sin : Tensor [T, d_head] """ B, T, D = x.shape N = self.config.n_heads d_head = self.config.d_head residual = x x_normed = self.norm1(x) V = self.W_v(x_normed) Q = self.W_q(x_normed) V_heads = V.view(B, T, N, d_head).transpose(1, 2) # [B, N, T, d] Q_heads = Q.view(B, T, N, d_head).transpose(1, 2) # Apply RoPE to all heads simultaneously (broadcast over B, N) V_heads = apply_rotary_emb(V_heads, cos, sin) Q_heads = apply_rotary_emb(Q_heads, cos, sin) head_outputs = [] for h in range(N): C_h = self.codebooks[h]() out_h = causal_membrane_attention( V_h=V_heads[:, h], C_h=C_h, Q_h=Q_heads[:, h], P_h=self.read_projections[h], temperature=self.temperatures[h], persistence_lambda=self.config.persistence_lambda, ) head_outputs.append(out_h) out = torch.cat(head_outputs, dim=-1) out = self.W_o(out) out = self.dropout(out) x = residual + out residual = x x = residual + self.dropout(self.ffn(self.norm2(x))) return x def forward_step( self, x_t: Tensor, membrane_states: List[Tensor], step: int, cos_t: Tensor, sin_t: Tensor, ) -> Tuple[Tensor, List[Tensor]]: """Process a single token in sequential (RNN) mode with RoPE.""" B, D = x_t.shape N = self.config.n_heads d_head = self.config.d_head residual = x_t x_normed = self.norm1(x_t) V = self.W_v(x_normed) Q = self.W_q(x_normed) V_heads = V.view(B, N, d_head) Q_heads = Q.view(B, N, d_head) # Apply RoPE for this single position V_heads = V_heads * cos_t + _rotate_half(V_heads) * sin_t Q_heads = Q_heads * cos_t + _rotate_half(Q_heads) * sin_t new_states = [] head_outputs = [] for h in range(N): C_h = self.codebooks[h]() v_h = V_heads[:, h] q_h = Q_heads[:, h] M_h = membrane_states[h] key = step % self.config.K code_vec = C_h[:, key] delta = torch.einsum("l,bd->bld", code_vec, v_h) M_h = M_h + delta Y = torch.einsum("bld,lk->bkd", M_h, C_h) logits = torch.einsum("bd,dk->bk", q_h, self.read_projections[h]) weights = torch.softmax(logits / self.temperatures[h], dim=-1) out_h = torch.einsum("bk,bkd->bd", weights, Y) M_h = M_h * self.config.persistence_lambda new_states.append(M_h) head_outputs.append(out_h) out = torch.cat(head_outputs, dim=-1) out = self.W_o(out) out = self.dropout(out) x_t = residual + out residual = x_t x_t = residual + self.dropout(self.ffn(self.norm2(x_t))) return x_t, new_states def init_membrane_states(self, B: int) -> List[Tensor]: device = self.W_v.weight.device dtype = self.W_v.weight.dtype return [ torch.zeros(B, self.config.L, self.config.d_head, device=device, dtype=dtype) for _ in range(self.config.n_heads) ] # --------------------------------------------------------------------------- # RoPE-enabled Full Model # --------------------------------------------------------------------------- class WrinkleBraneModelRoPE(nn.Module): """WrinkleBrane language model with RoPE positional encoding. Replaces the one-time additive sinusoidal positional encoding with Rotary Position Embeddings applied to V and Q at every layer. Benchmark result (Dir 9 Round 3, 500 steps, Muon, same param count): Sinusoidal PE → eval PPL 12.03 RoPE → eval PPL 11.73 (+1% improvement, free quality win) Same external interface as ``WrinkleBraneModel``: - ``forward(token_ids) -> logits`` - ``forward_sequential(token_ids, states) -> (logits, states)`` - ``ortho_loss() -> scalar`` - ``count_parameters() -> dict`` """ def __init__(self, config: WrinkleBraneConfig): super().__init__() self.config = config self.embedding = nn.Embedding(config.vocab_size, config.d_model) self.rope = RotaryEmbedding(config.d_head, config.max_seq_len) self.embed_dropout = nn.Dropout(config.dropout) self.layers = nn.ModuleList([ WrinkleBraneLayerRoPE(config) for _ in range(config.n_layers) ]) self.output_norm = nn.LayerNorm(config.d_model) self.output_head = nn.Linear(config.d_model, config.vocab_size, bias=False) if config.weight_tying: self.output_head.weight = self.embedding.weight self._init_weights() def _init_weights(self) -> None: nn.init.normal_(self.embedding.weight, std=0.02) def forward(self, token_ids: Tensor) -> Tensor: """Parallel forward pass (training). Parameters ---------- token_ids : Tensor [B, T] Returns ------- Tensor [B, T, vocab_size] """ B, T = token_ids.shape x = self.embedding(token_ids) * math.sqrt(self.config.d_model) x = self.embed_dropout(x) cos, sin = self.rope(T) for layer in self.layers: x = layer(x, cos, sin) x = self.output_norm(x) return self.output_head(x) def forward_sequential( self, token_ids: Tensor, states: Optional[List[List[Tensor]]] = None, ) -> Tuple[Tensor, List[List[Tensor]]]: """Sequential (RNN) forward pass for autoregressive generation. Same interface as ``WrinkleBraneModel.forward_sequential``. """ B, T = token_ids.shape if states is None: states = [layer.init_membrane_states(B) for layer in self.layers] cos_full, sin_full = self.rope(T) outputs = [] for t in range(T): x_t = self.embedding(token_ids[:, t]) * math.sqrt(self.config.d_model) x_t = self.embed_dropout(x_t) cos_t = cos_full[t] sin_t = sin_full[t] for i, layer in enumerate(self.layers): x_t, states[i] = layer.forward_step(x_t, states[i], t, cos_t, sin_t) outputs.append(x_t) x = torch.stack(outputs, dim=1) x = self.output_norm(x) return self.output_head(x), states def ortho_loss(self) -> Tensor: total = torch.tensor(0.0, device=self.embedding.weight.device) for layer in self.layers: for codebook in layer.codebooks: total = total + codebook.ortho_loss() return total def count_parameters(self) -> Dict[str, int]: counts = { "embedding": sum(p.numel() for p in self.embedding.parameters()), "rope": 0, # buffers only, no learnable params "output_head": 0 if self.config.weight_tying else sum( p.numel() for p in self.output_head.parameters() ), "output_norm": sum(p.numel() for p in self.output_norm.parameters()), } layer_params = 0 codebook_params = 0 for layer in self.layers: for name, p in layer.named_parameters(): if "codebook" in name: codebook_params += p.numel() else: layer_params += p.numel() counts["layers"] = layer_params counts["codebooks"] = codebook_params counts["total"] = sum(p.numel() for p in self.parameters()) return counts