File size: 13,695 Bytes
57f9808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
"""gary-neuron: an asynchronous Neural Cellular Automaton whose per-cell update
rule is a Mixture-of-Experts. Pure NumPy. Tiny reverse-mode autograd written
from scratch (no torch) so the MoE+NCA gradients are correct-by-construction.

The mesh is a 1-D strip of S cells, one per digit position (reversed: cell 0 is
the least-significant digit). Each async step a random subset of cells update;
every updating cell perceives [left, self, right] and routes that perception
through a router -> top-k experts. Carry ripples low->high across cells over
time, exactly like ripple-carry addition. Trained on reversed-digit addition
(Lee et al. 2023), which is the format that makes exact addition learnable.

This single file is BOTH the training brain (autograd Tensor `T` + model graph)
and the dependency-free inference engine (`forward_np`, used by solve.py).
"""
import numpy as np

f32 = np.float32

# ======================================================================
#  minimal reverse-mode autograd over numpy
# ======================================================================
class T:
    """A node in the tape: holds data `.d`, grad `.g`, and a backward closure."""
    __slots__ = ("d", "g", "_bw", "_prev", "req")
    def __init__(self, data, req=True, _prev=()):
        self.d = np.asarray(data, dtype=f32)
        self.g = None
        self._bw = _noop
        self._prev = _prev
        self.req = req
    def _acc(self, grad):
        self.g = grad if self.g is None else self.g + grad
    def backward(self):
        topo, seen = [], set()
        def build(v):
            if id(v) in seen: return
            seen.add(id(v))
            for p in v._prev: build(p)
            topo.append(v)
        build(self)
        self.g = np.ones_like(self.d)
        for v in reversed(topo):
            v._bw()
    # operator sugar
    def __add__(self, o): return add(self, o)
    def __mul__(self, o): return mul(self, o)

def _noop(): pass

def _unbroadcast(grad, shape):
    """Sum `grad` down to `shape` (reverse of numpy broadcasting)."""
    while grad.ndim > len(shape):
        grad = grad.sum(0)
    for i, s in enumerate(shape):
        if s == 1 and grad.shape[i] != 1:
            grad = grad.sum(i, keepdims=True)
    return grad

def add(a, b):
    out = T(a.d + b.d, _prev=(a, b))
    def bw():
        if a.req: a._acc(_unbroadcast(out.g, a.d.shape))
        if b.req: b._acc(_unbroadcast(out.g, b.d.shape))
    out._bw = bw; return out

def mul(a, b):
    out = T(a.d * b.d, _prev=(a, b))
    def bw():
        if a.req: a._acc(_unbroadcast(out.g * b.d, a.d.shape))
        if b.req: b._acc(_unbroadcast(out.g * a.d, b.d.shape))
    out._bw = bw; return out

def matmul(a, b):
    out = T(a.d @ b.d, _prev=(a, b))
    def bw():
        if a.req: a._acc(out.g @ b.d.T)
        if b.req: b._acc(a.d.T @ out.g)
    out._bw = bw; return out

def mulc(x, c):
    """Multiply by a constant numpy array (mask / gate that carries no grad)."""
    out = T(x.d * c, _prev=(x,))
    def bw():
        if x.req: x._acc(out.g * c)
    out._bw = bw; return out

def addc(x, c):
    out = T(x.d + c, _prev=(x,))
    def bw():
        if x.req: x._acc(out.g)
    out._bw = bw; return out

def relu(a):
    m = (a.d > 0)
    out = T(a.d * m, _prev=(a,))
    def bw():
        if a.req: a._acc(out.g * m)
    out._bw = bw; return out

def reshape(x, shape):
    out = T(x.d.reshape(shape), _prev=(x,))
    def bw():
        if x.req: x._acc(out.g.reshape(x.d.shape))
    out._bw = bw; return out

def gather(W, idx):
    """Rows of W (V,d) selected by int idx (N,) -> (N,d)."""
    out = T(W.d[idx], _prev=(W,))
    def bw():
        if W.req:
            g = np.zeros_like(W.d); np.add.at(g, idx, out.g); W._acc(g)
    out._bw = bw; return out

