File size: 4,793 Bytes
57f9808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Time-budgeted training burst for gary-neuron, with checkpoint/resume so it
survives short shell timeouts (same pattern as gary-4-petite). Trains the async
NCA+MoE on reversed-digit addition; evaluates true exact-match accuracy with the
dependency-free numpy engine each burst."""
import os, json, time, math, numpy as np
from garyneuron import init_params, forward, forward_np, params_to_np, n_params, default_cfg
from data import make_batch, exact_match, gen_hard

D = os.path.dirname(os.path.abspath(__file__))
SEC   = float(os.environ.get("SEC", "35"))
CKPT  = os.environ.get("CKPT", f"{D}/ckpt.npz")
LOG   = os.environ.get("LOG",  f"{D}/train.log")
BS    = int(os.environ.get("BS", "256"))
LR    = float(os.environ.get("LR", "3e-3")); LRMIN = LR * 0.05
WARM  = int(os.environ.get("WARM", "150")); TMAX = int(os.environ.get("TMAX", "8000"))

cfg = default_cfg()
for k in ["S", "d", "he", "K", "topk", "steps"]:
    if k in os.environ: cfg[k] = int(os.environ[k])
if "p_update" in os.environ: cfg["p_update"] = float(os.environ["p_update"])
if "aux" in os.environ:      cfg["aux"]      = float(os.environ["aux"])
MAXDIG = int(os.environ.get("MAXDIG", cfg["S"] - 1))
HARD = float(os.environ.get("HARD", "0.0"))     # fraction of batch drawn from carry-heavy hard cases

def batch(bs):
    if HARD > 0:
        nh = int(bs * HARD)
        a1, b1, y1 = make_batch(bs - nh, cfg["S"], rng, MAXDIG)
        a2, b2, y2 = gen_hard(nh, cfg["S"], rng)
        return (np.concatenate([a1, a2]), np.concatenate([b1, b2]), np.concatenate([y1, y2]))
    return make_batch(bs, cfg["S"], rng, MAXDIG)

class Adam:
    def __init__(self, P, lr, b1=0.9, b2=0.99, wd=1e-4, eps=1e-8):
        self.P, self.lr, self.b1, self.b2, self.wd, self.eps = P, lr, b1, b2, wd, eps
        self.m = {k: np.zeros_like(v.d) for k, v in P.items()}
        self.v = {k: np.zeros_like(v.d) for k, v in P.items()}
        self.t = 0
    def step(self, lr):
        self.t += 1; b1, b2 = self.b1, self.b2
        bc1 = 1 - b1 ** self.t; bc2 = 1 - b2 ** self.t
        for k, p in self.P.items():
            g = p.g
            if g is None: continue
            self.m[k] = b1 * self.m[k] + (1 - b1) * g
            self.v[k] = b2 * self.v[k] + (1 - b2) * (g * g)
            upd = (self.m[k] / bc1) / (np.sqrt(self.v[k] / bc2) + self.eps)
            if ".W" in k or k in ("Wr", "Wo"):
                upd = upd + self.wd * p.d          # decoupled wd on matmul weights only
            p.d -= lr * upd

def clip(P, maxn=1.0):
    tot = math.sqrt(sum(float((p.g * p.g).sum()) for p in P.values() if p.g is not None))
    if tot > maxn:
        s = maxn / (tot + 1e-6)
        for p in P.values():
            if p.g is not None: p.g *= s
    return tot

def lr_at(s):
    if s < WARM: return LR * s / WARM
    if s >= TMAX: return LRMIN
    r = (s - WARM) / (TMAX - WARM)
    return LRMIN + 0.5 * (LR - LRMIN) * (1 + math.cos(math.pi * r))

# ---- resume or init ----
if os.path.exists(CKPT):
    z = np.load(CKPT, allow_pickle=True)
    from garyneuron import T
    P = {k[2:]: T(z[k].copy()) for k in z.files if k.startswith("P/")}
    cfg = json.loads(str(z["cfg"]))
    opt = Adam(P, LR); opt.t = int(z["t"])
    for k in P: opt.m[k] = z["m/" + k].copy(); opt.v[k] = z["v/" + k].copy()
    step = int(z["step"]); rng = np.random.default_rng(1000 + step)
else:
    P = init_params(cfg, seed=1337)
    opt = Adam(P, LR); step = 0; rng = np.random.default_rng(0)

NP = n_params(P)

def save():
    d = {"P/" + k: v.d for k, v in P.items()}
    d.update({"m/" + k: opt.m[k] for k in P}); d.update({"v/" + k: opt.v[k] for k in P})
    d["step"] = step; d["t"] = opt.t; d["cfg"] = json.dumps(cfg)
    np.savez(CKPT[:-4], **d)

def evaluate(n=2000):
    Wnp = params_to_np(P)
    va, vb, vy = make_batch(n, cfg["S"], np.random.default_rng(987654), MAXDIG)  # fixed val set
    pred = forward_np(Wnp, va, vb, cfg, np.random.default_rng(321))
    return exact_match(pred, vy), float((pred == vy).mean())

# ---- train ----
t0 = time.time(); losses = []; nst = 0
while time.time() - t0 < SEC:
    A, Bb, Y = batch(BS)
    for v in P.values(): v.g = None
    tot, info = forward(P, A, Bb, Y, cfg, rng, train=True)
    tot.backward(); clip(P, 1.0); opt.step(lr_at(step))
    step += 1; nst += 1; losses.append(info["loss"])
dt = time.time() - t0
save()
em, da = evaluate()
load = info.get("load")
loadstr = (" | load " + ",".join(f"{x:.2f}" for x in load)) if load is not None else ""
msg = (f"step {step:5d} | loss {np.mean(losses[-50:]):.4f} | exact {em*100:6.2f}% | "
       f"digit {da*100:6.2f}% | lr {lr_at(step):.2e} | {nst}st/{dt:.0f}s | "
       f"n={NP} | S={cfg['S']} maxdig={MAXDIG} K={cfg['K']} top{cfg['topk']} "
       f"steps={cfg['steps']} p={cfg['p_update']}{loadstr}")
print(msg)
open(LOG, "a").write(msg + "\n")