Spaces:
Running
Running
| // CPU reference EqPropNet, faithful to fhn.html. | |
| // Used as the ground truth for GPU validation. | |
| export const MN_ALPHA=1.08, MN_A1=(1-1/MN_ALPHA); | |
| export const fhnF = (u)=> MN_A1*u - u*u*u; | |
| export const fhnRho= (u)=> u<0?0:(u>1?1:u); | |
| export const fhnRhoP=(u)=> (u>0 && u<1) ? 1 : 0; | |
| export const sg = (u)=> 1/(1+Math.exp(-4*(u-0.5))); | |
| export const sgp= (u)=>{ const s=sg(u); return 4*s*(1-s); }; | |
| export function rng(seed){ let s=seed>>>0; return ()=>{ s=(Math.imul(s,1664525)+1013904223)>>>0; return s/4294967296; }; } | |
| // matrix utilities (double precision for orth — matches fhn.html) | |
| function mm(A,B,m,k,n){ const C=new Float64Array(m*n); | |
| for(let i=0;i<m;i++) for(let p=0;p<k;p++){ const a=A[i*k+p]; if(a) for(let j=0;j<n;j++) C[i*n+j]+=a*B[p*n+j]; } | |
| return C; | |
| } | |
| function tr(A,m,n){ const T=new Float64Array(m*n); for(let i=0;i<m;i++) for(let j=0;j<n;j++) T[j*m+i]=A[i*n+j]; return T; } | |
| function invSqrtSPD(S,n,iters=18){ | |
| let nrm=0; for(const x of S) nrm+=x*x; nrm=Math.sqrt(nrm)+1e-30; | |
| let Y=new Float64Array(S.length); for(let i=0;i<S.length;i++) Y[i]=S[i]/nrm; | |
| let Z=new Float64Array(n*n); for(let i=0;i<n;i++) Z[i*n+i]=1; | |
| for(let it=0;it<iters;it++){ | |
| const ZY=mm(Z,Y,n,n,n); | |
| const T=new Float64Array(n*n); for(let i=0;i<n*n;i++) T[i]=-ZY[i]; for(let i=0;i<n;i++) T[i*n+i]+=3; | |
| const nY=mm(Y,T,n,n,n); for(let i=0;i<nY.length;i++) nY[i]*=0.5; | |
| const nZ=mm(T,Z,n,n,n); for(let i=0;i<nZ.length;i++) nZ[i]*=0.5; | |
| Y=nY; Z=nZ; | |
| } | |
| const sq=Math.sqrt(nrm); for(let i=0;i<Z.length;i++) Z[i]/=sq; return Z; | |
| } | |
| export function orth(G,m,n){ | |
| if(m>=n){ const Gt=tr(G,m,n); const S=mm(Gt,G,n,m,n); const is=invSqrtSPD(S,n); return mm(G,is,m,n,n); } | |
| else { const Gt=tr(G,m,n); const S=mm(G,Gt,m,n,m); const is=invSqrtSPD(S,m); return mm(is,G,m,m,n); } | |
| } | |
| export class EqPropNet { | |
| // sizes:[D,H,...,O] mode:'adaptive'|'fhn' opt:'sgd'|'adagrad'|'adago' | |
| constructor(sizes, seed=7, mode='adaptive', opt='adago'){ | |
| this.sizes=sizes; this.L=sizes.length; this.mode=mode; this.opt=opt; | |
| const r=rng(seed>>>0||1); this.W=[]; this.b=[]; | |
| for(let l=0;l<this.L-1;l++){ | |
| const ni=sizes[l], no=sizes[l+1], sc=Math.sqrt(2/(ni+no)); | |
| const w=new Float32Array(no*ni); for(let k=0;k<w.length;k++) w[k]=(r()*2-1)*sc; | |
| this.W.push(w); this.b.push(new Float32Array(no)); | |
| } | |
| this.GW=this.W.map(w=>new Float64Array(w.length)); | |
| this.GB=this.b.map(b=>new Float64Array(b.length)); | |
| this.MW=this.W.map(w=>new Float64Array(w.length)); | |
| this.MB=this.b.map(b=>new Float64Array(b.length)); | |
| this.vW=new Float64Array(this.L-1); | |
| this.vB=new Float64Array(this.L-1); | |
| this.OW=new Array(this.L-1).fill(null); | |
| this.OW_K=4; this.bc=0; | |
| this.gamma=0.6; this.betaN=(mode==='fhn')?0.9:0.5; | |
| this.adpC=0.15; this.adpSteps=3; this.tauInv=0.0; | |
| this.rmin=0.1; this.escale=0.4; | |
| } | |
| relax(x, iters, dt, beta=0, target=null){ | |
| const sz=this.sizes, L=this.L, ad=(this.mode==='adaptive'); | |
| const rho=ad?sg:fhnRho, rhop=ad?sgp:fhnRhoP; | |
| const u=sz.map(n=>new Float32Array(n)); u[0].set(x); for(let l=1;l<L;l++) u[l].fill(0.1); | |
| const w=ad? sz.map(n=>new Float32Array(n)) : null; | |
| for(let t=0;t<iters;t++){ | |
| for(let l=1;l<L;l++){ | |
| const Wlm=this.W[l-1], blm=this.b[l-1], ni=sz[l-1], no=sz[l]; | |
| const Wl=(l<L-1)?this.W[l]:null, nip1=(l<L-1)?sz[l+1]:0; | |
| const ulm=u[l-1], ul=u[l], ulp=(l<L-1)?u[l+1]:null, wl=ad?w[l]:null; | |
| for(let i=0;i<no;i++){ | |
| let c=blm[i]; for(let j=0;j<ni;j++) c+=Wlm[i*ni+j]*rho(ulm[j]); | |
| if(Wl){ let td=0; for(let k=0;k<nip1;k++) td+=Wl[k*no+i]*rho(ulp[k]); c+= ad? this.gamma*td : td; } | |
| let drive; | |
| if(ad){ drive = -ul[i] + sg(c) - wl[i]; | |
| if(beta!==0 && l===L-1 && target) drive += beta*(target[i]-ul[i]); | |
| ul[i] += dt*drive; wl[i] += dt*this.tauInv*(sg(ul[i])-wl[i]); } | |
| else{ drive = rhop(ul[i])*c + fhnF(ul[i]); | |
| if(beta!==0 && l===L-1 && target) drive += beta*(target[i]-rho(ul[i])); | |
| let nu=ul[i]+dt*drive; if(nu<-0.2)nu=-0.2; else if(nu>1.2)nu=1.2; ul[i]=nu; } | |
| } | |
| } | |
| } | |
| return u; | |
| } | |
| outputs(x,iters,dt){ return this.relax(x,iters,dt,0,null)[this.L-1]; } | |
| predict(x,iters,dt){ const o=this.outputs(x,iters,dt); let bi=0,bv=-1e9; for(let i=0;i<o.length;i++) if(o[i]>bv){bv=o[i];bi=i;} return bi; } | |
| accum(x,label,iters,nIters,dt,gW,gB){ | |
| const L=this.L, sz=this.sizes, ad=(this.mode==='adaptive'); const rho=ad?sg:fhnRho; | |
| const tgt=new Float32Array(sz[L-1]); tgt[label]=1; const bN=this.betaN; | |
| const uf=this.relax(x,iters,dt,0,null); | |
| const o=uf[L-1]; let loss=0; for(let i=0;i<o.length;i++){ const d=rho(o[i])-tgt[i]; loss+=d*d; } | |
| const r=this.rmin+(1-this.rmin)*Math.min(1, loss/this.escale); | |
| const up=this.relax(x,nIters,dt,+bN,tgt); | |
| const um=this.relax(x,nIters,dt,-bN,tgt); | |
| if(ad){ const c=this.adpC; for(let a=0;a<this.adpSteps;a++) for(let l=1;l<L;l++) for(let i=0;i<sz[l];i++){ | |
| up[l][i]=(1-c)*up[l][i]+c*uf[l][i]; um[l][i]=(1-c)*um[l][i]+c*uf[l][i]; | |
| }} | |
| for(let l=0;l<L-1;l++){ | |
| const ni=sz[l], no=sz[l+1]; const rip=up[l+1], rim=um[l+1], rjp=up[l], rjm=um[l]; | |
| for(let i=0;i<no;i++){ | |
| const a=rho(rip[i]), cc=rho(rim[i]); | |
| gB[l][i]+= r*(a-cc)/(2*bN); | |
| for(let j=0;j<ni;j++) gW[l][i*ni+j]+= r*(a*rho(rjp[j])-cc*rho(rjm[j]))/(2*bN); | |
| } | |
| } | |
| return loss; | |
| } | |
| apply(gW,gB,bs,lr){ | |
| const eps=1e-8, mu=0.9, gam=1.0; this.bc++; | |
| for(let l=0;l<this.L-1;l++){ | |
| const ni=this.sizes[l], no=this.sizes[l+1], W=this.W[l], B=this.b[l]; | |
| if(this.opt==='sgd'){ | |
| for(let k=0;k<W.length;k++) W[k]+=lr*gW[l][k]/bs; | |
| for(let k=0;k<B.length;k++) B[k]+=lr*gB[l][k]/bs; | |
| } else if(this.opt==='adagrad'){ | |
| for(let k=0;k<W.length;k++){ const g=gW[l][k]/bs; this.GW[l][k]+=g*g; W[k]+=lr*g/(Math.sqrt(this.GW[l][k])+1e-6); } | |
| for(let k=0;k<B.length;k++){ const g=gB[l][k]/bs; this.GB[l][k]+=g*g; B[k]+=lr*g/(Math.sqrt(this.GB[l][k])+1e-6); } | |
| } else { // adago | |
| let gn2=0; for(let k=0;k<W.length;k++){ const g=gW[l][k]/bs; this.MW[l][k]=mu*this.MW[l][k]+(1-mu)*g; gn2+=g*g; } | |
| const gn=Math.sqrt(gn2); this.vW[l]+=Math.min(gn2,gam*gam); | |
| if(!this.OW[l] || this.bc%this.OW_K===0) this.OW[l]=orth(this.MW[l],no,ni); | |
| const O=this.OW[l]; | |
| const alpha=Math.max(eps, lr*Math.min(gn,gam)/(Math.sqrt(this.vW[l])+eps)); | |
| for(let k=0;k<W.length;k++) W[k]+= alpha*O[k]; | |
| let bn2=0; for(let k=0;k<B.length;k++){ const g=gB[l][k]/bs; this.MB[l][k]=mu*this.MB[l][k]+(1-mu)*g; bn2+=g*g; } | |
| const bn=Math.sqrt(bn2); this.vB[l]+=Math.min(bn2,gam*gam); | |
| const ba=Math.max(eps, lr*Math.min(bn,gam)/(Math.sqrt(this.vB[l])+eps)); | |
| for(let k=0;k<B.length;k++) B[k]+= ba*(bn>0? this.MB[l][k]/bn : 0); | |
| } | |
| } | |
| } | |
| newGradBuffers(){ return [this.W.map(w=>new Float64Array(w.length)), this.b.map(b=>new Float64Array(b.length))]; } | |
| } | |
| // MNIST loader (browser side: fetches the JSON pack served from disk). | |
| export async function loadMnist(url){ | |
| const j = await (await fetch(url)).json(); | |
| const dec = (b64)=>{ const bin=atob(b64); const u=new Uint8Array(bin.length); for(let i=0;i<bin.length;i++) u[i]=bin.charCodeAt(i); return u; }; | |
| // Cin defaults to 1 (single-channel, back-compat with MNIST/Fashion). | |
| // For RGB (CIFAR) packs, Cin=3. Layout is Cin-major: [R-plane (R*R) || G-plane || B-plane] per sample. | |
| const Cin = j.Cin || 1; | |
| const out = {Cin, R:j.R, D:Cin*j.R*j.R, ntr:j.ntr, nte:j.nte, xtr:dec(j.xtr), ytr:dec(j.ytr), xte:dec(j.xte), yte:dec(j.yte)}; | |
| return out; | |
| } | |
| export function imgF32(MN, set, i){ | |
| const D=MN.D, a=new Float32Array(D), src=set==='tr'?MN.xtr:MN.xte; | |
| for(let p=0;p<D;p++) a[p]=src[i*D+p]/255; | |
| return a; | |
| } | |