def concat_last(ts):
    out = T(np.concatenate([t.d for t in ts], axis=-1), _prev=tuple(ts))
    sizes = [t.d.shape[-1] for t in ts]
    def bw():
        i = 0
        for t, s in zip(ts, sizes):
            if t.req: t._acc(out.g[..., i:i+s])
            i += s
    out._bw = bw; return out

def shift_from_left(x):
    """neighbour i-1: out[:,i]=x[:,i-1], out[:,0]=0.  (the carry source)"""
    d = np.zeros_like(x.d); d[:, 1:] = x.d[:, :-1]
    out = T(d, _prev=(x,))
    def bw():
        if x.req:
            g = np.zeros_like(x.d); g[:, :-1] += out.g[:, 1:]; x._acc(g)
    out._bw = bw; return out

def shift_from_right(x):
    """neighbour i+1: out[:,i]=x[:,i+1], out[:,-1]=0."""
    d = np.zeros_like(x.d); d[:, :-1] = x.d[:, 1:]
    out = T(d, _prev=(x,))
    def bw():
        if x.req:
            g = np.zeros_like(x.d); g[:, 1:] += out.g[:, :-1]; x._acc(g)
    out._bw = bw; return out

def softmax(x):
    e = np.exp(x.d - x.d.max(-1, keepdims=True)); p = e / e.sum(-1, keepdims=True)
    out = T(p, _prev=(x,))
    def bw():
        if x.req:
            g = out.g; dot = (g * p).sum(-1, keepdims=True)
            x._acc(p * (g - dot))
    out._bw = bw; return out

def col(x, e):
    out = T(x.d[:, e:e+1], _prev=(x,))
    def bw():
        if x.req:
            g = np.zeros_like(x.d); g[:, e:e+1] = out.g; x._acc(g)
    out._bw = bw; return out

def layernorm(x, eps=1e-5):
    mu = x.d.mean(-1, keepdims=True); xc = x.d - mu
    var = (xc * xc).mean(-1, keepdims=True); inv = 1.0 / np.sqrt(var + eps)
    y = xc * inv
    out = T(y, _prev=(x,))
    def bw():
        if x.req:
            g = out.g
            dx = inv * (g - g.mean(-1, keepdims=True) - y * (g * y).mean(-1, keepdims=True))
            x._acc(dx)
    out._bw = bw; return out

def mean0(x):
    """mean over axis 0: (N,K)->(K,)"""
    N = x.d.shape[0]
    out = T(x.d.mean(0), _prev=(x,))
    def bw():
        if x.req: x._acc(np.broadcast_to(out.g / N, x.d.shape).copy())
    out._bw = bw; return out

def sumall(x):
    out = T(np.asarray(x.d.sum(), dtype=f32), _prev=(x,))
    def bw():
        if x.req: x._acc(np.ones_like(x.d) * out.g)
    out._bw = bw; return out

def cross_entropy(logits, targets):
    """logits (N,C) T, targets (N,) int -> scalar mean CE."""
    z = logits.d - logits.d.max(-1, keepdims=True)
    e = np.exp(z); p = e / e.sum(-1, keepdims=True)
    N = z.shape[0]
    loss = -np.log(p[np.arange(N), targets] + 1e-9).mean()
    out = T(np.asarray(loss, dtype=f32), _prev=(logits,))
    def bw():
        g = p.copy(); g[np.arange(N), targets] -= 1.0; g /= N
        if logits.req: logits._acc(g * out.g)
    out._bw = bw; return out

# ======================================================================
#  helpers that produce CONSTANTS (no grad): routing & async masks
# ======================================================================
def topk_addmask(logits, k):
    """Additive mask: 0.0 on the top-k logits per row, -1e9 elsewhere."""
    N, K = logits.shape
    if k >= K:
        return np.zeros_like(logits)
    idx = np.argpartition(-logits, k - 1, axis=1)[:, :k]
    M = np.full_like(logits, -1e9, dtype=f32)
    np.put_along_axis(M, idx, 0.0, axis=1)
    return M

