File size: 1,612 Bytes
57f9808 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | """Finite-difference check of the from-scratch autograd. If analytic grads match
numeric grads, the MoE+NCA backward is correct and training can be trusted."""
import numpy as np
from garyneuron import init_params, forward
from data import make_batch
def run(aux=0.0):
cfg = dict(S=4, d=6, he=6, K=3, topk=2, steps=3, p_update=0.6,
Vin=10, Vout=10, aux=aux)
P = init_params(cfg, seed=0)
A, B, Y = make_batch(5, cfg["S"], np.random.default_rng(1))
MASK_SEED = 42
def fwd():
return forward(P, A, B, Y, cfg, np.random.default_rng(MASK_SEED), train=True)[0]
tot = fwd()
for v in P.values(): v.g = None
tot.backward()
worst = 0.0
for name in ["Wr", "br", "e0.W2", "e1.W1", "emb", "posemb", "Wo", "bo"]:
p = P[name]; flat = p.d.reshape(-1); gf = p.g.reshape(-1)
idxs = np.random.default_rng(7).integers(0, flat.size, min(6, flat.size))
maxerr = 0.0
for i in idxs:
old = float(flat[i]); eps = 1e-3
flat[i] = old + eps; lp = fwd().d
flat[i] = old - eps; lm = fwd().d
flat[i] = old
num = (lp - lm) / (2 * eps); ana = float(gf[i])
err = abs(num - ana) / (abs(num) + abs(ana) + 1e-8)
maxerr = max(maxerr, err)
worst = max(worst, maxerr)
print(f" {name:8s} max-rel-err {maxerr:.2e}")
print(f"aux={aux} WORST {worst:.2e} -> {'PASS' if worst < 2e-2 else 'FAIL'}")
return worst
if __name__ == "__main__":
print("gradcheck (no aux):"); w1 = run(0.0)
print("gradcheck (with load-balance aux):"); w2 = run(0.02)
|