"""Finite-difference check of the from-scratch autograd. If analytic grads match numeric grads, the MoE+NCA backward is correct and training can be trusted.""" import numpy as np from garyneuron import init_params, forward from data import make_batch def run(aux=0.0): cfg = dict(S=4, d=6, he=6, K=3, topk=2, steps=3, p_update=0.6, Vin=10, Vout=10, aux=aux) P = init_params(cfg, seed=0) A, B, Y = make_batch(5, cfg["S"], np.random.default_rng(1)) MASK_SEED = 42 def fwd(): return forward(P, A, B, Y, cfg, np.random.default_rng(MASK_SEED), train=True)[0] tot = fwd() for v in P.values(): v.g = None tot.backward() worst = 0.0 for name in ["Wr", "br", "e0.W2", "e1.W1", "emb", "posemb", "Wo", "bo"]: p = P[name]; flat = p.d.reshape(-1); gf = p.g.reshape(-1) idxs = np.random.default_rng(7).integers(0, flat.size, min(6, flat.size)) maxerr = 0.0 for i in idxs: old = float(flat[i]); eps = 1e-3 flat[i] = old + eps; lp = fwd().d flat[i] = old - eps; lm = fwd().d flat[i] = old num = (lp - lm) / (2 * eps); ana = float(gf[i]) err = abs(num - ana) / (abs(num) + abs(ana) + 1e-8) maxerr = max(maxerr, err) worst = max(worst, maxerr) print(f" {name:8s} max-rel-err {maxerr:.2e}") print(f"aux={aux} WORST {worst:.2e} -> {'PASS' if worst < 2e-2 else 'FAIL'}") return worst if __name__ == "__main__": print("gradcheck (no aux):"); w1 = run(0.0) print("gradcheck (with load-balance aux):"); w2 = run(0.02)