File size: 5,368 Bytes
57f9808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python3
"""gary-neuron: a ~34 KB asynchronous Neural Cellular Automaton whose per-cell
rule is a Mixture-of-Experts. It adds integers by letting carries ripple across
a 1-D mesh of cells. Pure numpy + stdlib, no deps, no tokenizer.

Usage:
  python solve.py 1234567 + 7654321      # solve a sum
  python solve.py 9999999 1 --show       # visualise the mesh firing + carry ripple
  python solve.py --vote 9 48591 + 9732  # robust inference (ensemble over async orders)
  python solve.py                        # interactive
"""
import json, sys, os, re
import numpy as np

D = os.path.dirname(os.path.abspath(__file__))
C = json.load(open(f"{D}/config.json"))
S, d, K, TOPK = C["S"], C["state_dim"], C["n_experts"], C["topk"]
P_UPDATE = C["p_update"]
STEPS = C.get("recommended_inference_steps", 24)

z = np.load(f"{D}/gary-neuron.int8.npz")
W = {k: z[k].astype(np.float32) * z[k + ".scale"] for k in z.files if not k.endswith(".scale")}

def _sm(x):
    e = np.exp(x - x.max(-1, keepdims=True)); return e / e.sum(-1, keepdims=True)

def digits_rev(x):
    out = np.zeros(S, np.int64)
    for i in range(S): out[i] = x % 10; x //= 10
    return out

def to_int(row):
    return int(sum(int(row[i]) * (10 ** i) for i in range(S)))

def mesh(A, B, steps=STEPS, p=P_UPDATE, seed=0, trace=False):
    Bn = A.shape[0]
    rng = np.random.default_rng(seed)
    H = W["emb"][A] + W["emb"][B] + W["posemb"][None]
    frames = []
    for t in range(steps):
        Hl = np.zeros_like(H); Hl[:, 1:] = H[:, :-1]
        Hr = np.zeros_like(H); Hr[:, :-1] = H[:, 1:]
        perc = np.concatenate([Hl, H, Hr], -1)
        perc = (perc - perc.mean(-1, keepdims=True)) / np.sqrt(perc.var(-1, keepdims=True) + 1e-5)
        pf = perc.reshape(Bn * S, 3 * d)
        rl = pf @ W["Wr"] + W["br"]
        idx = np.argpartition(-rl, TOPK - 1, axis=1)[:, :TOPK]
        M = np.full_like(rl, -1e9); np.put_along_axis(M, idx, 0.0, axis=1)
        gate = _sm(rl + M)
        mix = np.zeros((Bn * S, d), np.float32)
        for e in range(K):
            ge = gate[:, e]; act = ge > 0
            if act.any():
                h1 = np.maximum(pf[act] @ W[f"e{e}.W1"] + W[f"e{e}.b1"], 0)
                mix[act] += ge[act, None] * (h1 @ W[f"e{e}.W2"] + W[f"e{e}.b2"])
        um = (rng.random((Bn, S, 1)) < p).astype(np.float32)
        H = H + um * mix.reshape(Bn, S, d)
        if trace:
            lg = H.reshape(Bn * S, d) @ W["Wo"] + W["bo"]
            frames.append((lg.reshape(Bn, S, -1).argmax(-1),
                           um[..., 0].astype(int), gate.argmax(1).reshape(Bn, S)))
    logits = H.reshape(Bn * S, d) @ W["Wo"] + W["bo"]
    pred = logits.reshape(Bn, S, -1).argmax(-1)
    return (pred, frames) if trace else pred

def solve(a, b, vote=1, steps=STEPS):
    A = digits_rev(a)[None]; B = digits_rev(b)[None]
    if vote <= 1:
        return to_int(mesh(A, B, steps=steps, seed=0)[0])
    preds = np.stack([mesh(A, B, steps=steps, seed=s)[0] for s in range(vote)])  # (vote,S)
    maj = np.stack([(preds == g).sum(0) for g in range(10)]).argmax(0)           # (S,)
    return to_int(maj)

def show(a, b, steps=STEPS):
    A = digits_rev(a)[None]; B = digits_rev(b)[None]
    pred, frames = mesh(A, B, steps=steps, seed=0, trace=True)
    hdr = "  ".join(f"{i}" for i in range(S - 1, -1, -1))
    print(f"\n  {a} + {b}   (mesh = {S} cells, {K} experts, top-{TOPK}, async p={P_UPDATE}, {steps} steps)")
    print(f"  digit place (10^):   {hdr}")
    print("  " + "-" * (4 * S + 20))
    for t, (pr, upd, exp) in enumerate(frames):
        dig = "  ".join(str(pr[0, i]) for i in range(S - 1, -1, -1))
        fire = "  ".join((str(exp[0, i]) if upd[0, i] else "·") for i in range(S - 1, -1, -1))
        print(f"  step {t:2d} digits: {dig}   |  fired(expert#): {fire}   = {to_int(pr[0])}")
    print("  " + "-" * (4 * S + 20))
    ans = to_int(pred[0]); truth = a + b
    print(f"  => {a} + {b} = {ans}   {'OK' if ans == truth else 'X (truth ' + str(truth) + ')'}")
    print("  '·' = cell did not fire this step; digits settle as the carry ripples low->high.\n")

def parse(s):
    n = re.findall(r"\d+", s)
    if len(n) < 2: return None
    return int(n[0]), int(n[1])

if __name__ == "__main__":
    args = sys.argv[1:]
    vote = 1; steps = STEPS; doshow = False
    if "--show" in args: doshow = True; args.remove("--show")
    if "--vote" in args:
        i = args.index("--vote"); vote = int(args[i + 1]); del args[i:i + 2]
    if "--steps" in args:
        i = args.index("--steps"); steps = int(args[i + 1]); del args[i:i + 2]
    if args:
        pr = parse(" ".join(args))
        if not pr:
            print("give me two non-negative integers, e.g.  python solve.py 123 + 456"); sys.exit()
        a, b = pr
        if a + b >= 10 ** S:
            print(f"(sum exceeds {S} digits - this mesh has {S} cells; train a wider strip for bigger sums)")
        if doshow: show(a, b, steps=steps)
        else: print(solve(a, b, vote=vote, steps=steps))
    else:
        print("gary-neuron (~34 KB). Type a sum like '1234 + 5678'. Add '#' for the mesh view. ctrl-c to exit.")
        while True:
            try: msg = input("\nsum: ")
            except (EOFError, KeyboardInterrupt): break
            sh = "#" in msg; pr = parse(msg)
            if not pr: continue
            if sh: show(*pr)
            else: print(f"  = {solve(*pr, vote=9)}")