| """gary-neuron: an asynchronous Neural Cellular Automaton whose per-cell update |
| rule is a Mixture-of-Experts. Pure NumPy. Tiny reverse-mode autograd written |
| from scratch (no torch) so the MoE+NCA gradients are correct-by-construction. |
| |
| The mesh is a 1-D strip of S cells, one per digit position (reversed: cell 0 is |
| the least-significant digit). Each async step a random subset of cells update; |
| every updating cell perceives [left, self, right] and routes that perception |
| through a router -> top-k experts. Carry ripples low->high across cells over |
| time, exactly like ripple-carry addition. Trained on reversed-digit addition |
| (Lee et al. 2023), which is the format that makes exact addition learnable. |
| |
| This single file is BOTH the training brain (autograd Tensor `T` + model graph) |
| and the dependency-free inference engine (`forward_np`, used by solve.py). |
| """ |
| import numpy as np |
|
|
| f32 = np.float32 |
|
|
| |
| |
| |
| class T: |
| """A node in the tape: holds data `.d`, grad `.g`, and a backward closure.""" |
| __slots__ = ("d", "g", "_bw", "_prev", "req") |
| def __init__(self, data, req=True, _prev=()): |
| self.d = np.asarray(data, dtype=f32) |
| self.g = None |
| self._bw = _noop |
| self._prev = _prev |
| self.req = req |
| def _acc(self, grad): |
| self.g = grad if self.g is None else self.g + grad |
| def backward(self): |
| topo, seen = [], set() |
| def build(v): |
| if id(v) in seen: return |
| seen.add(id(v)) |
| for p in v._prev: build(p) |
| topo.append(v) |
| build(self) |
| self.g = np.ones_like(self.d) |
| for v in reversed(topo): |
| v._bw() |
| |
| def __add__(self, o): return add(self, o) |
| def __mul__(self, o): return mul(self, o) |
|
|
| def _noop(): pass |
|
|
| def _unbroadcast(grad, shape): |
| """Sum `grad` down to `shape` (reverse of numpy broadcasting).""" |
| while grad.ndim > len(shape): |
| grad = grad.sum(0) |
| for i, s in enumerate(shape): |
| if s == 1 and grad.shape[i] != 1: |
| grad = grad.sum(i, keepdims=True) |
| return grad |
|
|
| def add(a, b): |
| out = T(a.d + b.d, _prev=(a, b)) |
| def bw(): |
| if a.req: a._acc(_unbroadcast(out.g, a.d.shape)) |
| if b.req: b._acc(_unbroadcast(out.g, b.d.shape)) |
| out._bw = bw; return out |
|
|
| def mul(a, b): |
| out = T(a.d * b.d, _prev=(a, b)) |
| def bw(): |
| if a.req: a._acc(_unbroadcast(out.g * b.d, a.d.shape)) |
| if b.req: b._acc(_unbroadcast(out.g * a.d, b.d.shape)) |
| out._bw = bw; return out |
|
|
| def matmul(a, b): |
| out = T(a.d @ b.d, _prev=(a, b)) |
| def bw(): |
| if a.req: a._acc(out.g @ b.d.T) |
| if b.req: b._acc(a.d.T @ out.g) |
| out._bw = bw; return out |
|
|
| def mulc(x, c): |
| """Multiply by a constant numpy array (mask / gate that carries no grad).""" |
| out = T(x.d * c, _prev=(x,)) |
| def bw(): |
| if x.req: x._acc(out.g * c) |
| out._bw = bw; return out |
|
|
| def addc(x, c): |
| out = T(x.d + c, _prev=(x,)) |
| def bw(): |
| if x.req: x._acc(out.g) |
| out._bw = bw; return out |
|
|
| def relu(a): |
| m = (a.d > 0) |
| out = T(a.d * m, _prev=(a,)) |
| def bw(): |
| if a.req: a._acc(out.g * m) |
| out._bw = bw; return out |
|
|
| def reshape(x, shape): |
| out = T(x.d.reshape(shape), _prev=(x,)) |
| def bw(): |
| if x.req: x._acc(out.g.reshape(x.d.shape)) |
| out._bw = bw; return out |
|
|
| def gather(W, idx): |
| """Rows of W (V,d) selected by int idx (N,) -> (N,d).""" |
| out = T(W.d[idx], _prev=(W,)) |
| def bw(): |
| if W.req: |
| g = np.zeros_like(W.d); np.add.at(g, idx, out.g); W._acc(g) |
| out._bw = bw; return out |
|
|
| def concat_last(ts): |
| out = T(np.concatenate([t.d for t in ts], axis=-1), _prev=tuple(ts)) |
| sizes = [t.d.shape[-1] for t in ts] |
| def bw(): |
| i = 0 |
| for t, s in zip(ts, sizes): |
| if t.req: t._acc(out.g[..., i:i+s]) |
| i += s |
| out._bw = bw; return out |
|
|
| def shift_from_left(x): |
| """neighbour i-1: out[:,i]=x[:,i-1], out[:,0]=0. (the carry source)""" |
| d = np.zeros_like(x.d); d[:, 1:] = x.d[:, :-1] |
| out = T(d, _prev=(x,)) |
| def bw(): |
| if x.req: |
| g = np.zeros_like(x.d); g[:, :-1] += out.g[:, 1:]; x._acc(g) |
| out._bw = bw; return out |
|
|
| def shift_from_right(x): |
| """neighbour i+1: out[:,i]=x[:,i+1], out[:,-1]=0.""" |
| d = np.zeros_like(x.d); d[:, :-1] = x.d[:, 1:] |
| out = T(d, _prev=(x,)) |
| def bw(): |
| if x.req: |
| g = np.zeros_like(x.d); g[:, 1:] += out.g[:, :-1]; x._acc(g) |
| out._bw = bw; return out |
|
|
| def softmax(x): |
| e = np.exp(x.d - x.d.max(-1, keepdims=True)); p = e / e.sum(-1, keepdims=True) |
| out = T(p, _prev=(x,)) |
| def bw(): |
| if x.req: |
| g = out.g; dot = (g * p).sum(-1, keepdims=True) |
| x._acc(p * (g - dot)) |
| out._bw = bw; return out |
|
|
| def col(x, e): |
| out = T(x.d[:, e:e+1], _prev=(x,)) |
| def bw(): |
| if x.req: |
| g = np.zeros_like(x.d); g[:, e:e+1] = out.g; x._acc(g) |
| out._bw = bw; return out |
|
|
| def layernorm(x, eps=1e-5): |
| mu = x.d.mean(-1, keepdims=True); xc = x.d - mu |
| var = (xc * xc).mean(-1, keepdims=True); inv = 1.0 / np.sqrt(var + eps) |
| y = xc * inv |
| out = T(y, _prev=(x,)) |
| def bw(): |
| if x.req: |
| g = out.g |
| dx = inv * (g - g.mean(-1, keepdims=True) - y * (g * y).mean(-1, keepdims=True)) |
| x._acc(dx) |
| out._bw = bw; return out |
|
|
| def mean0(x): |
| """mean over axis 0: (N,K)->(K,)""" |
| N = x.d.shape[0] |
| out = T(x.d.mean(0), _prev=(x,)) |
| def bw(): |
| if x.req: x._acc(np.broadcast_to(out.g / N, x.d.shape).copy()) |
| out._bw = bw; return out |
|
|
| def sumall(x): |
| out = T(np.asarray(x.d.sum(), dtype=f32), _prev=(x,)) |
| def bw(): |
| if x.req: x._acc(np.ones_like(x.d) * out.g) |
| out._bw = bw; return out |
|
|
| def cross_entropy(logits, targets): |
| """logits (N,C) T, targets (N,) int -> scalar mean CE.""" |
| z = logits.d - logits.d.max(-1, keepdims=True) |
| e = np.exp(z); p = e / e.sum(-1, keepdims=True) |
| N = z.shape[0] |
| loss = -np.log(p[np.arange(N), targets] + 1e-9).mean() |
| out = T(np.asarray(loss, dtype=f32), _prev=(logits,)) |
| def bw(): |
| g = p.copy(); g[np.arange(N), targets] -= 1.0; g /= N |
| if logits.req: logits._acc(g * out.g) |
| out._bw = bw; return out |
|
|
| |
| |
| |
| def topk_addmask(logits, k): |
| """Additive mask: 0.0 on the top-k logits per row, -1e9 elsewhere.""" |
| N, K = logits.shape |
| if k >= K: |
| return np.zeros_like(logits) |
| idx = np.argpartition(-logits, k - 1, axis=1)[:, :k] |
| M = np.full_like(logits, -1e9, dtype=f32) |
| np.put_along_axis(M, idx, 0.0, axis=1) |
| return M |
|
|
| def async_mask(B, S, rng, p): |
| """1 where a cell updates this step, else 0. shape (B,S,1).""" |
| return (rng.random((B, S, 1)) < p).astype(f32) |
|
|
| |
| |
| |
| def default_cfg(): |
| return dict(S=8, d=32, he=32, K=6, topk=2, steps=18, p_update=0.5, |
| Vin=10, Vout=10, aux=0.01) |
|
|
| def init_params(cfg, seed=1337): |
| rng = np.random.default_rng(seed) |
| d, he, K, S = cfg["d"], cfg["he"], cfg["K"], cfg["S"] |
| Vin, Vout = cfg["Vin"], cfg["Vout"] |
| P = {} |
| P["emb"] = T(rng.normal(0, 0.08, (Vin, d)).astype(f32)) |
| P["posemb"] = T(rng.normal(0, 0.02, (S, d)).astype(f32)) |
| P["Wr"] = T(rng.normal(0, 0.08, (3*d, K)).astype(f32)) |
| P["br"] = T(np.zeros(K, f32)) |
| eo = 0.08 / np.sqrt(he) |
| for e in range(K): |
| P[f"e{e}.W1"] = T(rng.normal(0, 0.08, (3*d, he)).astype(f32)) |
| P[f"e{e}.b1"] = T(np.zeros(he, f32)) |
| P[f"e{e}.W2"] = T((rng.normal(0, eo, (he, d))).astype(f32)) |
| P[f"e{e}.b2"] = T(np.zeros(d, f32)) |
| P["Wo"] = T(rng.normal(0, 0.10, (d, Vout)).astype(f32)) |
| P["bo"] = T(np.zeros(Vout, f32)) |
| return P |
|
|
| def n_params(P): |
| return int(sum(v.d.size for v in P.values())) |
|
|
| def forward(P, A, B, Y, cfg, rng, train=True, collect=False): |
| """Graph forward. A,B,Y int arrays (Bn,S). Returns (total_loss_T, info).""" |
| Bn, S = A.shape |
| d, K, topk, steps = cfg["d"], cfg["K"], cfg["topk"], cfg["steps"] |
| p_up = cfg["p_update"] if train else cfg["p_update"] |
| N = Bn * S |
| Wr, br, Wo, bo = P["Wr"], P["br"], P["Wo"], P["bo"] |
|
|
| ha = reshape(gather(P["emb"], A.reshape(-1)), (Bn, S, d)) |
| hb = reshape(gather(P["emb"], B.reshape(-1)), (Bn, S, d)) |
| H = add(add(ha, hb), P["posemb"]) |
| router_probs_accum = [] |
| load_counts = np.zeros(K, dtype=f32) |
|
|
| for t in range(steps): |
| Hl = shift_from_left(H) |
| Hr = shift_from_right(H) |
| perc = layernorm(concat_last([Hl, H, Hr])) |
| pf = reshape(perc, (N, 3*d)) |
| rl = add(matmul(pf, Wr), br) |
| M = topk_addmask(rl.d, topk) |
| gate = softmax(addc(rl, M)) |
| if cfg["aux"] > 0: |
| router_probs_accum.append(softmax(rl)) |
| load_counts += np.bincount(rl.d.argmax(1), minlength=K).astype(f32) |
| mix = None |
| for e in range(K): |
| h1 = relu(add(matmul(pf, P[f"e{e}.W1"]), P[f"e{e}.b1"])) |
| oe = add(matmul(h1, P[f"e{e}.W2"]), P[f"e{e}.b2"]) |
| ge = mul(col(gate, e), oe) |
| mix = ge if mix is None else add(mix, ge) |
| um = async_mask(Bn, S, rng, p_up) |
| H = add(H, mulc(reshape(mix, (Bn, S, d)), um)) |
|
|
| Hf = reshape(H, (N, d)) |
| logits = add(matmul(Hf, Wo), bo) |
| loss = cross_entropy(logits, Y.reshape(-1)) |
| total = loss |
| info = {"loss": float(loss.d)} |
| if cfg["aux"] > 0 and router_probs_accum: |
| Pbar = mean0(concat_rows(router_probs_accum)) |
| f = load_counts / max(load_counts.sum(), 1.0) |
| aux = mulc(sumall(mulc(Pbar, f * K)), f32(cfg["aux"])) |
| total = add(loss, aux) |
| info["aux"] = float(aux.d) |
| info["load"] = f |
| if collect: |
| info["pred"] = logits.d.reshape(Bn, S, -1).argmax(-1) |
| return total, info |
|
|
| def concat_rows(ts): |
| """stack a list of (N,K) T along axis 0 -> (sum N, K) T.""" |
| out = T(np.concatenate([t.d for t in ts], axis=0), _prev=tuple(ts)) |
| sizes = [t.d.shape[0] for t in ts] |
| def bw(): |
| i = 0 |
| for t, s in zip(ts, sizes): |
| if t.req: t._acc(out.g[i:i+s]); |
| i += s |
| out._bw = bw; return out |
|
|
| |
| |
| |
| def _sm(x): |
| e = np.exp(x - x.max(-1, keepdims=True)); return e / e.sum(-1, keepdims=True) |
|
|
| def forward_np(W, A, B, cfg, rng, trace=False): |
| """Run the mesh with plain numpy weights `W` (dict of arrays). Returns |
| predicted digit grid (Bn,S) and, if trace, a per-step record for viz. |
| At inference top-k experts are the only ones evaluated (true sparsity).""" |
| Bn, S = A.shape |
| d, K, topk, steps, p_up = cfg["d"], cfg["K"], cfg["topk"], cfg["steps"], cfg["p_update"] |
| emb, pos = W["emb"], W["posemb"] |
| H = emb[A] + emb[B] + pos[None] |
| frames = [] |
| for t in range(steps): |
| Hl = np.zeros_like(H); Hl[:, 1:] = H[:, :-1] |
| Hr = np.zeros_like(H); Hr[:, :-1] = H[:, 1:] |
| perc = np.concatenate([Hl, H, Hr], axis=-1) |
| mu = perc.mean(-1, keepdims=True); v = perc.var(-1, keepdims=True) |
| perc = (perc - mu) / np.sqrt(v + 1e-5) |
| pf = perc.reshape(Bn*S, 3*d) |
| rl = pf @ W["Wr"] + W["br"] |
| |
| idx = np.argpartition(-rl, topk-1, axis=1)[:, :topk] |
| M = np.full_like(rl, -1e9); np.put_along_axis(M, idx, 0.0, axis=1) |
| gate = _sm(rl + M) |
| mix = np.zeros((Bn*S, d), f32) |
| fired = np.zeros((Bn*S, K), f32) |
| for e in range(K): |
| ge = gate[:, e] |
| act = ge > 0 |
| if not act.any(): |
| continue |
| h1 = np.maximum(pf[act] @ W[f"e{e}.W1"] + W[f"e{e}.b1"], 0) |
| oe = h1 @ W[f"e{e}.W2"] + W[f"e{e}.b2"] |
| mix[act] += ge[act, None] * oe |
| fired[act, e] = 1.0 |
| um = (rng.random((Bn, S, 1)) < p_up).astype(f32) |
| H = H + um * mix.reshape(Bn, S, d) |
| if trace: |
| logit = H.reshape(Bn*S, d) @ W["Wo"] + W["bo"] |
| frames.append(dict( |
| pred=logit.reshape(Bn, S, -1).argmax(-1), |
| updated=um[..., 0].astype(int), |
| expert=(gate.argmax(1)).reshape(Bn, S), |
| fired=fired.reshape(Bn, S, K))) |
| logits = H.reshape(Bn*S, d) @ W["Wo"] + W["bo"] |
| pred = logits.reshape(Bn, S, -1).argmax(-1) |
| return (pred, frames) if trace else pred |
|
|
| def params_to_np(P): |
| return {k: v.d.copy() for k, v in P.items()} |
|
|