| """ |
| NCA 3D Brain — Neural Cellular Automata for Language |
| Model definition for v5 (synaptic fatigue + dilated convolutions) |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| GRID = 16 |
| CELL_DIM = 256 |
| EMBED_DIM = 384 |
| TOTAL_VOCAB = 30006 |
| MAX_WORDS = 50 |
| DILATION_CYCLE = [1, 2, 4, 8] |
| FATIGUE_DECAY = 0.9 |
| FATIGUE_CHANNELS = 2 |
|
|
| |
| PAD_ID = 30000 |
| EOS_ID = 30001 |
| PUNCT_MAP = {'.': 30002, ',': 30003, '?': 30004, '!': 30005} |
| PUNCT_INV = {v: k for k, v in PUNCT_MAP.items()} |
|
|
|
|
| class NCA3D_Fatigue(nn.Module): |
| """ |
| Neural Cellular Automaton 3D with synaptic fatigue. |
| |
| A 16x16x16 grid of cells that communicate only with neighbors. |
| Information propagates as waves from input face (z=0) to output face (z=15). |
| Dilated convolutions with cycle [1,2,4,8] give global reach in 4 steps. |
| Synaptic fatigue prevents repetitive firing patterns. |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| self.grid = GRID |
| self.cell_dim = CELL_DIM |
| self.vocab_size = TOTAL_VOCAB |
| self.fatigue_channels = FATIGUE_CHANNELS |
| self.fatigue_decay = FATIGUE_DECAY |
|
|
| |
| self.word_embed = nn.Embedding(TOTAL_VOCAB, EMBED_DIM) |
| |
| self.embed_proj = nn.Linear(EMBED_DIM, CELL_DIM) |
| |
| self.pos_embed = nn.Embedding(MAX_WORDS + 2, CELL_DIM) |
| |
| self.init_state = nn.Parameter( |
| torch.randn(1, CELL_DIM, GRID, GRID, GRID) * 0.01 |
| ) |
|
|
| |
| self.trans1 = nn.Sequential( |
| nn.Conv3d(CELL_DIM, CELL_DIM * 2, 3, padding=1, bias=False), |
| nn.SiLU(), |
| nn.Conv3d(CELL_DIM * 2, CELL_DIM, 3, padding=1, bias=False), |
| ) |
| |
| self.trans2 = nn.Sequential( |
| nn.Conv3d(CELL_DIM, CELL_DIM, 3, padding=1, groups=CELL_DIM, bias=False), |
| nn.SiLU(), |
| nn.Conv3d(CELL_DIM, CELL_DIM, 1, bias=False), |
| ) |
| |
| self.gate_conv = nn.Conv3d(CELL_DIM, CELL_DIM, 1, bias=False) |
| |
| self.norm = nn.GroupNorm(32, CELL_DIM) |
| |
| self.out_proj = nn.Sequential( |
| nn.Linear(CELL_DIM, CELL_DIM * 2), |
| nn.SiLU(), |
| nn.Linear(CELL_DIM * 2, TOTAL_VOCAB), |
| ) |
|
|
| def inject(self, state, word_ids): |
| """Inject token embeddings into the input face (z=0) of the cube.""" |
| B, L = word_ids.shape |
| max_pos = self.pos_embed.num_embeddings |
| for pos in range(min(L, max_pos)): |
| row = pos // self.grid |
| col = pos % self.grid |
| if row >= self.grid: |
| break |
| vec = self.embed_proj(self.word_embed(word_ids[:, pos])) |
| pv = self.pos_embed(torch.full((B,), pos, dtype=torch.long, device=word_ids.device)) |
| state[:, :, col, row, 0] = state[:, :, col, row, 0] + vec + pv |
| return state |
|
|
| def one_step(self, state, step): |
| """One step of wave propagation through the 3D grid.""" |
| |
| d = DILATION_CYCLE[step % len(DILATION_CYCLE)] |
| p = d |
|
|
| |
| d1 = F.conv3d(state, self.trans1[0].weight, padding=p, dilation=d) |
| d1 = F.silu(d1) |
| d1 = F.conv3d(d1, self.trans1[2].weight, padding=p, dilation=d) |
|
|
| |
| d2 = F.conv3d(state, self.trans2[0].weight, padding=p, dilation=d, groups=self.cell_dim) |
| d2 = F.silu(d2) |
| d2 = F.conv3d(d2, self.trans2[2].weight) |
|
|
| |
| g = torch.sigmoid(self.gate_conv(state)) |
|
|
| |
| fatigue = state[:, -self.fatigue_channels:, :, :, :] |
| fatigue_penalty = 1.0 - torch.sigmoid(fatigue.mean(dim=1, keepdim=True)) |
| g = g * fatigue_penalty |
|
|
| |
| delta = d1 + d2 |
| state = state + g * delta |
| state = self.norm(state) |
|
|
| |
| gate_intensity = g.detach().mean(dim=1, keepdim=True) |
| new_fatigue = fatigue * self.fatigue_decay + gate_intensity * 0.1 |
| state = torch.cat( |
| [state[:, : -self.fatigue_channels, :, :, :], new_fatigue], dim=1 |
| ) |
| return state |
|
|
| def forward(self, word_ids, n_steps=15): |
| """ |
| Forward pass: inject tokens → propagate N steps → read output. |
| |
| Args: |
| word_ids: (batch, seq_len) tensor of word IDs |
| n_steps: number of propagation steps (more = deeper thinking) |
| |
| Returns: |
| logits: (batch, vocab_size) next-word prediction logits |
| """ |
| B = word_ids.shape[0] |
| state = self.init_state.expand(B, -1, -1, -1, -1).clone() |
| state = self.inject(state, word_ids) |
| for step in range(n_steps): |
| state = self.one_step(state, step) |
| |
| out_face = state[:, :, :, :, -1] |
| pooled = out_face.mean(dim=(-1, -2)) |
| return self.out_proj(pooled) |
|
|