File size: 1,612 Bytes
57f9808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Finite-difference check of the from-scratch autograd. If analytic grads match
numeric grads, the MoE+NCA backward is correct and training can be trusted."""
import numpy as np
from garyneuron import init_params, forward
from data import make_batch

def run(aux=0.0):
    cfg = dict(S=4, d=6, he=6, K=3, topk=2, steps=3, p_update=0.6,
               Vin=10, Vout=10, aux=aux)
    P = init_params(cfg, seed=0)
    A, B, Y = make_batch(5, cfg["S"], np.random.default_rng(1))
    MASK_SEED = 42
    def fwd():
        return forward(P, A, B, Y, cfg, np.random.default_rng(MASK_SEED), train=True)[0]
    tot = fwd()
    for v in P.values(): v.g = None
    tot.backward()
    worst = 0.0
    for name in ["Wr", "br", "e0.W2", "e1.W1", "emb", "posemb", "Wo", "bo"]:
        p = P[name]; flat = p.d.reshape(-1); gf = p.g.reshape(-1)
        idxs = np.random.default_rng(7).integers(0, flat.size, min(6, flat.size))
        maxerr = 0.0
        for i in idxs:
            old = float(flat[i]); eps = 1e-3
            flat[i] = old + eps; lp = fwd().d
            flat[i] = old - eps; lm = fwd().d
            flat[i] = old
            num = (lp - lm) / (2 * eps); ana = float(gf[i])
            err = abs(num - ana) / (abs(num) + abs(ana) + 1e-8)
            maxerr = max(maxerr, err)
        worst = max(worst, maxerr)
        print(f"  {name:8s} max-rel-err {maxerr:.2e}")
    print(f"aux={aux}  WORST {worst:.2e}  ->  {'PASS' if worst < 2e-2 else 'FAIL'}")
    return worst

if __name__ == "__main__":
    print("gradcheck (no aux):");  w1 = run(0.0)
    print("gradcheck (with load-balance aux):"); w2 = run(0.02)