def async_mask(B, S, rng, p):
    """1 where a cell updates this step, else 0.  shape (B,S,1)."""
    return (rng.random((B, S, 1)) < p).astype(f32)

# ======================================================================
#  model
# ======================================================================
def default_cfg():
    return dict(S=8, d=32, he=32, K=6, topk=2, steps=18, p_update=0.5,
                Vin=10, Vout=10, aux=0.01)

def init_params(cfg, seed=1337):
    rng = np.random.default_rng(seed)
    d, he, K, S = cfg["d"], cfg["he"], cfg["K"], cfg["S"]
    Vin, Vout = cfg["Vin"], cfg["Vout"]
    P = {}
    P["emb"]    = T(rng.normal(0, 0.08, (Vin, d)).astype(f32))   # shared digit embedding (a and b added)
    P["posemb"] = T(rng.normal(0, 0.02, (S, d)).astype(f32))
    P["Wr"]     = T(rng.normal(0, 0.08, (3*d, K)).astype(f32))   # router
    P["br"]     = T(np.zeros(K, f32))
    eo = 0.08 / np.sqrt(he)                                      # tiny expert output -> near-identity dynamics at init
    for e in range(K):
        P[f"e{e}.W1"] = T(rng.normal(0, 0.08, (3*d, he)).astype(f32))
        P[f"e{e}.b1"] = T(np.zeros(he, f32))
        P[f"e{e}.W2"] = T((rng.normal(0, eo, (he, d))).astype(f32))
        P[f"e{e}.b2"] = T(np.zeros(d, f32))
    P["Wo"] = T(rng.normal(0, 0.10, (d, Vout)).astype(f32))      # readout
    P["bo"] = T(np.zeros(Vout, f32))
    return P

def n_params(P):
    return int(sum(v.d.size for v in P.values()))

def forward(P, A, B, Y, cfg, rng, train=True, collect=False):
    """Graph forward. A,B,Y int arrays (Bn,S). Returns (total_loss_T, info)."""
    Bn, S = A.shape
    d, K, topk, steps = cfg["d"], cfg["K"], cfg["topk"], cfg["steps"]
    p_up = cfg["p_update"] if train else cfg["p_update"]
    N = Bn * S
    Wr, br, Wo, bo = P["Wr"], P["br"], P["Wo"], P["bo"]

    ha = reshape(gather(P["emb"], A.reshape(-1)), (Bn, S, d))
    hb = reshape(gather(P["emb"], B.reshape(-1)), (Bn, S, d))
    H = add(add(ha, hb), P["posemb"])                          # (Bn,S,d) broadcast posemb
    router_probs_accum = []
    load_counts = np.zeros(K, dtype=f32)

    for t in range(steps):
        Hl = shift_from_left(H)
        Hr = shift_from_right(H)
        perc = layernorm(concat_last([Hl, H, Hr]))            # (Bn,S,3d)
        pf = reshape(perc, (N, 3*d))
        rl = add(matmul(pf, Wr), br)                          # (N,K) router logits
        M = topk_addmask(rl.d, topk)                          # constant top-k mask
        gate = softmax(addc(rl, M))                           # (N,K) -> only top-k nonzero
        if cfg["aux"] > 0:
            router_probs_accum.append(softmax(rl))            # soft probs for load-balance aux
            load_counts += np.bincount(rl.d.argmax(1), minlength=K).astype(f32)
        mix = None
        for e in range(K):
            h1 = relu(add(matmul(pf, P[f"e{e}.W1"]), P[f"e{e}.b1"]))
            oe = add(matmul(h1, P[f"e{e}.W2"]), P[f"e{e}.b2"])  # (N,d)
            ge = mul(col(gate, e), oe)                          # gate broadcast (N,1)*(N,d)
            mix = ge if mix is None else add(mix, ge)
        um = async_mask(Bn, S, rng, p_up)
        H = add(H, mulc(reshape(mix, (Bn, S, d)), um))         # masked residual update

    Hf = reshape(H, (N, d))
    logits = add(matmul(Hf, Wo), bo)                          # (N,Vout)
    loss = cross_entropy(logits, Y.reshape(-1))
    total = loss
    info = {"loss": float(loss.d)}
    if cfg["aux"] > 0 and router_probs_accum:
        Pbar = mean0(concat_rows(router_probs_accum))         # (K,) mean soft prob
        f = load_counts / max(load_counts.sum(), 1.0)         # fraction routed (top-1), constant
        aux = mulc(sumall(mulc(Pbar, f * K)), f32(cfg["aux"]))
        total = add(loss, aux)
        info["aux"] = float(aux.d)
        info["load"] = f
    if collect:
        info["pred"] = logits.d.reshape(Bn, S, -1).argmax(-1)
    return total, info

