gary-neuron / training /garyneuron.py
gary23w's picture
gary-neuron: async NCA + top-2 MoE, 26k params, 99.97%/100% exact-match on 7-digit addition
57f9808 verified
Raw
History Blame Contribute Delete
13.7 kB
"""gary-neuron: an asynchronous Neural Cellular Automaton whose per-cell update
rule is a Mixture-of-Experts. Pure NumPy. Tiny reverse-mode autograd written
from scratch (no torch) so the MoE+NCA gradients are correct-by-construction.
The mesh is a 1-D strip of S cells, one per digit position (reversed: cell 0 is
the least-significant digit). Each async step a random subset of cells update;
every updating cell perceives [left, self, right] and routes that perception
through a router -> top-k experts. Carry ripples low->high across cells over
time, exactly like ripple-carry addition. Trained on reversed-digit addition
(Lee et al. 2023), which is the format that makes exact addition learnable.
This single file is BOTH the training brain (autograd Tensor `T` + model graph)
and the dependency-free inference engine (`forward_np`, used by solve.py).
"""
import numpy as np
f32 = np.float32
# ======================================================================
# minimal reverse-mode autograd over numpy
# ======================================================================
class T:
"""A node in the tape: holds data `.d`, grad `.g`, and a backward closure."""
__slots__ = ("d", "g", "_bw", "_prev", "req")
def __init__(self, data, req=True, _prev=()):
self.d = np.asarray(data, dtype=f32)
self.g = None
self._bw = _noop
self._prev = _prev
self.req = req
def _acc(self, grad):
self.g = grad if self.g is None else self.g + grad
def backward(self):
topo, seen = [], set()
def build(v):
if id(v) in seen: return
seen.add(id(v))
for p in v._prev: build(p)
topo.append(v)
build(self)
self.g = np.ones_like(self.d)
for v in reversed(topo):
v._bw()
# operator sugar
def __add__(self, o): return add(self, o)
def __mul__(self, o): return mul(self, o)
def _noop(): pass
def _unbroadcast(grad, shape):
"""Sum `grad` down to `shape` (reverse of numpy broadcasting)."""
while grad.ndim > len(shape):
grad = grad.sum(0)
for i, s in enumerate(shape):
if s == 1 and grad.shape[i] != 1:
grad = grad.sum(i, keepdims=True)
return grad
def add(a, b):
out = T(a.d + b.d, _prev=(a, b))
def bw():
if a.req: a._acc(_unbroadcast(out.g, a.d.shape))
if b.req: b._acc(_unbroadcast(out.g, b.d.shape))
out._bw = bw; return out
def mul(a, b):
out = T(a.d * b.d, _prev=(a, b))
def bw():
if a.req: a._acc(_unbroadcast(out.g * b.d, a.d.shape))
if b.req: b._acc(_unbroadcast(out.g * a.d, b.d.shape))
out._bw = bw; return out
def matmul(a, b):
out = T(a.d @ b.d, _prev=(a, b))
def bw():
if a.req: a._acc(out.g @ b.d.T)
if b.req: b._acc(a.d.T @ out.g)
out._bw = bw; return out
def mulc(x, c):
"""Multiply by a constant numpy array (mask / gate that carries no grad)."""
out = T(x.d * c, _prev=(x,))
def bw():
if x.req: x._acc(out.g * c)
out._bw = bw; return out
def addc(x, c):
out = T(x.d + c, _prev=(x,))
def bw():
if x.req: x._acc(out.g)
out._bw = bw; return out
def relu(a):
m = (a.d > 0)
out = T(a.d * m, _prev=(a,))
def bw():
if a.req: a._acc(out.g * m)
out._bw = bw; return out
def reshape(x, shape):
out = T(x.d.reshape(shape), _prev=(x,))
def bw():
if x.req: x._acc(out.g.reshape(x.d.shape))
out._bw = bw; return out
def gather(W, idx):
"""Rows of W (V,d) selected by int idx (N,) -> (N,d)."""
out = T(W.d[idx], _prev=(W,))
def bw():
if W.req:
g = np.zeros_like(W.d); np.add.at(g, idx, out.g); W._acc(g)
out._bw = bw; return out
def concat_last(ts):
out = T(np.concatenate([t.d for t in ts], axis=-1), _prev=tuple(ts))
sizes = [t.d.shape[-1] for t in ts]
def bw():
i = 0
for t, s in zip(ts, sizes):
if t.req: t._acc(out.g[..., i:i+s])
i += s
out._bw = bw; return out
def shift_from_left(x):
"""neighbour i-1: out[:,i]=x[:,i-1], out[:,0]=0. (the carry source)"""
d = np.zeros_like(x.d); d[:, 1:] = x.d[:, :-1]
out = T(d, _prev=(x,))
def bw():
if x.req:
g = np.zeros_like(x.d); g[:, :-1] += out.g[:, 1:]; x._acc(g)
out._bw = bw; return out
def shift_from_right(x):
"""neighbour i+1: out[:,i]=x[:,i+1], out[:,-1]=0."""
d = np.zeros_like(x.d); d[:, :-1] = x.d[:, 1:]
out = T(d, _prev=(x,))
def bw():
if x.req:
g = np.zeros_like(x.d); g[:, 1:] += out.g[:, :-1]; x._acc(g)
out._bw = bw; return out
def softmax(x):
e = np.exp(x.d - x.d.max(-1, keepdims=True)); p = e / e.sum(-1, keepdims=True)
out = T(p, _prev=(x,))
def bw():
if x.req:
g = out.g; dot = (g * p).sum(-1, keepdims=True)
x._acc(p * (g - dot))
out._bw = bw; return out
def col(x, e):
out = T(x.d[:, e:e+1], _prev=(x,))
def bw():
if x.req:
g = np.zeros_like(x.d); g[:, e:e+1] = out.g; x._acc(g)
out._bw = bw; return out
def layernorm(x, eps=1e-5):
mu = x.d.mean(-1, keepdims=True); xc = x.d - mu
var = (xc * xc).mean(-1, keepdims=True); inv = 1.0 / np.sqrt(var + eps)
y = xc * inv
out = T(y, _prev=(x,))
def bw():
if x.req:
g = out.g
dx = inv * (g - g.mean(-1, keepdims=True) - y * (g * y).mean(-1, keepdims=True))
x._acc(dx)
out._bw = bw; return out
def mean0(x):
"""mean over axis 0: (N,K)->(K,)"""
N = x.d.shape[0]
out = T(x.d.mean(0), _prev=(x,))
def bw():
if x.req: x._acc(np.broadcast_to(out.g / N, x.d.shape).copy())
out._bw = bw; return out
def sumall(x):
out = T(np.asarray(x.d.sum(), dtype=f32), _prev=(x,))
def bw():
if x.req: x._acc(np.ones_like(x.d) * out.g)
out._bw = bw; return out
def cross_entropy(logits, targets):
"""logits (N,C) T, targets (N,) int -> scalar mean CE."""
z = logits.d - logits.d.max(-1, keepdims=True)
e = np.exp(z); p = e / e.sum(-1, keepdims=True)
N = z.shape[0]
loss = -np.log(p[np.arange(N), targets] + 1e-9).mean()
out = T(np.asarray(loss, dtype=f32), _prev=(logits,))
def bw():
g = p.copy(); g[np.arange(N), targets] -= 1.0; g /= N
if logits.req: logits._acc(g * out.g)
out._bw = bw; return out
# ======================================================================
# helpers that produce CONSTANTS (no grad): routing & async masks
# ======================================================================
def topk_addmask(logits, k):
"""Additive mask: 0.0 on the top-k logits per row, -1e9 elsewhere."""
N, K = logits.shape
if k >= K:
return np.zeros_like(logits)
idx = np.argpartition(-logits, k - 1, axis=1)[:, :k]
M = np.full_like(logits, -1e9, dtype=f32)
np.put_along_axis(M, idx, 0.0, axis=1)
return M
def async_mask(B, S, rng, p):
"""1 where a cell updates this step, else 0. shape (B,S,1)."""
return (rng.random((B, S, 1)) < p).astype(f32)
# ======================================================================
# model
# ======================================================================
def default_cfg():
return dict(S=8, d=32, he=32, K=6, topk=2, steps=18, p_update=0.5,
Vin=10, Vout=10, aux=0.01)
def init_params(cfg, seed=1337):
rng = np.random.default_rng(seed)
d, he, K, S = cfg["d"], cfg["he"], cfg["K"], cfg["S"]
Vin, Vout = cfg["Vin"], cfg["Vout"]
P = {}
P["emb"] = T(rng.normal(0, 0.08, (Vin, d)).astype(f32)) # shared digit embedding (a and b added)
P["posemb"] = T(rng.normal(0, 0.02, (S, d)).astype(f32))
P["Wr"] = T(rng.normal(0, 0.08, (3*d, K)).astype(f32)) # router
P["br"] = T(np.zeros(K, f32))
eo = 0.08 / np.sqrt(he) # tiny expert output -> near-identity dynamics at init
for e in range(K):
P[f"e{e}.W1"] = T(rng.normal(0, 0.08, (3*d, he)).astype(f32))
P[f"e{e}.b1"] = T(np.zeros(he, f32))
P[f"e{e}.W2"] = T((rng.normal(0, eo, (he, d))).astype(f32))
P[f"e{e}.b2"] = T(np.zeros(d, f32))
P["Wo"] = T(rng.normal(0, 0.10, (d, Vout)).astype(f32)) # readout
P["bo"] = T(np.zeros(Vout, f32))
return P
def n_params(P):
return int(sum(v.d.size for v in P.values()))
def forward(P, A, B, Y, cfg, rng, train=True, collect=False):
"""Graph forward. A,B,Y int arrays (Bn,S). Returns (total_loss_T, info)."""
Bn, S = A.shape
d, K, topk, steps = cfg["d"], cfg["K"], cfg["topk"], cfg["steps"]
p_up = cfg["p_update"] if train else cfg["p_update"]
N = Bn * S
Wr, br, Wo, bo = P["Wr"], P["br"], P["Wo"], P["bo"]
ha = reshape(gather(P["emb"], A.reshape(-1)), (Bn, S, d))
hb = reshape(gather(P["emb"], B.reshape(-1)), (Bn, S, d))
H = add(add(ha, hb), P["posemb"]) # (Bn,S,d) broadcast posemb
router_probs_accum = []
load_counts = np.zeros(K, dtype=f32)
for t in range(steps):
Hl = shift_from_left(H)
Hr = shift_from_right(H)
perc = layernorm(concat_last([Hl, H, Hr])) # (Bn,S,3d)
pf = reshape(perc, (N, 3*d))
rl = add(matmul(pf, Wr), br) # (N,K) router logits
M = topk_addmask(rl.d, topk) # constant top-k mask
gate = softmax(addc(rl, M)) # (N,K) -> only top-k nonzero
if cfg["aux"] > 0:
router_probs_accum.append(softmax(rl)) # soft probs for load-balance aux
load_counts += np.bincount(rl.d.argmax(1), minlength=K).astype(f32)
mix = None
for e in range(K):
h1 = relu(add(matmul(pf, P[f"e{e}.W1"]), P[f"e{e}.b1"]))
oe = add(matmul(h1, P[f"e{e}.W2"]), P[f"e{e}.b2"]) # (N,d)
ge = mul(col(gate, e), oe) # gate broadcast (N,1)*(N,d)
mix = ge if mix is None else add(mix, ge)
um = async_mask(Bn, S, rng, p_up)
H = add(H, mulc(reshape(mix, (Bn, S, d)), um)) # masked residual update
Hf = reshape(H, (N, d))
logits = add(matmul(Hf, Wo), bo) # (N,Vout)
loss = cross_entropy(logits, Y.reshape(-1))
total = loss
info = {"loss": float(loss.d)}
if cfg["aux"] > 0 and router_probs_accum:
Pbar = mean0(concat_rows(router_probs_accum)) # (K,) mean soft prob
f = load_counts / max(load_counts.sum(), 1.0) # fraction routed (top-1), constant
aux = mulc(sumall(mulc(Pbar, f * K)), f32(cfg["aux"]))
total = add(loss, aux)
info["aux"] = float(aux.d)
info["load"] = f
if collect:
info["pred"] = logits.d.reshape(Bn, S, -1).argmax(-1)
return total, info
def concat_rows(ts):
"""stack a list of (N,K) T along axis 0 -> (sum N, K) T."""
out = T(np.concatenate([t.d for t in ts], axis=0), _prev=tuple(ts))
sizes = [t.d.shape[0] for t in ts]
def bw():
i = 0
for t, s in zip(ts, sizes):
if t.req: t._acc(out.g[i:i+s]);
i += s
out._bw = bw; return out
# ======================================================================
# dependency-free inference (also used by solve.py). numpy only, no tape.
# ======================================================================
def _sm(x):
e = np.exp(x - x.max(-1, keepdims=True)); return e / e.sum(-1, keepdims=True)
def forward_np(W, A, B, cfg, rng, trace=False):
"""Run the mesh with plain numpy weights `W` (dict of arrays). Returns
predicted digit grid (Bn,S) and, if trace, a per-step record for viz.
At inference top-k experts are the only ones evaluated (true sparsity)."""
Bn, S = A.shape
d, K, topk, steps, p_up = cfg["d"], cfg["K"], cfg["topk"], cfg["steps"], cfg["p_update"]
emb, pos = W["emb"], W["posemb"]
H = emb[A] + emb[B] + pos[None] # (Bn,S,d)
frames = []
for t in range(steps):
Hl = np.zeros_like(H); Hl[:, 1:] = H[:, :-1]
Hr = np.zeros_like(H); Hr[:, :-1] = H[:, 1:]
perc = np.concatenate([Hl, H, Hr], axis=-1)
mu = perc.mean(-1, keepdims=True); v = perc.var(-1, keepdims=True)
perc = (perc - mu) / np.sqrt(v + 1e-5)
pf = perc.reshape(Bn*S, 3*d)
rl = pf @ W["Wr"] + W["br"]
# top-k selection
idx = np.argpartition(-rl, topk-1, axis=1)[:, :topk]
M = np.full_like(rl, -1e9); np.put_along_axis(M, idx, 0.0, axis=1)
gate = _sm(rl + M)
mix = np.zeros((Bn*S, d), f32)
fired = np.zeros((Bn*S, K), f32)
for e in range(K):
ge = gate[:, e]
act = ge > 0
if not act.any():
continue
h1 = np.maximum(pf[act] @ W[f"e{e}.W1"] + W[f"e{e}.b1"], 0)
oe = h1 @ W[f"e{e}.W2"] + W[f"e{e}.b2"]
mix[act] += ge[act, None] * oe
fired[act, e] = 1.0
um = (rng.random((Bn, S, 1)) < p_up).astype(f32)
H = H + um * mix.reshape(Bn, S, d)
if trace:
logit = H.reshape(Bn*S, d) @ W["Wo"] + W["bo"]
frames.append(dict(
pred=logit.reshape(Bn, S, -1).argmax(-1),
updated=um[..., 0].astype(int),
expert=(gate.argmax(1)).reshape(Bn, S),
fired=fired.reshape(Bn, S, K)))
logits = H.reshape(Bn*S, d) @ W["Wo"] + W["bo"]
pred = logits.reshape(Bn, S, -1).argmax(-1)
return (pred, frames) if trace else pred
def params_to_np(P):
return {k: v.d.copy() for k, v in P.items()}