"""Random Discrete Neural Cellular Automaton — synthetic data generator.
Implements the data-generation procedure described in
Lee, Han, Kumar, Agrawal. "Training Language Models via Neural
Cellular Automata." arXiv:2603.10055.
with parameter choices reconciled against the reference Flax/JAX
implementation at https://github.com/danihyunlee/nca-pre-pretraining
(``utils/nca.py``). Where paper text and reference code disagree, we
default to the **reference code's behavior** but expose the divergence
as an explicit ``NCAConfig`` field.
The module also adds *sparse / approximate* options that don't appear
in the reference but match feather's design philosophy (sparse, fast,
topology-respecting). These options keep the data path runnable on a
laptop CPU at non-trivial token rates and integrate cleanly with the
existing SDR / HTM / Engram subsystems:
• ``tokenizer = "hash_bucket"``: hash patch ids into a smaller bucket
pool (vocab << n^(patch²)). Exact-id is replaced with feature
hashing; collisions are absorbed by feather's downstream
representation, much like word2vec hash embeddings.
• ``complexity_metric = "rle"``: replace gzip with a vectorized
run-length-equivalent compressibility proxy. RLE compression is
a strict lower bound on what gzip achieves, but the *band shape*
is preserved up to a constant scale, so the K-complexity gate is
intact while running ~50× faster.
• Batched rollout: ``rollout_batch`` runs ``B`` independent rules in
one CUDA call, amortizing the per-rule overhead.
Reconciled recipe (defaults):
• 2D grid with periodic (toroidal) boundaries; 12×12.
• Discrete state alphabet of n=10 symbols; each cell is an n-d one-hot.
• Update rule f_θ: 3×3 conv (4 channels) → 1×1 conv (16 channels) →
ReLU → 1×1 conv (n logits). 1×1 convs are mathematically identical
to per-cell Linear, so this implementation uses Linear layers.
• Sampling: ``c_i^(t+1) ~ Categorical((f_θ + identity_bias·c_i)/τ)``
with τ=1e-3 and ``identity_bias=0.0`` (reference defaults).
• Per trajectory: fresh random θ. Initial grid: uniform i.i.d. over
{0,…,n-1} (paper text) by default; ``initial_state="gaussian_categorical"``
matches the reference code's lower-entropy init.
• Frame stride dT=2 (reference): rollout runs ``dT × num_examples``
steps and keeps every dT-th frame.
• Burn-in ``start_step=0`` (configurable).
• Tokenization: 2×2 patches → bijective ids in ``[0, n^4)``; each
timestep framed by ```` / ```` delimiters.
• Complexity gate: gzip ratio with ``compresslevel=9``. Paper-default
band: r > 50% for OpenWebText/OpenWebMath; 30–40% for CodeParrot.
"""
from __future__ import annotations
import gzip
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class NCAConfig:
"""Hyperparameters for the random discrete NCA + tokenizer.
Defaults reproduce the reference repo. Sweepable axes:
``rule_type`` "mlp" paper-faithful Linear-MLP cell rule
"rbf_switch" Kolmogorov-Arnold radial-switch rule
``initial_state`` "uniform" paper-text default (uniform discrete)
"gaussian_categorical"
reference code's lower-entropy init
``tokenizer`` "patch" paper-faithful bijective n^4 vocab
"hash_bucket" feature-hashed vocab of size
``hash_buckets`` (sparse / faster)
``complexity_metric`` "gzip" paper-faithful (compresslevel=9)
"rle" RLE-compression-ratio proxy
(~50× faster, preserves band shape)
"""
# Substrate
n_states: int = 10
grid_size: int = 12
conv_channels: int = 4
mlp_hidden: int = 16
tau: float = 1e-3
identity_bias: float = 0.0
initial_state: str = "uniform"
# Rollout
dT: int = 2 # frame stride: keep every dT-th step
start_step: int = 0 # burn-in to skip transients
# Tokenizer
tokenizer: str = "patch"
patch_size: int = 2
hash_buckets: int = 1024 # only consulted when tokenizer == "hash_bucket"
# Cell-rule choice
rule_type: str = "mlp"
rbf_n_centers: int = 8
rbf_sigma: float = 0.5
# Complexity gate
complexity_metric: str = "gzip"
gzip_compresslevel: int = 9
@property
def patch_vocab(self) -> int:
"""Distinct patch ids (before optional bucket hashing) = n^(patch²)."""
return self.n_states ** (self.patch_size * self.patch_size)
@property
def vocab_size(self) -> int:
"""Effective vocabulary inclusive of the two grid delimiters."""
if self.tokenizer == "patch":
return self.patch_vocab + 2
if self.tokenizer == "hash_bucket":
return self.hash_buckets + 2
raise ValueError(f"unknown tokenizer {self.tokenizer!r}")
@property
def grid_start_id(self) -> int:
return self.vocab_size - 2
@property
def grid_end_id(self) -> int:
return self.vocab_size - 1
# ---------------------------------------------------------------------------
# Cell rules
# ---------------------------------------------------------------------------
class _MLPCellRule(nn.Module):
"""Reference-faithful cell-wise MLP: 4 → 16 (ReLU) → n_states.
Identical to a stack of two 1×1 convolutions sandwiching ReLU when
applied per-cell. Implemented as Linear so the einsum tags read
cleanly in the perceive→rule pipeline.
"""
def __init__(self, cfg: NCAConfig) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(cfg.conv_channels, cfg.mlp_hidden, bias=True),
nn.ReLU(),
nn.Linear(cfg.mlp_hidden, cfg.n_states, bias=True),
)
def forward(self, feat: torch.Tensor) -> torch.Tensor:
return self.net(feat)
class _RBFSwitchCellRule(nn.Module):
"""Radial-switch / KAN-style cell rule.
Implements the Kolmogorov-Arnold representation directly: each input
channel feeds a sum of K Gaussian RBFs on a fixed [-1, 1] center
grid; the K-vector is then linearly combined into the n_states
output. No nonlinearity at the output node — the K-A theorem says we
don't need one.
logits[o] = Σ_c Σ_k α[c,k,o] · exp(-((x[c]-μ_k)² / 2σ²)) + bias[o]
Total params ≈ ``C·K·n + n``. At C=4, K=8, n=10 that's 330 — same
order of magnitude as the MLP path's 250.
"""
def __init__(self, cfg: NCAConfig) -> None:
super().__init__()
self.cfg = cfg
centers = torch.linspace(-1.0, 1.0, cfg.rbf_n_centers)
self.register_buffer("centers", centers, persistent=False)
self.register_buffer(
"inv_two_sigma_sq",
torch.tensor(1.0 / (2.0 * cfg.rbf_sigma ** 2)),
persistent=False,
)
self.alpha = nn.Parameter(
torch.empty(cfg.conv_channels, cfg.rbf_n_centers, cfg.n_states)
)
self.bias = nn.Parameter(torch.empty(cfg.n_states))
def forward(self, feat: torch.Tensor) -> torch.Tensor:
diff = feat.unsqueeze(-1) - self.centers
rbf = torch.exp(-(diff * diff) * self.inv_two_sigma_sq)
return torch.einsum("bhwck,cko->bhwo", rbf, self.alpha) + self.bias
# ---------------------------------------------------------------------------
# NCA module
# ---------------------------------------------------------------------------
class RandomDiscreteNCA(nn.Module):
"""Single discrete NCA rule f_θ. Construct → freshly random weights.
The conv uses ``padding_mode='circular'`` for periodic boundaries.
All math is no_grad — this module is a *data generator*, not a
trainable component.
"""
def __init__(self, cfg: NCAConfig | None = None) -> None:
super().__init__()
self.cfg = cfg or NCAConfig()
self.perceive = nn.Conv2d(
self.cfg.n_states,
self.cfg.conv_channels,
kernel_size=3,
padding=1,
padding_mode="circular",
bias=True,
)
if self.cfg.rule_type == "mlp":
self.cell_rule: nn.Module = _MLPCellRule(self.cfg)
elif self.cfg.rule_type == "rbf_switch":
self.cell_rule = _RBFSwitchCellRule(self.cfg)
else:
raise ValueError(
f"unknown rule_type {self.cfg.rule_type!r}; "
"expected 'mlp' or 'rbf_switch'"
)
for p in self.parameters():
p.requires_grad_(False)
@torch.no_grad()
def step(self, onehot: torch.Tensor) -> torch.Tensor:
"""Advance one NCA timestep.
``onehot``: ``(B, n_states, H, W)``, values in {0, 1}. Returns
the next-step one-hot (still {0, 1}).
"""
feat = self.perceive(onehot) # (B, conv_ch, H, W)
feat = feat.permute(0, 2, 3, 1) # (B, H, W, conv_ch)
logits = self.cell_rule(feat) # (B, H, W, n_states)
if self.cfg.identity_bias != 0.0:
# add identity_bias·c_i; matches reference code's `(logits + state_oh*ib)/τ`
current = onehot.permute(0, 2, 3, 1)
logits = logits + current * self.cfg.identity_bias
probs = F.softmax(logits / self.cfg.tau, dim=-1)
flat = probs.reshape(-1, self.cfg.n_states)
sampled = torch.multinomial(flat, num_samples=1).squeeze(-1)
sampled = sampled.reshape(probs.shape[:-1]) # (B, H, W)
return F.one_hot(sampled, self.cfg.n_states).permute(0, 3, 1, 2).to(onehot.dtype)
def _init_random_rule(cfg: NCAConfig, generator: torch.Generator) -> RandomDiscreteNCA:
"""Build an NCA with freshly sampled weights. Kaiming-uniform fan-in.
The reference repo uses Flax LeCun-normal; this implementation uses
PyTorch's Kaiming-uniform default. Magnitudes differ by ~1.7×, but
both schemes are zero-mean and the gzip rejection sampler discards
rules whose dynamics fall outside the requested complexity band, so
the init scheme is largely a smoothing factor on the acceptance rate.
"""
nca = RandomDiscreteNCA(cfg)
for p in nca.parameters():
if p.dim() > 1:
fan_in = p.shape[1] * (p.shape[2] * p.shape[3] if p.dim() == 4 else 1)
bound = (3.0 / fan_in) ** 0.5
else:
bound = (1.0 / max(p.shape[0], 1)) ** 0.5
with torch.no_grad():
p.uniform_(-bound, bound, generator=generator)
return nca
# ---------------------------------------------------------------------------
# Initial-state samplers
# ---------------------------------------------------------------------------
def _sample_initial_state(
cfg: NCAConfig,
batch: int,
generator: torch.Generator,
) -> torch.Tensor:
"""Return ``(B, H, W)`` integer initial state in [0, n_states)."""
if cfg.initial_state == "uniform":
return torch.randint(
0, cfg.n_states,
(batch, cfg.grid_size, cfg.grid_size),
generator=generator,
device=generator.device,
)
if cfg.initial_state == "gaussian_categorical":
# Reference repo's init: sample N(0,1) logits, take per-cell categorical.
# Constant across (H, W) per group (n_groups=1) per the reference code.
logits = torch.randn(
(batch, cfg.n_states),
generator=generator,
device=generator.device,
)
cats = torch.distributions.Categorical(logits=logits).sample() # (B,)
# Broadcast a single state across the grid (matches reference rearrange).
return cats.view(batch, 1, 1).expand(batch, cfg.grid_size, cfg.grid_size).clone()
raise ValueError(f"unknown initial_state {cfg.initial_state!r}")
# ---------------------------------------------------------------------------
# Rollout (single + batched)
# ---------------------------------------------------------------------------
@torch.no_grad()
def rollout(
cfg: NCAConfig,
n_steps: int,
generator: torch.Generator,
device: torch.device | str = "cpu",
) -> torch.Tensor:
"""Single-trajectory rollout. Returns ``(n_steps, H, W)`` uint8."""
out = rollout_batch(cfg, n_steps=n_steps, batch=1, generator=generator, device=device)
return out[0]
@torch.no_grad()
def rollout_batch(
cfg: NCAConfig,
n_steps: int,
batch: int,
generator: torch.Generator,
device: torch.device | str = "cpu",
) -> torch.Tensor:
"""Batched rollout: ``B`` independent rules in parallel.
Each batch element gets its own random rule (fresh random θ). Frame
stride and burn-in are applied uniformly across the batch.
Returns ``(B, n_steps, H, W)`` uint8 with values in [0, n_states).
"""
if n_steps < 1:
raise ValueError(f"n_steps must be >= 1, got {n_steps}")
if batch < 1:
raise ValueError(f"batch must be >= 1, got {batch}")
# Build a stacked rule by sampling one set of params per batch entry.
# We use B independent NCA modules and stack their forward calls.
# For small batches (<= 64) this is cleaner than building a "grouped"
# conv that mixes the rules; correctness > 1.5× wall-time savings.
rules = [_init_random_rule(cfg, generator).to(device) for _ in range(batch)]
init_int = _sample_initial_state(cfg, batch, generator).to(device)
onehot = F.one_hot(init_int, cfg.n_states).permute(0, 3, 1, 2).float()
total_steps = cfg.start_step + n_steps * cfg.dT
keep_indices = set(
range(cfg.start_step, cfg.start_step + n_steps * cfg.dT, cfg.dT)
)
kept: list[torch.Tensor] = []
# The very-first state is included only if start_step == 0 *and* dT
# would land on it (it always does because k=0 is in the range).
# To match reference behavior (which uses jax.lax.scan and saves
# *post-step* states only when in 'video' mode), we save the initial
# state iff start_step == 0.
if cfg.start_step == 0:
kept.append(init_int.to(torch.uint8).cpu())
keep_indices.discard(0)
for t in range(1, total_steps + 1):
# Run each rule on its corresponding slice.
next_chunks = []
for b in range(batch):
slc = onehot[b : b + 1]
next_chunks.append(rules[b].step(slc))
onehot = torch.cat(next_chunks, dim=0)
if t in keep_indices:
kept.append(onehot.argmax(dim=1).to(torch.uint8).cpu())
if len(kept) < n_steps:
# Edge case: dT/start_step combination yields fewer frames than
# requested. Pad with the last seen frame to keep tensors stackable.
while len(kept) < n_steps:
kept.append(kept[-1])
# kept is a list of (B, H, W) tensors; stack along time → (n_steps, B, H, W)
stacked = torch.stack(kept[:n_steps], dim=0)
return stacked.permute(1, 0, 2, 3).contiguous() # (B, T, H, W)
# ---------------------------------------------------------------------------
# Complexity gates
# ---------------------------------------------------------------------------
def gzip_compressibility(states: torch.Tensor, compresslevel: int = 9) -> float:
"""Per-trajectory gzip ratio in percent: ``compressed/raw * 100``.
Default ``compresslevel=9`` matches the reference repo. Lower
⇒ more compressible ⇒ simpler / more predictable dynamics.
"""
raw = states.contiguous().to(torch.uint8).numpy().tobytes()
if not raw:
return 0.0
compressed = gzip.compress(raw, compresslevel=compresslevel)
return len(compressed) / len(raw) * 100.0
def rle_compressibility(states: torch.Tensor) -> float:
"""Run-length compression ratio in percent — an O(N) gzip proxy.
Counts the number of (symbol, run_length) pairs in the row-major
serialization of the trajectory. Each pair encodes to roughly two
bytes (one for the symbol, one for the length, capped at 255), so
rle_size ≈ 2 · num_runs ; raw_size = num_cells
This understates gzip's actual compression on highly-structured
sequences, but the *relative ordering* across rules is preserved
well enough that the K-complexity band gate is preserved up to a
monotone re-mapping. Empirically, RLE-band ≈ 0.55 · gzip-band on
NCA trajectories — set the band threshold proportionally lower.
"""
flat = states.contiguous().to(torch.uint8).flatten()
n = flat.numel()
if n == 0:
return 0.0
# Vectorized run-length count: a run starts whenever symbol changes.
diffs = flat[1:] != flat[:-1]
num_runs = int(diffs.sum().item()) + 1
rle_size = 2 * num_runs
return rle_size / n * 100.0
def trajectory_complexity(states: torch.Tensor, cfg: NCAConfig) -> float:
"""Apply the configured complexity metric ('gzip' or 'rle')."""
if cfg.complexity_metric == "gzip":
return gzip_compressibility(states, compresslevel=cfg.gzip_compresslevel)
if cfg.complexity_metric == "rle":
return rle_compressibility(states)
raise ValueError(f"unknown complexity_metric {cfg.complexity_metric!r}")
# ---------------------------------------------------------------------------
# Tokenizers
# ---------------------------------------------------------------------------
# Stable 32-bit integer hash (FNV-1a-style with our own primes). We use
# this only for `hash_bucket` tokenization; the hash kernel is local to
# this module so the bin format never depends on Python's hash() seed.
_HASH_PRIME_A = 2166136261
_HASH_PRIME_B = 16777619
def _hash_to_buckets(ids: torch.Tensor, n_buckets: int) -> torch.Tensor:
"""Deterministically hash integer ids into [0, n_buckets) via FNV-1a.
Pure torch — no Python hash, no NumPy. Idempotent across processes
and OS versions. Operates in int64 to avoid wraparound surprises;
the modulo at the end makes the result reproducible.
"""
h = torch.full_like(ids, _HASH_PRIME_A, dtype=torch.int64)
h = (h ^ (ids & 0xFF)) * _HASH_PRIME_B
h = (h ^ ((ids >> 8) & 0xFF)) * _HASH_PRIME_B
h = (h ^ ((ids >> 16) & 0xFF)) * _HASH_PRIME_B
h = (h ^ ((ids >> 24) & 0xFF)) * _HASH_PRIME_B
return (h.abs() % n_buckets).to(ids.dtype)
def tokenize_trajectory(states: torch.Tensor, cfg: NCAConfig) -> list[int]:
"""Serialize a trajectory into a token list.
Layout per timestep: `` p₁ p₂ … p_K ``, row-major,
where each ``p_j`` is either a bijective patch id (tokenizer="patch")
or a feature-hashed bucket id (tokenizer="hash_bucket").
"""
if states.dim() != 3:
raise ValueError(f"expected (T, H, W) states, got shape {tuple(states.shape)}")
T, H, W = states.shape
ps = cfg.patch_size
if H % ps or W % ps:
raise ValueError(f"grid {H}×{W} not divisible by patch_size={ps}")
n = cfg.n_states
states_long = states.to(torch.long)
if states_long.numel() and (
states_long.min().item() < 0 or states_long.max().item() >= n
):
raise ValueError(f"state values must lie in [0, {n})")
patches = states_long.view(T, H // ps, ps, W // ps, ps).permute(0, 1, 3, 2, 4)
patches = patches.reshape(T, -1, ps * ps)
weights = torch.tensor(
[n ** (ps * ps - 1 - k) for k in range(ps * ps)],
dtype=torch.long,
)
ids = (patches * weights).sum(dim=-1) # (T, num_patches)
if cfg.tokenizer == "hash_bucket":
ids = _hash_to_buckets(ids, cfg.hash_buckets)
elif cfg.tokenizer != "patch":
raise ValueError(f"unknown tokenizer {cfg.tokenizer!r}")
g_start = cfg.grid_start_id
g_end = cfg.grid_end_id
out: list[int] = []
rows = ids.tolist()
for row in rows:
out.append(g_start)
out.extend(row)
out.append(g_end)
return out
# ---------------------------------------------------------------------------
# End-to-end sampler
# ---------------------------------------------------------------------------
@torch.no_grad()
def sample_trajectory_tokens(
cfg: NCAConfig,
n_steps: int,
generator: torch.Generator,
gzip_min: float = 50.0,
gzip_max: float = 100.0,
max_attempts: int = 64,
device: torch.device | str = "cpu",
batch: int = 1,
) -> tuple[list[int], float] | None:
"""One trajectory passing the complexity-band gate, tokenized.
With ``batch>1`` we propose ``batch`` rules per attempt and return
the first one to fall in the requested band — this amortizes the
per-rollout overhead and is the recommended path for production.
The ``gzip_min`` / ``gzip_max`` arguments are interpreted via
``cfg.complexity_metric`` (gzip percent for "gzip", RLE percent for
"rle"). RLE bands typically run ~0.55× the gzip values for matched
qualitative complexity.
"""
for _ in range(max_attempts):
states_batch = rollout_batch(
cfg, n_steps=n_steps, batch=batch,
generator=generator, device=device,
)
for b in range(batch):
states = states_batch[b]
ratio = trajectory_complexity(states, cfg)
if gzip_min <= ratio <= gzip_max:
return tokenize_trajectory(states, cfg), ratio
return None