"""gary-neuron: an asynchronous Neural Cellular Automaton whose per-cell update rule is a Mixture-of-Experts. Pure NumPy. Tiny reverse-mode autograd written from scratch (no torch) so the MoE+NCA gradients are correct-by-construction. The mesh is a 1-D strip of S cells, one per digit position (reversed: cell 0 is the least-significant digit). Each async step a random subset of cells update; every updating cell perceives [left, self, right] and routes that perception through a router -> top-k experts. Carry ripples low->high across cells over time, exactly like ripple-carry addition. Trained on reversed-digit addition (Lee et al. 2023), which is the format that makes exact addition learnable. This single file is BOTH the training brain (autograd Tensor `T` + model graph) and the dependency-free inference engine (`forward_np`, used by solve.py). """ import numpy as np f32 = np.float32 # ====================================================================== # minimal reverse-mode autograd over numpy # ====================================================================== class T: """A node in the tape: holds data `.d`, grad `.g`, and a backward closure.""" __slots__ = ("d", "g", "_bw", "_prev", "req") def __init__(self, data, req=True, _prev=()): self.d = np.asarray(data, dtype=f32) self.g = None self._bw = _noop self._prev = _prev self.req = req def _acc(self, grad): self.g = grad if self.g is None else self.g + grad def backward(self): topo, seen = [], set() def build(v): if id(v) in seen: return seen.add(id(v)) for p in v._prev: build(p) topo.append(v) build(self) self.g = np.ones_like(self.d) for v in reversed(topo): v._bw() # operator sugar def __add__(self, o): return add(self, o) def __mul__(self, o): return mul(self, o) def _noop(): pass def _unbroadcast(grad, shape): """Sum `grad` down to `shape` (reverse of numpy broadcasting).""" while grad.ndim > len(shape): grad = grad.sum(0) for i, s in enumerate(shape): if s == 1 and grad.shape[i] != 1: grad = grad.sum(i, keepdims=True) return grad def add(a, b): out = T(a.d + b.d, _prev=(a, b)) def bw(): if a.req: a._acc(_unbroadcast(out.g, a.d.shape)) if b.req: b._acc(_unbroadcast(out.g, b.d.shape)) out._bw = bw; return out def mul(a, b): out = T(a.d * b.d, _prev=(a, b)) def bw(): if a.req: a._acc(_unbroadcast(out.g * b.d, a.d.shape)) if b.req: b._acc(_unbroadcast(out.g * a.d, b.d.shape)) out._bw = bw; return out def matmul(a, b): out = T(a.d @ b.d, _prev=(a, b)) def bw(): if a.req: a._acc(out.g @ b.d.T) if b.req: b._acc(a.d.T @ out.g) out._bw = bw; return out def mulc(x, c): """Multiply by a constant numpy array (mask / gate that carries no grad).""" out = T(x.d * c, _prev=(x,)) def bw(): if x.req: x._acc(out.g * c) out._bw = bw; return out def addc(x, c): out = T(x.d + c, _prev=(x,)) def bw(): if x.req: x._acc(out.g) out._bw = bw; return out def relu(a): m = (a.d > 0) out = T(a.d * m, _prev=(a,)) def bw(): if a.req: a._acc(out.g * m) out._bw = bw; return out def reshape(x, shape): out = T(x.d.reshape(shape), _prev=(x,)) def bw(): if x.req: x._acc(out.g.reshape(x.d.shape)) out._bw = bw; return out def gather(W, idx): """Rows of W (V,d) selected by int idx (N,) -> (N,d).""" out = T(W.d[idx], _prev=(W,)) def bw(): if W.req: g = np.zeros_like(W.d); np.add.at(g, idx, out.g); W._acc(g) out._bw = bw; return out def concat_last(ts): out = T(np.concatenate([t.d for t in ts], axis=-1), _prev=tuple(ts)) sizes = [t.d.shape[-1] for t in ts] def bw(): i = 0 for t, s in zip(ts, sizes): if t.req: t._acc(out.g[..., i:i+s]) i += s out._bw = bw; return out def shift_from_left(x): """neighbour i-1: out[:,i]=x[:,i-1], out[:,0]=0. (the carry source)""" d = np.zeros_like(x.d); d[:, 1:] = x.d[:, :-1] out = T(d, _prev=(x,)) def bw(): if x.req: g = np.zeros_like(x.d); g[:, :-1] += out.g[:, 1:]; x._acc(g) out._bw = bw; return out def shift_from_right(x): """neighbour i+1: out[:,i]=x[:,i+1], out[:,-1]=0.""" d = np.zeros_like(x.d); d[:, :-1] = x.d[:, 1:] out = T(d, _prev=(x,)) def bw(): if x.req: g = np.zeros_like(x.d); g[:, 1:] += out.g[:, :-1]; x._acc(g) out._bw = bw; return out def softmax(x): e = np.exp(x.d - x.d.max(-1, keepdims=True)); p = e / e.sum(-1, keepdims=True) out = T(p, _prev=(x,)) def bw(): if x.req: g = out.g; dot = (g * p).sum(-1, keepdims=True) x._acc(p * (g - dot)) out._bw = bw; return out def col(x, e): out = T(x.d[:, e:e+1], _prev=(x,)) def bw(): if x.req: g = np.zeros_like(x.d); g[:, e:e+1] = out.g; x._acc(g) out._bw = bw; return out def layernorm(x, eps=1e-5): mu = x.d.mean(-1, keepdims=True); xc = x.d - mu var = (xc * xc).mean(-1, keepdims=True); inv = 1.0 / np.sqrt(var + eps) y = xc * inv out = T(y, _prev=(x,)) def bw(): if x.req: g = out.g dx = inv * (g - g.mean(-1, keepdims=True) - y * (g * y).mean(-1, keepdims=True)) x._acc(dx) out._bw = bw; return out def mean0(x): """mean over axis 0: (N,K)->(K,)""" N = x.d.shape[0] out = T(x.d.mean(0), _prev=(x,)) def bw(): if x.req: x._acc(np.broadcast_to(out.g / N, x.d.shape).copy()) out._bw = bw; return out def sumall(x): out = T(np.asarray(x.d.sum(), dtype=f32), _prev=(x,)) def bw(): if x.req: x._acc(np.ones_like(x.d) * out.g) out._bw = bw; return out def cross_entropy(logits, targets): """logits (N,C) T, targets (N,) int -> scalar mean CE.""" z = logits.d - logits.d.max(-1, keepdims=True) e = np.exp(z); p = e / e.sum(-1, keepdims=True) N = z.shape[0] loss = -np.log(p[np.arange(N), targets] + 1e-9).mean() out = T(np.asarray(loss, dtype=f32), _prev=(logits,)) def bw(): g = p.copy(); g[np.arange(N), targets] -= 1.0; g /= N if logits.req: logits._acc(g * out.g) out._bw = bw; return out # ====================================================================== # helpers that produce CONSTANTS (no grad): routing & async masks # ====================================================================== def topk_addmask(logits, k): """Additive mask: 0.0 on the top-k logits per row, -1e9 elsewhere.""" N, K = logits.shape if k >= K: return np.zeros_like(logits) idx = np.argpartition(-logits, k - 1, axis=1)[:, :k] M = np.full_like(logits, -1e9, dtype=f32) np.put_along_axis(M, idx, 0.0, axis=1) return M def async_mask(B, S, rng, p): """1 where a cell updates this step, else 0. shape (B,S,1).""" return (rng.random((B, S, 1)) < p).astype(f32) # ====================================================================== # model # ====================================================================== def default_cfg(): return dict(S=8, d=32, he=32, K=6, topk=2, steps=18, p_update=0.5, Vin=10, Vout=10, aux=0.01) def init_params(cfg, seed=1337): rng = np.random.default_rng(seed) d, he, K, S = cfg["d"], cfg["he"], cfg["K"], cfg["S"] Vin, Vout = cfg["Vin"], cfg["Vout"] P = {} P["emb"] = T(rng.normal(0, 0.08, (Vin, d)).astype(f32)) # shared digit embedding (a and b added) P["posemb"] = T(rng.normal(0, 0.02, (S, d)).astype(f32)) P["Wr"] = T(rng.normal(0, 0.08, (3*d, K)).astype(f32)) # router P["br"] = T(np.zeros(K, f32)) eo = 0.08 / np.sqrt(he) # tiny expert output -> near-identity dynamics at init for e in range(K): P[f"e{e}.W1"] = T(rng.normal(0, 0.08, (3*d, he)).astype(f32)) P[f"e{e}.b1"] = T(np.zeros(he, f32)) P[f"e{e}.W2"] = T((rng.normal(0, eo, (he, d))).astype(f32)) P[f"e{e}.b2"] = T(np.zeros(d, f32)) P["Wo"] = T(rng.normal(0, 0.10, (d, Vout)).astype(f32)) # readout P["bo"] = T(np.zeros(Vout, f32)) return P def n_params(P): return int(sum(v.d.size for v in P.values())) def forward(P, A, B, Y, cfg, rng, train=True, collect=False): """Graph forward. A,B,Y int arrays (Bn,S). Returns (total_loss_T, info).""" Bn, S = A.shape d, K, topk, steps = cfg["d"], cfg["K"], cfg["topk"], cfg["steps"] p_up = cfg["p_update"] if train else cfg["p_update"] N = Bn * S Wr, br, Wo, bo = P["Wr"], P["br"], P["Wo"], P["bo"] ha = reshape(gather(P["emb"], A.reshape(-1)), (Bn, S, d)) hb = reshape(gather(P["emb"], B.reshape(-1)), (Bn, S, d)) H = add(add(ha, hb), P["posemb"]) # (Bn,S,d) broadcast posemb router_probs_accum = [] load_counts = np.zeros(K, dtype=f32) for t in range(steps): Hl = shift_from_left(H) Hr = shift_from_right(H) perc = layernorm(concat_last([Hl, H, Hr])) # (Bn,S,3d) pf = reshape(perc, (N, 3*d)) rl = add(matmul(pf, Wr), br) # (N,K) router logits M = topk_addmask(rl.d, topk) # constant top-k mask gate = softmax(addc(rl, M)) # (N,K) -> only top-k nonzero if cfg["aux"] > 0: router_probs_accum.append(softmax(rl)) # soft probs for load-balance aux load_counts += np.bincount(rl.d.argmax(1), minlength=K).astype(f32) mix = None for e in range(K): h1 = relu(add(matmul(pf, P[f"e{e}.W1"]), P[f"e{e}.b1"])) oe = add(matmul(h1, P[f"e{e}.W2"]), P[f"e{e}.b2"]) # (N,d) ge = mul(col(gate, e), oe) # gate broadcast (N,1)*(N,d) mix = ge if mix is None else add(mix, ge) um = async_mask(Bn, S, rng, p_up) H = add(H, mulc(reshape(mix, (Bn, S, d)), um)) # masked residual update Hf = reshape(H, (N, d)) logits = add(matmul(Hf, Wo), bo) # (N,Vout) loss = cross_entropy(logits, Y.reshape(-1)) total = loss info = {"loss": float(loss.d)} if cfg["aux"] > 0 and router_probs_accum: Pbar = mean0(concat_rows(router_probs_accum)) # (K,) mean soft prob f = load_counts / max(load_counts.sum(), 1.0) # fraction routed (top-1), constant aux = mulc(sumall(mulc(Pbar, f * K)), f32(cfg["aux"])) total = add(loss, aux) info["aux"] = float(aux.d) info["load"] = f if collect: info["pred"] = logits.d.reshape(Bn, S, -1).argmax(-1) return total, info def concat_rows(ts): """stack a list of (N,K) T along axis 0 -> (sum N, K) T.""" out = T(np.concatenate([t.d for t in ts], axis=0), _prev=tuple(ts)) sizes = [t.d.shape[0] for t in ts] def bw(): i = 0 for t, s in zip(ts, sizes): if t.req: t._acc(out.g[i:i+s]); i += s out._bw = bw; return out # ====================================================================== # dependency-free inference (also used by solve.py). numpy only, no tape. # ====================================================================== def _sm(x): e = np.exp(x - x.max(-1, keepdims=True)); return e / e.sum(-1, keepdims=True) def forward_np(W, A, B, cfg, rng, trace=False): """Run the mesh with plain numpy weights `W` (dict of arrays). Returns predicted digit grid (Bn,S) and, if trace, a per-step record for viz. At inference top-k experts are the only ones evaluated (true sparsity).""" Bn, S = A.shape d, K, topk, steps, p_up = cfg["d"], cfg["K"], cfg["topk"], cfg["steps"], cfg["p_update"] emb, pos = W["emb"], W["posemb"] H = emb[A] + emb[B] + pos[None] # (Bn,S,d) frames = [] for t in range(steps): Hl = np.zeros_like(H); Hl[:, 1:] = H[:, :-1] Hr = np.zeros_like(H); Hr[:, :-1] = H[:, 1:] perc = np.concatenate([Hl, H, Hr], axis=-1) mu = perc.mean(-1, keepdims=True); v = perc.var(-1, keepdims=True) perc = (perc - mu) / np.sqrt(v + 1e-5) pf = perc.reshape(Bn*S, 3*d) rl = pf @ W["Wr"] + W["br"] # top-k selection idx = np.argpartition(-rl, topk-1, axis=1)[:, :topk] M = np.full_like(rl, -1e9); np.put_along_axis(M, idx, 0.0, axis=1) gate = _sm(rl + M) mix = np.zeros((Bn*S, d), f32) fired = np.zeros((Bn*S, K), f32) for e in range(K): ge = gate[:, e] act = ge > 0 if not act.any(): continue h1 = np.maximum(pf[act] @ W[f"e{e}.W1"] + W[f"e{e}.b1"], 0) oe = h1 @ W[f"e{e}.W2"] + W[f"e{e}.b2"] mix[act] += ge[act, None] * oe fired[act, e] = 1.0 um = (rng.random((Bn, S, 1)) < p_up).astype(f32) H = H + um * mix.reshape(Bn, S, d) if trace: logit = H.reshape(Bn*S, d) @ W["Wo"] + W["bo"] frames.append(dict( pred=logit.reshape(Bn, S, -1).argmax(-1), updated=um[..., 0].astype(int), expert=(gate.argmax(1)).reshape(Bn, S), fired=fired.reshape(Bn, S, K))) logits = H.reshape(Bn*S, d) @ W["Wo"] + W["bo"] pred = logits.reshape(Bn, S, -1).argmax(-1) return (pred, frames) if trace else pred def params_to_np(P): return {k: v.d.copy() for k, v in P.items()}