nca3d-brain-v5 / model.py
killking69's picture
Upload folder using huggingface_hub
b973c61 verified
Raw
History Blame Contribute Delete
5.49 kB
"""
NCA 3D Brain — Neural Cellular Automata for Language
Model definition for v5 (synaptic fatigue + dilated convolutions)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# Model constants
GRID = 16
CELL_DIM = 256
EMBED_DIM = 384
TOTAL_VOCAB = 30006
MAX_WORDS = 50
DILATION_CYCLE = [1, 2, 4, 8]
FATIGUE_DECAY = 0.9
FATIGUE_CHANNELS = 2
# Special token IDs
PAD_ID = 30000
EOS_ID = 30001
PUNCT_MAP = {'.': 30002, ',': 30003, '?': 30004, '!': 30005}
PUNCT_INV = {v: k for k, v in PUNCT_MAP.items()}
class NCA3D_Fatigue(nn.Module):
"""
Neural Cellular Automaton 3D with synaptic fatigue.
A 16x16x16 grid of cells that communicate only with neighbors.
Information propagates as waves from input face (z=0) to output face (z=15).
Dilated convolutions with cycle [1,2,4,8] give global reach in 4 steps.
Synaptic fatigue prevents repetitive firing patterns.
"""
def __init__(self):
super().__init__()
self.grid = GRID
self.cell_dim = CELL_DIM
self.vocab_size = TOTAL_VOCAB
self.fatigue_channels = FATIGUE_CHANNELS
self.fatigue_decay = FATIGUE_DECAY
# Token embedding: word → 384d vector
self.word_embed = nn.Embedding(TOTAL_VOCAB, EMBED_DIM)
# Project embedding to cell dimension
self.embed_proj = nn.Linear(EMBED_DIM, CELL_DIM)
# Positional encoding for token order
self.pos_embed = nn.Embedding(MAX_WORDS + 2, CELL_DIM)
# Initial state of the 3D grid (learned)
self.init_state = nn.Parameter(
torch.randn(1, CELL_DIM, GRID, GRID, GRID) * 0.01
)
# Pathway 1: Full interaction (Conv3d standard)
self.trans1 = nn.Sequential(
nn.Conv3d(CELL_DIM, CELL_DIM * 2, 3, padding=1, bias=False),
nn.SiLU(),
nn.Conv3d(CELL_DIM * 2, CELL_DIM, 3, padding=1, bias=False),
)
# Pathway 2: Depthwise separable (efficient)
self.trans2 = nn.Sequential(
nn.Conv3d(CELL_DIM, CELL_DIM, 3, padding=1, groups=CELL_DIM, bias=False),
nn.SiLU(),
nn.Conv3d(CELL_DIM, CELL_DIM, 1, bias=False),
)
# Gate: controls how much each cell changes
self.gate_conv = nn.Conv3d(CELL_DIM, CELL_DIM, 1, bias=False)
# Normalization
self.norm = nn.GroupNorm(32, CELL_DIM)
# Output projection: cell state → vocabulary logits
self.out_proj = nn.Sequential(
nn.Linear(CELL_DIM, CELL_DIM * 2),
nn.SiLU(),
nn.Linear(CELL_DIM * 2, TOTAL_VOCAB),
)
def inject(self, state, word_ids):
"""Inject token embeddings into the input face (z=0) of the cube."""
B, L = word_ids.shape
max_pos = self.pos_embed.num_embeddings
for pos in range(min(L, max_pos)):
row = pos // self.grid
col = pos % self.grid
if row >= self.grid:
break
vec = self.embed_proj(self.word_embed(word_ids[:, pos]))
pv = self.pos_embed(torch.full((B,), pos, dtype=torch.long, device=word_ids.device))
state[:, :, col, row, 0] = state[:, :, col, row, 0] + vec + pv
return state
def one_step(self, state, step):
"""One step of wave propagation through the 3D grid."""
# Dilated convolution — cycle [1, 2, 4, 8] for global reach
d = DILATION_CYCLE[step % len(DILATION_CYCLE)]
p = d # padding = dilation for same-size output
# Pathway 1: Full Conv3d with dilation
d1 = F.conv3d(state, self.trans1[0].weight, padding=p, dilation=d)
d1 = F.silu(d1)
d1 = F.conv3d(d1, self.trans1[2].weight, padding=p, dilation=d)
# Pathway 2: Depthwise separable with dilation
d2 = F.conv3d(state, self.trans2[0].weight, padding=p, dilation=d, groups=self.cell_dim)
d2 = F.silu(d2)
d2 = F.conv3d(d2, self.trans2[2].weight)
# Gate: sigmoid controls update magnitude
g = torch.sigmoid(self.gate_conv(state))
# Synaptic fatigue: penalize cells that fire too much
fatigue = state[:, -self.fatigue_channels:, :, :, :]
fatigue_penalty = 1.0 - torch.sigmoid(fatigue.mean(dim=1, keepdim=True))
g = g * fatigue_penalty
# Residual update
delta = d1 + d2
state = state + g * delta
state = self.norm(state)
# Update fatigue channels
gate_intensity = g.detach().mean(dim=1, keepdim=True)
new_fatigue = fatigue * self.fatigue_decay + gate_intensity * 0.1
state = torch.cat(
[state[:, : -self.fatigue_channels, :, :, :], new_fatigue], dim=1
)
return state
def forward(self, word_ids, n_steps=15):
"""
Forward pass: inject tokens → propagate N steps → read output.
Args:
word_ids: (batch, seq_len) tensor of word IDs
n_steps: number of propagation steps (more = deeper thinking)
Returns:
logits: (batch, vocab_size) next-word prediction logits
"""
B = word_ids.shape[0]
state = self.init_state.expand(B, -1, -1, -1, -1).clone()
state = self.inject(state, word_ids)
for step in range(n_steps):
state = self.one_step(state, step)
# Read from output face (z=15), average pool
out_face = state[:, :, :, :, -1]
pooled = out_face.mean(dim=(-1, -2))
return self.out_proj(pooled)