File size: 3,756 Bytes
57f9808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Rigorous, adversarial benchmark for gary-neuron. Reports TRUE exact-match on a
large held-out test set the model never trained on, stress-tests the hardest
long-carry ripples, checks robustness to the random async update order, and
shows the NCA 'train short / run longer' property. Usage:
    python benchmark.py main     # held-out + adversarial + by-length
    python benchmark.py sweep    # inference-steps and update-prob sweeps
"""
import os, sys, json, numpy as np
from data import make_batch, exact_match, digits_rev, to_int
from garyneuron import forward_np

D = os.path.dirname(os.path.abspath(__file__))
CKPT = os.environ.get("CKPT", f"{D}/final.npz")
z = np.load(CKPT, allow_pickle=True)
W = {k[2:]: z[k] for k in z.files if k.startswith("P/")}
cfg = json.loads(str(z["cfg"]))
S = cfg["S"]
mode = sys.argv[1] if len(sys.argv) > 1 else "main"
print(f"# gary-neuron benchmark | step {int(z['step'])} | trained steps={cfg['steps']} p={cfg['p_update']} | {mode}")

def C(**kw):
    c = dict(cfg); c.update(kw); return c

def grids(pairs):
    A = np.array([digits_rev(a, S) for a, b in pairs])
    B = np.array([digits_rev(b, S) for a, b in pairs])
    Y = np.array([digits_rev(a + b, S) for a, b in pairs])
    return A, B, Y

if mode == "main":
    # ---- 1. large held-out test, robustness to async update order ----
    A, B, Y = make_batch(10000, S, np.random.default_rng(20260611), 7)
    ems = [exact_match(forward_np(W, A, B, cfg, np.random.default_rng(s)), Y) for s in range(8)]
    print(f"\n[held-out 10k, ≤7-digit] exact-match across 8 random async orders:")
    print(f"   mean {np.mean(ems)*100:.3f}%  min {np.min(ems)*100:.3f}%  max {np.max(ems)*100:.3f}%  std {np.std(ems)*100:.4f}")

    # ---- 2. adversarial maximal-carry ripples (run a bit longer: steps=28) ----
    hard = []
    for L in range(1, 9):
        hard.append((10**L - 1, 1))            # 99..9 + 1   -> full-length carry
        hard.append((10**L - 1, 10**L - 1))    # 99..9 + 99..9
    hard += [(9999999, 1), (9999999, 9999999), (5555555, 4444445),
             (1234567, 8765433), (9090909, 909091), (7777777, 2222223)]
    HA, HB, HY = grids(hard)
    pred = forward_np(W, HA, HB, C(steps=28), np.random.default_rng(0))
    ok = (pred == HY).all(1)
    print(f"\n[adversarial max-carry, {len(hard)} cases @ steps=28] {int(ok.sum())}/{len(hard)} correct")
    for (a, b), o, p in zip(hard, ok, pred):
        flag = "ok " if o else "MISS"
        if not o:
            print(f"   {flag} {a} + {b} = {a+b}  got {to_int(p[None])[0]}")

    # ---- 3. accuracy by operand length ----
    print("\n[exact-match by max operand length @ steps=24]")
    for L in range(1, 8):
        rng = np.random.default_rng(500 + L)
        lo = 10**(L-1) if L > 1 else 0
        a = rng.integers(lo, 10**L, 4000); b = rng.integers(lo, 10**L, 4000)
        A, B, Y = grids(list(zip(a.tolist(), b.tolist())))
        em = exact_match(forward_np(W, A, B, C(steps=24), np.random.default_rng(0)), Y)
        print(f"   len {L}: {em*100:6.3f}%")

if mode == "sweep":
    A, B, Y = make_batch(10000, S, np.random.default_rng(424242), 7)
    print("\n[inference async-steps sweep]  (trained at 20)")
    for st in [8, 10, 12, 16, 20, 24, 28, 32]:
        ems = [exact_match(forward_np(W, A, B, C(steps=st), np.random.default_rng(s)), Y) for s in range(3)]
        print(f"   steps={st:2d}: exact {np.mean(ems)*100:6.3f}%  (±{np.std(ems)*100:.3f})")
    print("\n[update-probability sweep @ steps=28]   p=1.0 is fully synchronous")
    for p in [0.25, 0.5, 0.75, 1.0]:
        ems = [exact_match(forward_np(W, A, B, C(steps=28, p_update=p), np.random.default_rng(s)), Y) for s in range(3)]
        print(f"   p_update={p}: exact {np.mean(ems)*100:6.3f}%  (±{np.std(ems)*100:.3f})")