| """Random Discrete Neural Cellular Automaton — synthetic data generator. |
| |
| Implements the data-generation procedure described in |
| |
| Lee, Han, Kumar, Agrawal. "Training Language Models via Neural |
| Cellular Automata." arXiv:2603.10055. |
| |
| with parameter choices reconciled against the reference Flax/JAX |
| implementation at https://github.com/danihyunlee/nca-pre-pretraining |
| (``utils/nca.py``). Where paper text and reference code disagree, we |
| default to the **reference code's behavior** but expose the divergence |
| as an explicit ``NCAConfig`` field. |
| |
| The module also adds *sparse / approximate* options that don't appear |
| in the reference but match feather's design philosophy (sparse, fast, |
| topology-respecting). These options keep the data path runnable on a |
| laptop CPU at non-trivial token rates and integrate cleanly with the |
| existing SDR / HTM / Engram subsystems: |
| |
| • ``tokenizer = "hash_bucket"``: hash patch ids into a smaller bucket |
| pool (vocab << n^(patch²)). Exact-id is replaced with feature |
| hashing; collisions are absorbed by feather's downstream |
| representation, much like word2vec hash embeddings. |
| • ``complexity_metric = "rle"``: replace gzip with a vectorized |
| run-length-equivalent compressibility proxy. RLE compression is |
| a strict lower bound on what gzip achieves, but the *band shape* |
| is preserved up to a constant scale, so the K-complexity gate is |
| intact while running ~50× faster. |
| • Batched rollout: ``rollout_batch`` runs ``B`` independent rules in |
| one CUDA call, amortizing the per-rule overhead. |
| |
| Reconciled recipe (defaults): |
| |
| • 2D grid with periodic (toroidal) boundaries; 12×12. |
| • Discrete state alphabet of n=10 symbols; each cell is an n-d one-hot. |
| • Update rule f_θ: 3×3 conv (4 channels) → 1×1 conv (16 channels) → |
| ReLU → 1×1 conv (n logits). 1×1 convs are mathematically identical |
| to per-cell Linear, so this implementation uses Linear layers. |
| • Sampling: ``c_i^(t+1) ~ Categorical((f_θ + identity_bias·c_i)/τ)`` |
| with τ=1e-3 and ``identity_bias=0.0`` (reference defaults). |
| • Per trajectory: fresh random θ. Initial grid: uniform i.i.d. over |
| {0,…,n-1} (paper text) by default; ``initial_state="gaussian_categorical"`` |
| matches the reference code's lower-entropy init. |
| • Frame stride dT=2 (reference): rollout runs ``dT × num_examples`` |
| steps and keeps every dT-th frame. |
| • Burn-in ``start_step=0`` (configurable). |
| • Tokenization: 2×2 patches → bijective ids in ``[0, n^4)``; each |
| timestep framed by ``<grid>`` / ``</grid>`` delimiters. |
| • Complexity gate: gzip ratio with ``compresslevel=9``. Paper-default |
| band: r > 50% for OpenWebText/OpenWebMath; 30–40% for CodeParrot. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import gzip |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| @dataclass(frozen=True) |
| class NCAConfig: |
| """Hyperparameters for the random discrete NCA + tokenizer. |
| |
| Defaults reproduce the reference repo. Sweepable axes: |
| |
| ``rule_type`` "mlp" paper-faithful Linear-MLP cell rule |
| "rbf_switch" Kolmogorov-Arnold radial-switch rule |
| ``initial_state`` "uniform" paper-text default (uniform discrete) |
| "gaussian_categorical" |
| reference code's lower-entropy init |
| ``tokenizer`` "patch" paper-faithful bijective n^4 vocab |
| "hash_bucket" feature-hashed vocab of size |
| ``hash_buckets`` (sparse / faster) |
| ``complexity_metric`` "gzip" paper-faithful (compresslevel=9) |
| "rle" RLE-compression-ratio proxy |
| (~50× faster, preserves band shape) |
| """ |
|
|
| |
| n_states: int = 10 |
| grid_size: int = 12 |
| conv_channels: int = 4 |
| mlp_hidden: int = 16 |
| tau: float = 1e-3 |
| identity_bias: float = 0.0 |
| initial_state: str = "uniform" |
|
|
| |
| dT: int = 2 |
| start_step: int = 0 |
|
|
| |
| tokenizer: str = "patch" |
| patch_size: int = 2 |
| hash_buckets: int = 1024 |
|
|
| |
| rule_type: str = "mlp" |
| rbf_n_centers: int = 8 |
| rbf_sigma: float = 0.5 |
|
|
| |
| complexity_metric: str = "gzip" |
| gzip_compresslevel: int = 9 |
|
|
| @property |
| def patch_vocab(self) -> int: |
| """Distinct patch ids (before optional bucket hashing) = n^(patch²).""" |
| return self.n_states ** (self.patch_size * self.patch_size) |
|
|
| @property |
| def vocab_size(self) -> int: |
| """Effective vocabulary inclusive of the two grid delimiters.""" |
| if self.tokenizer == "patch": |
| return self.patch_vocab + 2 |
| if self.tokenizer == "hash_bucket": |
| return self.hash_buckets + 2 |
| raise ValueError(f"unknown tokenizer {self.tokenizer!r}") |
|
|
| @property |
| def grid_start_id(self) -> int: |
| return self.vocab_size - 2 |
|
|
| @property |
| def grid_end_id(self) -> int: |
| return self.vocab_size - 1 |
|
|
|
|
| |
| |
| |
|
|
| class _MLPCellRule(nn.Module): |
| """Reference-faithful cell-wise MLP: 4 → 16 (ReLU) → n_states. |
| |
| Identical to a stack of two 1×1 convolutions sandwiching ReLU when |
| applied per-cell. Implemented as Linear so the einsum tags read |
| cleanly in the perceive→rule pipeline. |
| """ |
|
|
| def __init__(self, cfg: NCAConfig) -> None: |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(cfg.conv_channels, cfg.mlp_hidden, bias=True), |
| nn.ReLU(), |
| nn.Linear(cfg.mlp_hidden, cfg.n_states, bias=True), |
| ) |
|
|
| def forward(self, feat: torch.Tensor) -> torch.Tensor: |
| return self.net(feat) |
|
|
|
|
| class _RBFSwitchCellRule(nn.Module): |
| """Radial-switch / KAN-style cell rule. |
| |
| Implements the Kolmogorov-Arnold representation directly: each input |
| channel feeds a sum of K Gaussian RBFs on a fixed [-1, 1] center |
| grid; the K-vector is then linearly combined into the n_states |
| output. No nonlinearity at the output node — the K-A theorem says we |
| don't need one. |
| |
| logits[o] = Σ_c Σ_k α[c,k,o] · exp(-((x[c]-μ_k)² / 2σ²)) + bias[o] |
| |
| Total params ≈ ``C·K·n + n``. At C=4, K=8, n=10 that's 330 — same |
| order of magnitude as the MLP path's 250. |
| """ |
|
|
| def __init__(self, cfg: NCAConfig) -> None: |
| super().__init__() |
| self.cfg = cfg |
| centers = torch.linspace(-1.0, 1.0, cfg.rbf_n_centers) |
| self.register_buffer("centers", centers, persistent=False) |
| self.register_buffer( |
| "inv_two_sigma_sq", |
| torch.tensor(1.0 / (2.0 * cfg.rbf_sigma ** 2)), |
| persistent=False, |
| ) |
| self.alpha = nn.Parameter( |
| torch.empty(cfg.conv_channels, cfg.rbf_n_centers, cfg.n_states) |
| ) |
| self.bias = nn.Parameter(torch.empty(cfg.n_states)) |
|
|
| def forward(self, feat: torch.Tensor) -> torch.Tensor: |
| diff = feat.unsqueeze(-1) - self.centers |
| rbf = torch.exp(-(diff * diff) * self.inv_two_sigma_sq) |
| return torch.einsum("bhwck,cko->bhwo", rbf, self.alpha) + self.bias |
|
|
|
|
| |
| |
| |
|
|
| class RandomDiscreteNCA(nn.Module): |
| """Single discrete NCA rule f_θ. Construct → freshly random weights. |
| |
| The conv uses ``padding_mode='circular'`` for periodic boundaries. |
| All math is no_grad — this module is a *data generator*, not a |
| trainable component. |
| """ |
|
|
| def __init__(self, cfg: NCAConfig | None = None) -> None: |
| super().__init__() |
| self.cfg = cfg or NCAConfig() |
| self.perceive = nn.Conv2d( |
| self.cfg.n_states, |
| self.cfg.conv_channels, |
| kernel_size=3, |
| padding=1, |
| padding_mode="circular", |
| bias=True, |
| ) |
| if self.cfg.rule_type == "mlp": |
| self.cell_rule: nn.Module = _MLPCellRule(self.cfg) |
| elif self.cfg.rule_type == "rbf_switch": |
| self.cell_rule = _RBFSwitchCellRule(self.cfg) |
| else: |
| raise ValueError( |
| f"unknown rule_type {self.cfg.rule_type!r}; " |
| "expected 'mlp' or 'rbf_switch'" |
| ) |
| for p in self.parameters(): |
| p.requires_grad_(False) |
|
|
| @torch.no_grad() |
| def step(self, onehot: torch.Tensor) -> torch.Tensor: |
| """Advance one NCA timestep. |
| |
| ``onehot``: ``(B, n_states, H, W)``, values in {0, 1}. Returns |
| the next-step one-hot (still {0, 1}). |
| """ |
| feat = self.perceive(onehot) |
| feat = feat.permute(0, 2, 3, 1) |
| logits = self.cell_rule(feat) |
| if self.cfg.identity_bias != 0.0: |
| |
| current = onehot.permute(0, 2, 3, 1) |
| logits = logits + current * self.cfg.identity_bias |
| probs = F.softmax(logits / self.cfg.tau, dim=-1) |
| flat = probs.reshape(-1, self.cfg.n_states) |
| sampled = torch.multinomial(flat, num_samples=1).squeeze(-1) |
| sampled = sampled.reshape(probs.shape[:-1]) |
| return F.one_hot(sampled, self.cfg.n_states).permute(0, 3, 1, 2).to(onehot.dtype) |
|
|
|
|
| def _init_random_rule(cfg: NCAConfig, generator: torch.Generator) -> RandomDiscreteNCA: |
| """Build an NCA with freshly sampled weights. Kaiming-uniform fan-in. |
| |
| The reference repo uses Flax LeCun-normal; this implementation uses |
| PyTorch's Kaiming-uniform default. Magnitudes differ by ~1.7×, but |
| both schemes are zero-mean and the gzip rejection sampler discards |
| rules whose dynamics fall outside the requested complexity band, so |
| the init scheme is largely a smoothing factor on the acceptance rate. |
| """ |
| nca = RandomDiscreteNCA(cfg) |
| for p in nca.parameters(): |
| if p.dim() > 1: |
| fan_in = p.shape[1] * (p.shape[2] * p.shape[3] if p.dim() == 4 else 1) |
| bound = (3.0 / fan_in) ** 0.5 |
| else: |
| bound = (1.0 / max(p.shape[0], 1)) ** 0.5 |
| with torch.no_grad(): |
| p.uniform_(-bound, bound, generator=generator) |
| return nca |
|
|
|
|
| |
| |
| |
|
|
| def _sample_initial_state( |
| cfg: NCAConfig, |
| batch: int, |
| generator: torch.Generator, |
| ) -> torch.Tensor: |
| """Return ``(B, H, W)`` integer initial state in [0, n_states).""" |
| if cfg.initial_state == "uniform": |
| return torch.randint( |
| 0, cfg.n_states, |
| (batch, cfg.grid_size, cfg.grid_size), |
| generator=generator, |
| device=generator.device, |
| ) |
| if cfg.initial_state == "gaussian_categorical": |
| |
| |
| logits = torch.randn( |
| (batch, cfg.n_states), |
| generator=generator, |
| device=generator.device, |
| ) |
| cats = torch.distributions.Categorical(logits=logits).sample() |
| |
| return cats.view(batch, 1, 1).expand(batch, cfg.grid_size, cfg.grid_size).clone() |
| raise ValueError(f"unknown initial_state {cfg.initial_state!r}") |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def rollout( |
| cfg: NCAConfig, |
| n_steps: int, |
| generator: torch.Generator, |
| device: torch.device | str = "cpu", |
| ) -> torch.Tensor: |
| """Single-trajectory rollout. Returns ``(n_steps, H, W)`` uint8.""" |
| out = rollout_batch(cfg, n_steps=n_steps, batch=1, generator=generator, device=device) |
| return out[0] |
|
|
|
|
| @torch.no_grad() |
| def rollout_batch( |
| cfg: NCAConfig, |
| n_steps: int, |
| batch: int, |
| generator: torch.Generator, |
| device: torch.device | str = "cpu", |
| ) -> torch.Tensor: |
| """Batched rollout: ``B`` independent rules in parallel. |
| |
| Each batch element gets its own random rule (fresh random θ). Frame |
| stride and burn-in are applied uniformly across the batch. |
| |
| Returns ``(B, n_steps, H, W)`` uint8 with values in [0, n_states). |
| """ |
| if n_steps < 1: |
| raise ValueError(f"n_steps must be >= 1, got {n_steps}") |
| if batch < 1: |
| raise ValueError(f"batch must be >= 1, got {batch}") |
|
|
| |
| |
| |
| |
| rules = [_init_random_rule(cfg, generator).to(device) for _ in range(batch)] |
| init_int = _sample_initial_state(cfg, batch, generator).to(device) |
| onehot = F.one_hot(init_int, cfg.n_states).permute(0, 3, 1, 2).float() |
|
|
| total_steps = cfg.start_step + n_steps * cfg.dT |
| keep_indices = set( |
| range(cfg.start_step, cfg.start_step + n_steps * cfg.dT, cfg.dT) |
| ) |
| kept: list[torch.Tensor] = [] |
|
|
| |
| |
| |
| |
| |
| if cfg.start_step == 0: |
| kept.append(init_int.to(torch.uint8).cpu()) |
| keep_indices.discard(0) |
|
|
| for t in range(1, total_steps + 1): |
| |
| next_chunks = [] |
| for b in range(batch): |
| slc = onehot[b : b + 1] |
| next_chunks.append(rules[b].step(slc)) |
| onehot = torch.cat(next_chunks, dim=0) |
| if t in keep_indices: |
| kept.append(onehot.argmax(dim=1).to(torch.uint8).cpu()) |
|
|
| if len(kept) < n_steps: |
| |
| |
| while len(kept) < n_steps: |
| kept.append(kept[-1]) |
|
|
| |
| stacked = torch.stack(kept[:n_steps], dim=0) |
| return stacked.permute(1, 0, 2, 3).contiguous() |
|
|
|
|
| |
| |
| |
|
|
| def gzip_compressibility(states: torch.Tensor, compresslevel: int = 9) -> float: |
| """Per-trajectory gzip ratio in percent: ``compressed/raw * 100``. |
| |
| Default ``compresslevel=9`` matches the reference repo. Lower |
| ⇒ more compressible ⇒ simpler / more predictable dynamics. |
| """ |
| raw = states.contiguous().to(torch.uint8).numpy().tobytes() |
| if not raw: |
| return 0.0 |
| compressed = gzip.compress(raw, compresslevel=compresslevel) |
| return len(compressed) / len(raw) * 100.0 |
|
|
|
|
| def rle_compressibility(states: torch.Tensor) -> float: |
| """Run-length compression ratio in percent — an O(N) gzip proxy. |
| |
| Counts the number of (symbol, run_length) pairs in the row-major |
| serialization of the trajectory. Each pair encodes to roughly two |
| bytes (one for the symbol, one for the length, capped at 255), so |
| |
| rle_size ≈ 2 · num_runs ; raw_size = num_cells |
| |
| This understates gzip's actual compression on highly-structured |
| sequences, but the *relative ordering* across rules is preserved |
| well enough that the K-complexity band gate is preserved up to a |
| monotone re-mapping. Empirically, RLE-band ≈ 0.55 · gzip-band on |
| NCA trajectories — set the band threshold proportionally lower. |
| """ |
| flat = states.contiguous().to(torch.uint8).flatten() |
| n = flat.numel() |
| if n == 0: |
| return 0.0 |
| |
| diffs = flat[1:] != flat[:-1] |
| num_runs = int(diffs.sum().item()) + 1 |
| rle_size = 2 * num_runs |
| return rle_size / n * 100.0 |
|
|
|
|
| def trajectory_complexity(states: torch.Tensor, cfg: NCAConfig) -> float: |
| """Apply the configured complexity metric ('gzip' or 'rle').""" |
| if cfg.complexity_metric == "gzip": |
| return gzip_compressibility(states, compresslevel=cfg.gzip_compresslevel) |
| if cfg.complexity_metric == "rle": |
| return rle_compressibility(states) |
| raise ValueError(f"unknown complexity_metric {cfg.complexity_metric!r}") |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| _HASH_PRIME_A = 2166136261 |
| _HASH_PRIME_B = 16777619 |
|
|
|
|
| def _hash_to_buckets(ids: torch.Tensor, n_buckets: int) -> torch.Tensor: |
| """Deterministically hash integer ids into [0, n_buckets) via FNV-1a. |
| |
| Pure torch — no Python hash, no NumPy. Idempotent across processes |
| and OS versions. Operates in int64 to avoid wraparound surprises; |
| the modulo at the end makes the result reproducible. |
| """ |
| h = torch.full_like(ids, _HASH_PRIME_A, dtype=torch.int64) |
| h = (h ^ (ids & 0xFF)) * _HASH_PRIME_B |
| h = (h ^ ((ids >> 8) & 0xFF)) * _HASH_PRIME_B |
| h = (h ^ ((ids >> 16) & 0xFF)) * _HASH_PRIME_B |
| h = (h ^ ((ids >> 24) & 0xFF)) * _HASH_PRIME_B |
| return (h.abs() % n_buckets).to(ids.dtype) |
|
|
|
|
| def tokenize_trajectory(states: torch.Tensor, cfg: NCAConfig) -> list[int]: |
| """Serialize a trajectory into a token list. |
| |
| Layout per timestep: ``<grid> p₁ p₂ … p_K </grid>``, row-major, |
| where each ``p_j`` is either a bijective patch id (tokenizer="patch") |
| or a feature-hashed bucket id (tokenizer="hash_bucket"). |
| """ |
| if states.dim() != 3: |
| raise ValueError(f"expected (T, H, W) states, got shape {tuple(states.shape)}") |
| T, H, W = states.shape |
| ps = cfg.patch_size |
| if H % ps or W % ps: |
| raise ValueError(f"grid {H}×{W} not divisible by patch_size={ps}") |
| n = cfg.n_states |
| states_long = states.to(torch.long) |
| if states_long.numel() and ( |
| states_long.min().item() < 0 or states_long.max().item() >= n |
| ): |
| raise ValueError(f"state values must lie in [0, {n})") |
|
|
| patches = states_long.view(T, H // ps, ps, W // ps, ps).permute(0, 1, 3, 2, 4) |
| patches = patches.reshape(T, -1, ps * ps) |
| weights = torch.tensor( |
| [n ** (ps * ps - 1 - k) for k in range(ps * ps)], |
| dtype=torch.long, |
| ) |
| ids = (patches * weights).sum(dim=-1) |
|
|
| if cfg.tokenizer == "hash_bucket": |
| ids = _hash_to_buckets(ids, cfg.hash_buckets) |
| elif cfg.tokenizer != "patch": |
| raise ValueError(f"unknown tokenizer {cfg.tokenizer!r}") |
|
|
| g_start = cfg.grid_start_id |
| g_end = cfg.grid_end_id |
| out: list[int] = [] |
| rows = ids.tolist() |
| for row in rows: |
| out.append(g_start) |
| out.extend(row) |
| out.append(g_end) |
| return out |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def sample_trajectory_tokens( |
| cfg: NCAConfig, |
| n_steps: int, |
| generator: torch.Generator, |
| gzip_min: float = 50.0, |
| gzip_max: float = 100.0, |
| max_attempts: int = 64, |
| device: torch.device | str = "cpu", |
| batch: int = 1, |
| ) -> tuple[list[int], float] | None: |
| """One trajectory passing the complexity-band gate, tokenized. |
| |
| With ``batch>1`` we propose ``batch`` rules per attempt and return |
| the first one to fall in the requested band — this amortizes the |
| per-rollout overhead and is the recommended path for production. |
| |
| The ``gzip_min`` / ``gzip_max`` arguments are interpreted via |
| ``cfg.complexity_metric`` (gzip percent for "gzip", RLE percent for |
| "rle"). RLE bands typically run ~0.55× the gzip values for matched |
| qualitative complexity. |
| """ |
| for _ in range(max_attempts): |
| states_batch = rollout_batch( |
| cfg, n_steps=n_steps, batch=batch, |
| generator=generator, device=device, |
| ) |
| for b in range(batch): |
| states = states_batch[b] |
| ratio = trajectory_complexity(states, cfg) |
| if gzip_min <= ratio <= gzip_max: |
| return tokenize_trajectory(states, cfg), ratio |
| return None |
|
|