"""Time-budgeted training burst for gary-neuron, with checkpoint/resume so it survives short shell timeouts (same pattern as gary-4-petite). Trains the async NCA+MoE on reversed-digit addition; evaluates true exact-match accuracy with the dependency-free numpy engine each burst.""" import os, json, time, math, numpy as np from garyneuron import init_params, forward, forward_np, params_to_np, n_params, default_cfg from data import make_batch, exact_match, gen_hard D = os.path.dirname(os.path.abspath(__file__)) SEC = float(os.environ.get("SEC", "35")) CKPT = os.environ.get("CKPT", f"{D}/ckpt.npz") LOG = os.environ.get("LOG", f"{D}/train.log") BS = int(os.environ.get("BS", "256")) LR = float(os.environ.get("LR", "3e-3")); LRMIN = LR * 0.05 WARM = int(os.environ.get("WARM", "150")); TMAX = int(os.environ.get("TMAX", "8000")) cfg = default_cfg() for k in ["S", "d", "he", "K", "topk", "steps"]: if k in os.environ: cfg[k] = int(os.environ[k]) if "p_update" in os.environ: cfg["p_update"] = float(os.environ["p_update"]) if "aux" in os.environ: cfg["aux"] = float(os.environ["aux"]) MAXDIG = int(os.environ.get("MAXDIG", cfg["S"] - 1)) HARD = float(os.environ.get("HARD", "0.0")) # fraction of batch drawn from carry-heavy hard cases def batch(bs): if HARD > 0: nh = int(bs * HARD) a1, b1, y1 = make_batch(bs - nh, cfg["S"], rng, MAXDIG) a2, b2, y2 = gen_hard(nh, cfg["S"], rng) return (np.concatenate([a1, a2]), np.concatenate([b1, b2]), np.concatenate([y1, y2])) return make_batch(bs, cfg["S"], rng, MAXDIG) class Adam: def __init__(self, P, lr, b1=0.9, b2=0.99, wd=1e-4, eps=1e-8): self.P, self.lr, self.b1, self.b2, self.wd, self.eps = P, lr, b1, b2, wd, eps self.m = {k: np.zeros_like(v.d) for k, v in P.items()} self.v = {k: np.zeros_like(v.d) for k, v in P.items()} self.t = 0 def step(self, lr): self.t += 1; b1, b2 = self.b1, self.b2 bc1 = 1 - b1 ** self.t; bc2 = 1 - b2 ** self.t for k, p in self.P.items(): g = p.g if g is None: continue self.m[k] = b1 * self.m[k] + (1 - b1) * g self.v[k] = b2 * self.v[k] + (1 - b2) * (g * g) upd = (self.m[k] / bc1) / (np.sqrt(self.v[k] / bc2) + self.eps) if ".W" in k or k in ("Wr", "Wo"): upd = upd + self.wd * p.d # decoupled wd on matmul weights only p.d -= lr * upd def clip(P, maxn=1.0): tot = math.sqrt(sum(float((p.g * p.g).sum()) for p in P.values() if p.g is not None)) if tot > maxn: s = maxn / (tot + 1e-6) for p in P.values(): if p.g is not None: p.g *= s return tot def lr_at(s): if s < WARM: return LR * s / WARM if s >= TMAX: return LRMIN r = (s - WARM) / (TMAX - WARM) return LRMIN + 0.5 * (LR - LRMIN) * (1 + math.cos(math.pi * r)) # ---- resume or init ---- if os.path.exists(CKPT): z = np.load(CKPT, allow_pickle=True) from garyneuron import T P = {k[2:]: T(z[k].copy()) for k in z.files if k.startswith("P/")} cfg = json.loads(str(z["cfg"])) opt = Adam(P, LR); opt.t = int(z["t"]) for k in P: opt.m[k] = z["m/" + k].copy(); opt.v[k] = z["v/" + k].copy() step = int(z["step"]); rng = np.random.default_rng(1000 + step) else: P = init_params(cfg, seed=1337) opt = Adam(P, LR); step = 0; rng = np.random.default_rng(0) NP = n_params(P) def save(): d = {"P/" + k: v.d for k, v in P.items()} d.update({"m/" + k: opt.m[k] for k in P}); d.update({"v/" + k: opt.v[k] for k in P}) d["step"] = step; d["t"] = opt.t; d["cfg"] = json.dumps(cfg) np.savez(CKPT[:-4], **d) def evaluate(n=2000): Wnp = params_to_np(P) va, vb, vy = make_batch(n, cfg["S"], np.random.default_rng(987654), MAXDIG) # fixed val set pred = forward_np(Wnp, va, vb, cfg, np.random.default_rng(321)) return exact_match(pred, vy), float((pred == vy).mean()) # ---- train ---- t0 = time.time(); losses = []; nst = 0 while time.time() - t0 < SEC: A, Bb, Y = batch(BS) for v in P.values(): v.g = None tot, info = forward(P, A, Bb, Y, cfg, rng, train=True) tot.backward(); clip(P, 1.0); opt.step(lr_at(step)) step += 1; nst += 1; losses.append(info["loss"]) dt = time.time() - t0 save() em, da = evaluate() load = info.get("load") loadstr = (" | load " + ",".join(f"{x:.2f}" for x in load)) if load is not None else "" msg = (f"step {step:5d} | loss {np.mean(losses[-50:]):.4f} | exact {em*100:6.2f}% | " f"digit {da*100:6.2f}% | lr {lr_at(step):.2e} | {nst}st/{dt:.0f}s | " f"n={NP} | S={cfg['S']} maxdig={MAXDIG} K={cfg['K']} top{cfg['topk']} " f"steps={cfg['steps']} p={cfg['p_update']}{loadstr}") print(msg) open(LOG, "a").write(msg + "\n")