| """Time-budgeted training burst for gary-neuron, with checkpoint/resume so it |
| survives short shell timeouts (same pattern as gary-4-petite). Trains the async |
| NCA+MoE on reversed-digit addition; evaluates true exact-match accuracy with the |
| dependency-free numpy engine each burst.""" |
| import os, json, time, math, numpy as np |
| from garyneuron import init_params, forward, forward_np, params_to_np, n_params, default_cfg |
| from data import make_batch, exact_match, gen_hard |
|
|
| D = os.path.dirname(os.path.abspath(__file__)) |
| SEC = float(os.environ.get("SEC", "35")) |
| CKPT = os.environ.get("CKPT", f"{D}/ckpt.npz") |
| LOG = os.environ.get("LOG", f"{D}/train.log") |
| BS = int(os.environ.get("BS", "256")) |
| LR = float(os.environ.get("LR", "3e-3")); LRMIN = LR * 0.05 |
| WARM = int(os.environ.get("WARM", "150")); TMAX = int(os.environ.get("TMAX", "8000")) |
|
|
| cfg = default_cfg() |
| for k in ["S", "d", "he", "K", "topk", "steps"]: |
| if k in os.environ: cfg[k] = int(os.environ[k]) |
| if "p_update" in os.environ: cfg["p_update"] = float(os.environ["p_update"]) |
| if "aux" in os.environ: cfg["aux"] = float(os.environ["aux"]) |
| MAXDIG = int(os.environ.get("MAXDIG", cfg["S"] - 1)) |
| HARD = float(os.environ.get("HARD", "0.0")) |
|
|
| def batch(bs): |
| if HARD > 0: |
| nh = int(bs * HARD) |
| a1, b1, y1 = make_batch(bs - nh, cfg["S"], rng, MAXDIG) |
| a2, b2, y2 = gen_hard(nh, cfg["S"], rng) |
| return (np.concatenate([a1, a2]), np.concatenate([b1, b2]), np.concatenate([y1, y2])) |
| return make_batch(bs, cfg["S"], rng, MAXDIG) |
|
|
| class Adam: |
| def __init__(self, P, lr, b1=0.9, b2=0.99, wd=1e-4, eps=1e-8): |
| self.P, self.lr, self.b1, self.b2, self.wd, self.eps = P, lr, b1, b2, wd, eps |
| self.m = {k: np.zeros_like(v.d) for k, v in P.items()} |
| self.v = {k: np.zeros_like(v.d) for k, v in P.items()} |
| self.t = 0 |
| def step(self, lr): |
| self.t += 1; b1, b2 = self.b1, self.b2 |
| bc1 = 1 - b1 ** self.t; bc2 = 1 - b2 ** self.t |
| for k, p in self.P.items(): |
| g = p.g |
| if g is None: continue |
| self.m[k] = b1 * self.m[k] + (1 - b1) * g |
| self.v[k] = b2 * self.v[k] + (1 - b2) * (g * g) |
| upd = (self.m[k] / bc1) / (np.sqrt(self.v[k] / bc2) + self.eps) |
| if ".W" in k or k in ("Wr", "Wo"): |
| upd = upd + self.wd * p.d |
| p.d -= lr * upd |
|
|
| def clip(P, maxn=1.0): |
| tot = math.sqrt(sum(float((p.g * p.g).sum()) for p in P.values() if p.g is not None)) |
| if tot > maxn: |
| s = maxn / (tot + 1e-6) |
| for p in P.values(): |
| if p.g is not None: p.g *= s |
| return tot |
|
|
| def lr_at(s): |
| if s < WARM: return LR * s / WARM |
| if s >= TMAX: return LRMIN |
| r = (s - WARM) / (TMAX - WARM) |
| return LRMIN + 0.5 * (LR - LRMIN) * (1 + math.cos(math.pi * r)) |
|
|
| |
| if os.path.exists(CKPT): |
| z = np.load(CKPT, allow_pickle=True) |
| from garyneuron import T |
| P = {k[2:]: T(z[k].copy()) for k in z.files if k.startswith("P/")} |
| cfg = json.loads(str(z["cfg"])) |
| opt = Adam(P, LR); opt.t = int(z["t"]) |
| for k in P: opt.m[k] = z["m/" + k].copy(); opt.v[k] = z["v/" + k].copy() |
| step = int(z["step"]); rng = np.random.default_rng(1000 + step) |
| else: |
| P = init_params(cfg, seed=1337) |
| opt = Adam(P, LR); step = 0; rng = np.random.default_rng(0) |
|
|
| NP = n_params(P) |
|
|
| def save(): |
| d = {"P/" + k: v.d for k, v in P.items()} |
| d.update({"m/" + k: opt.m[k] for k in P}); d.update({"v/" + k: opt.v[k] for k in P}) |
| d["step"] = step; d["t"] = opt.t; d["cfg"] = json.dumps(cfg) |
| np.savez(CKPT[:-4], **d) |
|
|
| def evaluate(n=2000): |
| Wnp = params_to_np(P) |
| va, vb, vy = make_batch(n, cfg["S"], np.random.default_rng(987654), MAXDIG) |
| pred = forward_np(Wnp, va, vb, cfg, np.random.default_rng(321)) |
| return exact_match(pred, vy), float((pred == vy).mean()) |
|
|
| |
| t0 = time.time(); losses = []; nst = 0 |
| while time.time() - t0 < SEC: |
| A, Bb, Y = batch(BS) |
| for v in P.values(): v.g = None |
| tot, info = forward(P, A, Bb, Y, cfg, rng, train=True) |
| tot.backward(); clip(P, 1.0); opt.step(lr_at(step)) |
| step += 1; nst += 1; losses.append(info["loss"]) |
| dt = time.time() - t0 |
| save() |
| em, da = evaluate() |
| load = info.get("load") |
| loadstr = (" | load " + ",".join(f"{x:.2f}" for x in load)) if load is not None else "" |
| msg = (f"step {step:5d} | loss {np.mean(losses[-50:]):.4f} | exact {em*100:6.2f}% | " |
| f"digit {da*100:6.2f}% | lr {lr_at(step):.2e} | {nst}st/{dt:.0f}s | " |
| f"n={NP} | S={cfg['S']} maxdig={MAXDIG} K={cfg['K']} top{cfg['topk']} " |
| f"steps={cfg['steps']} p={cfg['p_update']}{loadstr}") |
| print(msg) |
| open(LOG, "a").write(msg + "\n") |
|
|