def concat_rows(ts):
    """stack a list of (N,K) T along axis 0 -> (sum N, K) T."""
    out = T(np.concatenate([t.d for t in ts], axis=0), _prev=tuple(ts))
    sizes = [t.d.shape[0] for t in ts]
    def bw():
        i = 0
        for t, s in zip(ts, sizes):
            if t.req: t._acc(out.g[i:i+s]);
            i += s
    out._bw = bw; return out

# ======================================================================
#  dependency-free inference (also used by solve.py).  numpy only, no tape.
# ======================================================================
def _sm(x):
    e = np.exp(x - x.max(-1, keepdims=True)); return e / e.sum(-1, keepdims=True)

def forward_np(W, A, B, cfg, rng, trace=False):
    """Run the mesh with plain numpy weights `W` (dict of arrays). Returns
    predicted digit grid (Bn,S) and, if trace, a per-step record for viz.
    At inference top-k experts are the only ones evaluated (true sparsity)."""
    Bn, S = A.shape
    d, K, topk, steps, p_up = cfg["d"], cfg["K"], cfg["topk"], cfg["steps"], cfg["p_update"]
    emb, pos = W["emb"], W["posemb"]
    H = emb[A] + emb[B] + pos[None]                            # (Bn,S,d)
    frames = []
    for t in range(steps):
        Hl = np.zeros_like(H); Hl[:, 1:] = H[:, :-1]
        Hr = np.zeros_like(H); Hr[:, :-1] = H[:, 1:]
        perc = np.concatenate([Hl, H, Hr], axis=-1)
        mu = perc.mean(-1, keepdims=True); v = perc.var(-1, keepdims=True)
        perc = (perc - mu) / np.sqrt(v + 1e-5)
        pf = perc.reshape(Bn*S, 3*d)
        rl = pf @ W["Wr"] + W["br"]
        # top-k selection
        idx = np.argpartition(-rl, topk-1, axis=1)[:, :topk]
        M = np.full_like(rl, -1e9); np.put_along_axis(M, idx, 0.0, axis=1)
        gate = _sm(rl + M)
        mix = np.zeros((Bn*S, d), f32)
        fired = np.zeros((Bn*S, K), f32)
        for e in range(K):
            ge = gate[:, e]
            act = ge > 0
            if not act.any():
                continue
            h1 = np.maximum(pf[act] @ W[f"e{e}.W1"] + W[f"e{e}.b1"], 0)
            oe = h1 @ W[f"e{e}.W2"] + W[f"e{e}.b2"]
            mix[act] += ge[act, None] * oe
            fired[act, e] = 1.0
        um = (rng.random((Bn, S, 1)) < p_up).astype(f32)
        H = H + um * mix.reshape(Bn, S, d)
        if trace:
            logit = H.reshape(Bn*S, d) @ W["Wo"] + W["bo"]
            frames.append(dict(
                pred=logit.reshape(Bn, S, -1).argmax(-1),
                updated=um[..., 0].astype(int),
                expert=(gate.argmax(1)).reshape(Bn, S),
                fired=fired.reshape(Bn, S, K)))
    logits = H.reshape(Bn*S, d) @ W["Wo"] + W["bo"]
    pred = logits.reshape(Bn, S, -1).argmax(-1)
    return (pred, frames) if trace else pred

def params_to_np(P):
    return {k: v.d.copy() for k, v in P.items()}