""" NCA 3D Brain — Neural Cellular Automata for Language Model definition for v5 (synaptic fatigue + dilated convolutions) """ import torch import torch.nn as nn import torch.nn.functional as F # Model constants GRID = 16 CELL_DIM = 256 EMBED_DIM = 384 TOTAL_VOCAB = 30006 MAX_WORDS = 50 DILATION_CYCLE = [1, 2, 4, 8] FATIGUE_DECAY = 0.9 FATIGUE_CHANNELS = 2 # Special token IDs PAD_ID = 30000 EOS_ID = 30001 PUNCT_MAP = {'.': 30002, ',': 30003, '?': 30004, '!': 30005} PUNCT_INV = {v: k for k, v in PUNCT_MAP.items()} class NCA3D_Fatigue(nn.Module): """ Neural Cellular Automaton 3D with synaptic fatigue. A 16x16x16 grid of cells that communicate only with neighbors. Information propagates as waves from input face (z=0) to output face (z=15). Dilated convolutions with cycle [1,2,4,8] give global reach in 4 steps. Synaptic fatigue prevents repetitive firing patterns. """ def __init__(self): super().__init__() self.grid = GRID self.cell_dim = CELL_DIM self.vocab_size = TOTAL_VOCAB self.fatigue_channels = FATIGUE_CHANNELS self.fatigue_decay = FATIGUE_DECAY # Token embedding: word → 384d vector self.word_embed = nn.Embedding(TOTAL_VOCAB, EMBED_DIM) # Project embedding to cell dimension self.embed_proj = nn.Linear(EMBED_DIM, CELL_DIM) # Positional encoding for token order self.pos_embed = nn.Embedding(MAX_WORDS + 2, CELL_DIM) # Initial state of the 3D grid (learned) self.init_state = nn.Parameter( torch.randn(1, CELL_DIM, GRID, GRID, GRID) * 0.01 ) # Pathway 1: Full interaction (Conv3d standard) self.trans1 = nn.Sequential( nn.Conv3d(CELL_DIM, CELL_DIM * 2, 3, padding=1, bias=False), nn.SiLU(), nn.Conv3d(CELL_DIM * 2, CELL_DIM, 3, padding=1, bias=False), ) # Pathway 2: Depthwise separable (efficient) self.trans2 = nn.Sequential( nn.Conv3d(CELL_DIM, CELL_DIM, 3, padding=1, groups=CELL_DIM, bias=False), nn.SiLU(), nn.Conv3d(CELL_DIM, CELL_DIM, 1, bias=False), ) # Gate: controls how much each cell changes self.gate_conv = nn.Conv3d(CELL_DIM, CELL_DIM, 1, bias=False) # Normalization self.norm = nn.GroupNorm(32, CELL_DIM) # Output projection: cell state → vocabulary logits self.out_proj = nn.Sequential( nn.Linear(CELL_DIM, CELL_DIM * 2), nn.SiLU(), nn.Linear(CELL_DIM * 2, TOTAL_VOCAB), ) def inject(self, state, word_ids): """Inject token embeddings into the input face (z=0) of the cube.""" B, L = word_ids.shape max_pos = self.pos_embed.num_embeddings for pos in range(min(L, max_pos)): row = pos // self.grid col = pos % self.grid if row >= self.grid: break vec = self.embed_proj(self.word_embed(word_ids[:, pos])) pv = self.pos_embed(torch.full((B,), pos, dtype=torch.long, device=word_ids.device)) state[:, :, col, row, 0] = state[:, :, col, row, 0] + vec + pv return state def one_step(self, state, step): """One step of wave propagation through the 3D grid.""" # Dilated convolution — cycle [1, 2, 4, 8] for global reach d = DILATION_CYCLE[step % len(DILATION_CYCLE)] p = d # padding = dilation for same-size output # Pathway 1: Full Conv3d with dilation d1 = F.conv3d(state, self.trans1[0].weight, padding=p, dilation=d) d1 = F.silu(d1) d1 = F.conv3d(d1, self.trans1[2].weight, padding=p, dilation=d) # Pathway 2: Depthwise separable with dilation d2 = F.conv3d(state, self.trans2[0].weight, padding=p, dilation=d, groups=self.cell_dim) d2 = F.silu(d2) d2 = F.conv3d(d2, self.trans2[2].weight) # Gate: sigmoid controls update magnitude g = torch.sigmoid(self.gate_conv(state)) # Synaptic fatigue: penalize cells that fire too much fatigue = state[:, -self.fatigue_channels:, :, :, :] fatigue_penalty = 1.0 - torch.sigmoid(fatigue.mean(dim=1, keepdim=True)) g = g * fatigue_penalty # Residual update delta = d1 + d2 state = state + g * delta state = self.norm(state) # Update fatigue channels gate_intensity = g.detach().mean(dim=1, keepdim=True) new_fatigue = fatigue * self.fatigue_decay + gate_intensity * 0.1 state = torch.cat( [state[:, : -self.fatigue_channels, :, :, :], new_fatigue], dim=1 ) return state def forward(self, word_ids, n_steps=15): """ Forward pass: inject tokens → propagate N steps → read output. Args: word_ids: (batch, seq_len) tensor of word IDs n_steps: number of propagation steps (more = deeper thinking) Returns: logits: (batch, vocab_size) next-word prediction logits """ B = word_ids.shape[0] state = self.init_state.expand(B, -1, -1, -1, -1).clone() state = self.inject(state, word_ids) for step in range(n_steps): state = self.one_step(state, step) # Read from output face (z=15), average pool out_face = state[:, :, :, :, -1] pooled = out_face.mean(dim=(-1, -2)) return self.out_proj(pooled)