"""Random Discrete Neural Cellular Automaton — synthetic data generator. Implements the data-generation procedure described in Lee, Han, Kumar, Agrawal. "Training Language Models via Neural Cellular Automata." arXiv:2603.10055. with parameter choices reconciled against the reference Flax/JAX implementation at https://github.com/danihyunlee/nca-pre-pretraining (``utils/nca.py``). Where paper text and reference code disagree, we default to the **reference code's behavior** but expose the divergence as an explicit ``NCAConfig`` field. The module also adds *sparse / approximate* options that don't appear in the reference but match feather's design philosophy (sparse, fast, topology-respecting). These options keep the data path runnable on a laptop CPU at non-trivial token rates and integrate cleanly with the existing SDR / HTM / Engram subsystems: • ``tokenizer = "hash_bucket"``: hash patch ids into a smaller bucket pool (vocab << n^(patch²)). Exact-id is replaced with feature hashing; collisions are absorbed by feather's downstream representation, much like word2vec hash embeddings. • ``complexity_metric = "rle"``: replace gzip with a vectorized run-length-equivalent compressibility proxy. RLE compression is a strict lower bound on what gzip achieves, but the *band shape* is preserved up to a constant scale, so the K-complexity gate is intact while running ~50× faster. • Batched rollout: ``rollout_batch`` runs ``B`` independent rules in one CUDA call, amortizing the per-rule overhead. Reconciled recipe (defaults): • 2D grid with periodic (toroidal) boundaries; 12×12. • Discrete state alphabet of n=10 symbols; each cell is an n-d one-hot. • Update rule f_θ: 3×3 conv (4 channels) → 1×1 conv (16 channels) → ReLU → 1×1 conv (n logits). 1×1 convs are mathematically identical to per-cell Linear, so this implementation uses Linear layers. • Sampling: ``c_i^(t+1) ~ Categorical((f_θ + identity_bias·c_i)/τ)`` with τ=1e-3 and ``identity_bias=0.0`` (reference defaults). • Per trajectory: fresh random θ. Initial grid: uniform i.i.d. over {0,…,n-1} (paper text) by default; ``initial_state="gaussian_categorical"`` matches the reference code's lower-entropy init. • Frame stride dT=2 (reference): rollout runs ``dT × num_examples`` steps and keeps every dT-th frame. • Burn-in ``start_step=0`` (configurable). • Tokenization: 2×2 patches → bijective ids in ``[0, n^4)``; each timestep framed by ```` / ```` delimiters. • Complexity gate: gzip ratio with ``compresslevel=9``. Paper-default band: r > 50% for OpenWebText/OpenWebMath; 30–40% for CodeParrot. """ from __future__ import annotations import gzip from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- @dataclass(frozen=True) class NCAConfig: """Hyperparameters for the random discrete NCA + tokenizer. Defaults reproduce the reference repo. Sweepable axes: ``rule_type`` "mlp" paper-faithful Linear-MLP cell rule "rbf_switch" Kolmogorov-Arnold radial-switch rule ``initial_state`` "uniform" paper-text default (uniform discrete) "gaussian_categorical" reference code's lower-entropy init ``tokenizer`` "patch" paper-faithful bijective n^4 vocab "hash_bucket" feature-hashed vocab of size ``hash_buckets`` (sparse / faster) ``complexity_metric`` "gzip" paper-faithful (compresslevel=9) "rle" RLE-compression-ratio proxy (~50× faster, preserves band shape) """ # Substrate n_states: int = 10 grid_size: int = 12 conv_channels: int = 4 mlp_hidden: int = 16 tau: float = 1e-3 identity_bias: float = 0.0 initial_state: str = "uniform" # Rollout dT: int = 2 # frame stride: keep every dT-th step start_step: int = 0 # burn-in to skip transients # Tokenizer tokenizer: str = "patch" patch_size: int = 2 hash_buckets: int = 1024 # only consulted when tokenizer == "hash_bucket" # Cell-rule choice rule_type: str = "mlp" rbf_n_centers: int = 8 rbf_sigma: float = 0.5 # Complexity gate complexity_metric: str = "gzip" gzip_compresslevel: int = 9 @property def patch_vocab(self) -> int: """Distinct patch ids (before optional bucket hashing) = n^(patch²).""" return self.n_states ** (self.patch_size * self.patch_size) @property def vocab_size(self) -> int: """Effective vocabulary inclusive of the two grid delimiters.""" if self.tokenizer == "patch": return self.patch_vocab + 2 if self.tokenizer == "hash_bucket": return self.hash_buckets + 2 raise ValueError(f"unknown tokenizer {self.tokenizer!r}") @property def grid_start_id(self) -> int: return self.vocab_size - 2 @property def grid_end_id(self) -> int: return self.vocab_size - 1 # --------------------------------------------------------------------------- # Cell rules # --------------------------------------------------------------------------- class _MLPCellRule(nn.Module): """Reference-faithful cell-wise MLP: 4 → 16 (ReLU) → n_states. Identical to a stack of two 1×1 convolutions sandwiching ReLU when applied per-cell. Implemented as Linear so the einsum tags read cleanly in the perceive→rule pipeline. """ def __init__(self, cfg: NCAConfig) -> None: super().__init__() self.net = nn.Sequential( nn.Linear(cfg.conv_channels, cfg.mlp_hidden, bias=True), nn.ReLU(), nn.Linear(cfg.mlp_hidden, cfg.n_states, bias=True), ) def forward(self, feat: torch.Tensor) -> torch.Tensor: return self.net(feat) class _RBFSwitchCellRule(nn.Module): """Radial-switch / KAN-style cell rule. Implements the Kolmogorov-Arnold representation directly: each input channel feeds a sum of K Gaussian RBFs on a fixed [-1, 1] center grid; the K-vector is then linearly combined into the n_states output. No nonlinearity at the output node — the K-A theorem says we don't need one. logits[o] = Σ_c Σ_k α[c,k,o] · exp(-((x[c]-μ_k)² / 2σ²)) + bias[o] Total params ≈ ``C·K·n + n``. At C=4, K=8, n=10 that's 330 — same order of magnitude as the MLP path's 250. """ def __init__(self, cfg: NCAConfig) -> None: super().__init__() self.cfg = cfg centers = torch.linspace(-1.0, 1.0, cfg.rbf_n_centers) self.register_buffer("centers", centers, persistent=False) self.register_buffer( "inv_two_sigma_sq", torch.tensor(1.0 / (2.0 * cfg.rbf_sigma ** 2)), persistent=False, ) self.alpha = nn.Parameter( torch.empty(cfg.conv_channels, cfg.rbf_n_centers, cfg.n_states) ) self.bias = nn.Parameter(torch.empty(cfg.n_states)) def forward(self, feat: torch.Tensor) -> torch.Tensor: diff = feat.unsqueeze(-1) - self.centers rbf = torch.exp(-(diff * diff) * self.inv_two_sigma_sq) return torch.einsum("bhwck,cko->bhwo", rbf, self.alpha) + self.bias # --------------------------------------------------------------------------- # NCA module # --------------------------------------------------------------------------- class RandomDiscreteNCA(nn.Module): """Single discrete NCA rule f_θ. Construct → freshly random weights. The conv uses ``padding_mode='circular'`` for periodic boundaries. All math is no_grad — this module is a *data generator*, not a trainable component. """ def __init__(self, cfg: NCAConfig | None = None) -> None: super().__init__() self.cfg = cfg or NCAConfig() self.perceive = nn.Conv2d( self.cfg.n_states, self.cfg.conv_channels, kernel_size=3, padding=1, padding_mode="circular", bias=True, ) if self.cfg.rule_type == "mlp": self.cell_rule: nn.Module = _MLPCellRule(self.cfg) elif self.cfg.rule_type == "rbf_switch": self.cell_rule = _RBFSwitchCellRule(self.cfg) else: raise ValueError( f"unknown rule_type {self.cfg.rule_type!r}; " "expected 'mlp' or 'rbf_switch'" ) for p in self.parameters(): p.requires_grad_(False) @torch.no_grad() def step(self, onehot: torch.Tensor) -> torch.Tensor: """Advance one NCA timestep. ``onehot``: ``(B, n_states, H, W)``, values in {0, 1}. Returns the next-step one-hot (still {0, 1}). """ feat = self.perceive(onehot) # (B, conv_ch, H, W) feat = feat.permute(0, 2, 3, 1) # (B, H, W, conv_ch) logits = self.cell_rule(feat) # (B, H, W, n_states) if self.cfg.identity_bias != 0.0: # add identity_bias·c_i; matches reference code's `(logits + state_oh*ib)/τ` current = onehot.permute(0, 2, 3, 1) logits = logits + current * self.cfg.identity_bias probs = F.softmax(logits / self.cfg.tau, dim=-1) flat = probs.reshape(-1, self.cfg.n_states) sampled = torch.multinomial(flat, num_samples=1).squeeze(-1) sampled = sampled.reshape(probs.shape[:-1]) # (B, H, W) return F.one_hot(sampled, self.cfg.n_states).permute(0, 3, 1, 2).to(onehot.dtype) def _init_random_rule(cfg: NCAConfig, generator: torch.Generator) -> RandomDiscreteNCA: """Build an NCA with freshly sampled weights. Kaiming-uniform fan-in. The reference repo uses Flax LeCun-normal; this implementation uses PyTorch's Kaiming-uniform default. Magnitudes differ by ~1.7×, but both schemes are zero-mean and the gzip rejection sampler discards rules whose dynamics fall outside the requested complexity band, so the init scheme is largely a smoothing factor on the acceptance rate. """ nca = RandomDiscreteNCA(cfg) for p in nca.parameters(): if p.dim() > 1: fan_in = p.shape[1] * (p.shape[2] * p.shape[3] if p.dim() == 4 else 1) bound = (3.0 / fan_in) ** 0.5 else: bound = (1.0 / max(p.shape[0], 1)) ** 0.5 with torch.no_grad(): p.uniform_(-bound, bound, generator=generator) return nca # --------------------------------------------------------------------------- # Initial-state samplers # --------------------------------------------------------------------------- def _sample_initial_state( cfg: NCAConfig, batch: int, generator: torch.Generator, ) -> torch.Tensor: """Return ``(B, H, W)`` integer initial state in [0, n_states).""" if cfg.initial_state == "uniform": return torch.randint( 0, cfg.n_states, (batch, cfg.grid_size, cfg.grid_size), generator=generator, device=generator.device, ) if cfg.initial_state == "gaussian_categorical": # Reference repo's init: sample N(0,1) logits, take per-cell categorical. # Constant across (H, W) per group (n_groups=1) per the reference code. logits = torch.randn( (batch, cfg.n_states), generator=generator, device=generator.device, ) cats = torch.distributions.Categorical(logits=logits).sample() # (B,) # Broadcast a single state across the grid (matches reference rearrange). return cats.view(batch, 1, 1).expand(batch, cfg.grid_size, cfg.grid_size).clone() raise ValueError(f"unknown initial_state {cfg.initial_state!r}") # --------------------------------------------------------------------------- # Rollout (single + batched) # --------------------------------------------------------------------------- @torch.no_grad() def rollout( cfg: NCAConfig, n_steps: int, generator: torch.Generator, device: torch.device | str = "cpu", ) -> torch.Tensor: """Single-trajectory rollout. Returns ``(n_steps, H, W)`` uint8.""" out = rollout_batch(cfg, n_steps=n_steps, batch=1, generator=generator, device=device) return out[0] @torch.no_grad() def rollout_batch( cfg: NCAConfig, n_steps: int, batch: int, generator: torch.Generator, device: torch.device | str = "cpu", ) -> torch.Tensor: """Batched rollout: ``B`` independent rules in parallel. Each batch element gets its own random rule (fresh random θ). Frame stride and burn-in are applied uniformly across the batch. Returns ``(B, n_steps, H, W)`` uint8 with values in [0, n_states). """ if n_steps < 1: raise ValueError(f"n_steps must be >= 1, got {n_steps}") if batch < 1: raise ValueError(f"batch must be >= 1, got {batch}") # Build a stacked rule by sampling one set of params per batch entry. # We use B independent NCA modules and stack their forward calls. # For small batches (<= 64) this is cleaner than building a "grouped" # conv that mixes the rules; correctness > 1.5× wall-time savings. rules = [_init_random_rule(cfg, generator).to(device) for _ in range(batch)] init_int = _sample_initial_state(cfg, batch, generator).to(device) onehot = F.one_hot(init_int, cfg.n_states).permute(0, 3, 1, 2).float() total_steps = cfg.start_step + n_steps * cfg.dT keep_indices = set( range(cfg.start_step, cfg.start_step + n_steps * cfg.dT, cfg.dT) ) kept: list[torch.Tensor] = [] # The very-first state is included only if start_step == 0 *and* dT # would land on it (it always does because k=0 is in the range). # To match reference behavior (which uses jax.lax.scan and saves # *post-step* states only when in 'video' mode), we save the initial # state iff start_step == 0. if cfg.start_step == 0: kept.append(init_int.to(torch.uint8).cpu()) keep_indices.discard(0) for t in range(1, total_steps + 1): # Run each rule on its corresponding slice. next_chunks = [] for b in range(batch): slc = onehot[b : b + 1] next_chunks.append(rules[b].step(slc)) onehot = torch.cat(next_chunks, dim=0) if t in keep_indices: kept.append(onehot.argmax(dim=1).to(torch.uint8).cpu()) if len(kept) < n_steps: # Edge case: dT/start_step combination yields fewer frames than # requested. Pad with the last seen frame to keep tensors stackable. while len(kept) < n_steps: kept.append(kept[-1]) # kept is a list of (B, H, W) tensors; stack along time → (n_steps, B, H, W) stacked = torch.stack(kept[:n_steps], dim=0) return stacked.permute(1, 0, 2, 3).contiguous() # (B, T, H, W) # --------------------------------------------------------------------------- # Complexity gates # --------------------------------------------------------------------------- def gzip_compressibility(states: torch.Tensor, compresslevel: int = 9) -> float: """Per-trajectory gzip ratio in percent: ``compressed/raw * 100``. Default ``compresslevel=9`` matches the reference repo. Lower ⇒ more compressible ⇒ simpler / more predictable dynamics. """ raw = states.contiguous().to(torch.uint8).numpy().tobytes() if not raw: return 0.0 compressed = gzip.compress(raw, compresslevel=compresslevel) return len(compressed) / len(raw) * 100.0 def rle_compressibility(states: torch.Tensor) -> float: """Run-length compression ratio in percent — an O(N) gzip proxy. Counts the number of (symbol, run_length) pairs in the row-major serialization of the trajectory. Each pair encodes to roughly two bytes (one for the symbol, one for the length, capped at 255), so rle_size ≈ 2 · num_runs ; raw_size = num_cells This understates gzip's actual compression on highly-structured sequences, but the *relative ordering* across rules is preserved well enough that the K-complexity band gate is preserved up to a monotone re-mapping. Empirically, RLE-band ≈ 0.55 · gzip-band on NCA trajectories — set the band threshold proportionally lower. """ flat = states.contiguous().to(torch.uint8).flatten() n = flat.numel() if n == 0: return 0.0 # Vectorized run-length count: a run starts whenever symbol changes. diffs = flat[1:] != flat[:-1] num_runs = int(diffs.sum().item()) + 1 rle_size = 2 * num_runs return rle_size / n * 100.0 def trajectory_complexity(states: torch.Tensor, cfg: NCAConfig) -> float: """Apply the configured complexity metric ('gzip' or 'rle').""" if cfg.complexity_metric == "gzip": return gzip_compressibility(states, compresslevel=cfg.gzip_compresslevel) if cfg.complexity_metric == "rle": return rle_compressibility(states) raise ValueError(f"unknown complexity_metric {cfg.complexity_metric!r}") # --------------------------------------------------------------------------- # Tokenizers # --------------------------------------------------------------------------- # Stable 32-bit integer hash (FNV-1a-style with our own primes). We use # this only for `hash_bucket` tokenization; the hash kernel is local to # this module so the bin format never depends on Python's hash() seed. _HASH_PRIME_A = 2166136261 _HASH_PRIME_B = 16777619 def _hash_to_buckets(ids: torch.Tensor, n_buckets: int) -> torch.Tensor: """Deterministically hash integer ids into [0, n_buckets) via FNV-1a. Pure torch — no Python hash, no NumPy. Idempotent across processes and OS versions. Operates in int64 to avoid wraparound surprises; the modulo at the end makes the result reproducible. """ h = torch.full_like(ids, _HASH_PRIME_A, dtype=torch.int64) h = (h ^ (ids & 0xFF)) * _HASH_PRIME_B h = (h ^ ((ids >> 8) & 0xFF)) * _HASH_PRIME_B h = (h ^ ((ids >> 16) & 0xFF)) * _HASH_PRIME_B h = (h ^ ((ids >> 24) & 0xFF)) * _HASH_PRIME_B return (h.abs() % n_buckets).to(ids.dtype) def tokenize_trajectory(states: torch.Tensor, cfg: NCAConfig) -> list[int]: """Serialize a trajectory into a token list. Layout per timestep: `` p₁ p₂ … p_K ``, row-major, where each ``p_j`` is either a bijective patch id (tokenizer="patch") or a feature-hashed bucket id (tokenizer="hash_bucket"). """ if states.dim() != 3: raise ValueError(f"expected (T, H, W) states, got shape {tuple(states.shape)}") T, H, W = states.shape ps = cfg.patch_size if H % ps or W % ps: raise ValueError(f"grid {H}×{W} not divisible by patch_size={ps}") n = cfg.n_states states_long = states.to(torch.long) if states_long.numel() and ( states_long.min().item() < 0 or states_long.max().item() >= n ): raise ValueError(f"state values must lie in [0, {n})") patches = states_long.view(T, H // ps, ps, W // ps, ps).permute(0, 1, 3, 2, 4) patches = patches.reshape(T, -1, ps * ps) weights = torch.tensor( [n ** (ps * ps - 1 - k) for k in range(ps * ps)], dtype=torch.long, ) ids = (patches * weights).sum(dim=-1) # (T, num_patches) if cfg.tokenizer == "hash_bucket": ids = _hash_to_buckets(ids, cfg.hash_buckets) elif cfg.tokenizer != "patch": raise ValueError(f"unknown tokenizer {cfg.tokenizer!r}") g_start = cfg.grid_start_id g_end = cfg.grid_end_id out: list[int] = [] rows = ids.tolist() for row in rows: out.append(g_start) out.extend(row) out.append(g_end) return out # --------------------------------------------------------------------------- # End-to-end sampler # --------------------------------------------------------------------------- @torch.no_grad() def sample_trajectory_tokens( cfg: NCAConfig, n_steps: int, generator: torch.Generator, gzip_min: float = 50.0, gzip_max: float = 100.0, max_attempts: int = 64, device: torch.device | str = "cpu", batch: int = 1, ) -> tuple[list[int], float] | None: """One trajectory passing the complexity-band gate, tokenized. With ``batch>1`` we propose ``batch`` rules per attempt and return the first one to fall in the requested band — this amortizes the per-rollout overhead and is the recommended path for production. The ``gzip_min`` / ``gzip_max`` arguments are interpreted via ``cfg.complexity_metric`` (gzip percent for "gzip", RLE percent for "rle"). RLE bands typically run ~0.55× the gzip values for matched qualitative complexity. """ for _ in range(max_attempts): states_batch = rollout_batch( cfg, n_steps=n_steps, batch=batch, generator=generator, device=device, ) for b in range(batch): states = states_batch[b] ratio = trajectory_complexity(states, cfg) if gzip_min <= ratio <= gzip_max: return tokenize_trajectory(states, cfg), ratio return None