| """Finite-difference check of the from-scratch autograd. If analytic grads match |
| numeric grads, the MoE+NCA backward is correct and training can be trusted.""" |
| import numpy as np |
| from garyneuron import init_params, forward |
| from data import make_batch |
|
|
| def run(aux=0.0): |
| cfg = dict(S=4, d=6, he=6, K=3, topk=2, steps=3, p_update=0.6, |
| Vin=10, Vout=10, aux=aux) |
| P = init_params(cfg, seed=0) |
| A, B, Y = make_batch(5, cfg["S"], np.random.default_rng(1)) |
| MASK_SEED = 42 |
| def fwd(): |
| return forward(P, A, B, Y, cfg, np.random.default_rng(MASK_SEED), train=True)[0] |
| tot = fwd() |
| for v in P.values(): v.g = None |
| tot.backward() |
| worst = 0.0 |
| for name in ["Wr", "br", "e0.W2", "e1.W1", "emb", "posemb", "Wo", "bo"]: |
| p = P[name]; flat = p.d.reshape(-1); gf = p.g.reshape(-1) |
| idxs = np.random.default_rng(7).integers(0, flat.size, min(6, flat.size)) |
| maxerr = 0.0 |
| for i in idxs: |
| old = float(flat[i]); eps = 1e-3 |
| flat[i] = old + eps; lp = fwd().d |
| flat[i] = old - eps; lm = fwd().d |
| flat[i] = old |
| num = (lp - lm) / (2 * eps); ana = float(gf[i]) |
| err = abs(num - ana) / (abs(num) + abs(ana) + 1e-8) |
| maxerr = max(maxerr, err) |
| worst = max(worst, maxerr) |
| print(f" {name:8s} max-rel-err {maxerr:.2e}") |
| print(f"aux={aux} WORST {worst:.2e} -> {'PASS' if worst < 2e-2 else 'FAIL'}") |
| return worst |
|
|
| if __name__ == "__main__": |
| print("gradcheck (no aux):"); w1 = run(0.0) |
| print("gradcheck (with load-balance aux):"); w2 = run(0.02) |
|
|