cortex-conv / tests /eqprop_lib.js
Ex0bit's picture
initial deployment: cortex-conv ships pre-trained at 96.8% MNIST
d035fbd
// CPU reference EqPropNet, faithful to fhn.html.
// Used as the ground truth for GPU validation.
export const MN_ALPHA=1.08, MN_A1=(1-1/MN_ALPHA);
export const fhnF = (u)=> MN_A1*u - u*u*u;
export const fhnRho= (u)=> u<0?0:(u>1?1:u);
export const fhnRhoP=(u)=> (u>0 && u<1) ? 1 : 0;
export const sg = (u)=> 1/(1+Math.exp(-4*(u-0.5)));
export const sgp= (u)=>{ const s=sg(u); return 4*s*(1-s); };
export function rng(seed){ let s=seed>>>0; return ()=>{ s=(Math.imul(s,1664525)+1013904223)>>>0; return s/4294967296; }; }
// matrix utilities (double precision for orth — matches fhn.html)
function mm(A,B,m,k,n){ const C=new Float64Array(m*n);
for(let i=0;i<m;i++) for(let p=0;p<k;p++){ const a=A[i*k+p]; if(a) for(let j=0;j<n;j++) C[i*n+j]+=a*B[p*n+j]; }
return C;
}
function tr(A,m,n){ const T=new Float64Array(m*n); for(let i=0;i<m;i++) for(let j=0;j<n;j++) T[j*m+i]=A[i*n+j]; return T; }
function invSqrtSPD(S,n,iters=18){
let nrm=0; for(const x of S) nrm+=x*x; nrm=Math.sqrt(nrm)+1e-30;
let Y=new Float64Array(S.length); for(let i=0;i<S.length;i++) Y[i]=S[i]/nrm;
let Z=new Float64Array(n*n); for(let i=0;i<n;i++) Z[i*n+i]=1;
for(let it=0;it<iters;it++){
const ZY=mm(Z,Y,n,n,n);
const T=new Float64Array(n*n); for(let i=0;i<n*n;i++) T[i]=-ZY[i]; for(let i=0;i<n;i++) T[i*n+i]+=3;
const nY=mm(Y,T,n,n,n); for(let i=0;i<nY.length;i++) nY[i]*=0.5;
const nZ=mm(T,Z,n,n,n); for(let i=0;i<nZ.length;i++) nZ[i]*=0.5;
Y=nY; Z=nZ;
}
const sq=Math.sqrt(nrm); for(let i=0;i<Z.length;i++) Z[i]/=sq; return Z;
}
export function orth(G,m,n){
if(m>=n){ const Gt=tr(G,m,n); const S=mm(Gt,G,n,m,n); const is=invSqrtSPD(S,n); return mm(G,is,m,n,n); }
else { const Gt=tr(G,m,n); const S=mm(G,Gt,m,n,m); const is=invSqrtSPD(S,m); return mm(is,G,m,m,n); }
}
export class EqPropNet {
// sizes:[D,H,...,O] mode:'adaptive'|'fhn' opt:'sgd'|'adagrad'|'adago'
constructor(sizes, seed=7, mode='adaptive', opt='adago'){
this.sizes=sizes; this.L=sizes.length; this.mode=mode; this.opt=opt;
const r=rng(seed>>>0||1); this.W=[]; this.b=[];
for(let l=0;l<this.L-1;l++){
const ni=sizes[l], no=sizes[l+1], sc=Math.sqrt(2/(ni+no));
const w=new Float32Array(no*ni); for(let k=0;k<w.length;k++) w[k]=(r()*2-1)*sc;
this.W.push(w); this.b.push(new Float32Array(no));
}
this.GW=this.W.map(w=>new Float64Array(w.length));
this.GB=this.b.map(b=>new Float64Array(b.length));
this.MW=this.W.map(w=>new Float64Array(w.length));
this.MB=this.b.map(b=>new Float64Array(b.length));
this.vW=new Float64Array(this.L-1);
this.vB=new Float64Array(this.L-1);
this.OW=new Array(this.L-1).fill(null);
this.OW_K=4; this.bc=0;
this.gamma=0.6; this.betaN=(mode==='fhn')?0.9:0.5;
this.adpC=0.15; this.adpSteps=3; this.tauInv=0.0;
this.rmin=0.1; this.escale=0.4;
}
relax(x, iters, dt, beta=0, target=null){
const sz=this.sizes, L=this.L, ad=(this.mode==='adaptive');
const rho=ad?sg:fhnRho, rhop=ad?sgp:fhnRhoP;
const u=sz.map(n=>new Float32Array(n)); u[0].set(x); for(let l=1;l<L;l++) u[l].fill(0.1);
const w=ad? sz.map(n=>new Float32Array(n)) : null;
for(let t=0;t<iters;t++){
for(let l=1;l<L;l++){
const Wlm=this.W[l-1], blm=this.b[l-1], ni=sz[l-1], no=sz[l];
const Wl=(l<L-1)?this.W[l]:null, nip1=(l<L-1)?sz[l+1]:0;
const ulm=u[l-1], ul=u[l], ulp=(l<L-1)?u[l+1]:null, wl=ad?w[l]:null;
for(let i=0;i<no;i++){
let c=blm[i]; for(let j=0;j<ni;j++) c+=Wlm[i*ni+j]*rho(ulm[j]);
if(Wl){ let td=0; for(let k=0;k<nip1;k++) td+=Wl[k*no+i]*rho(ulp[k]); c+= ad? this.gamma*td : td; }
let drive;
if(ad){ drive = -ul[i] + sg(c) - wl[i];
if(beta!==0 && l===L-1 && target) drive += beta*(target[i]-ul[i]);
ul[i] += dt*drive; wl[i] += dt*this.tauInv*(sg(ul[i])-wl[i]); }
else{ drive = rhop(ul[i])*c + fhnF(ul[i]);
if(beta!==0 && l===L-1 && target) drive += beta*(target[i]-rho(ul[i]));
let nu=ul[i]+dt*drive; if(nu<-0.2)nu=-0.2; else if(nu>1.2)nu=1.2; ul[i]=nu; }
}
}
}
return u;
}
outputs(x,iters,dt){ return this.relax(x,iters,dt,0,null)[this.L-1]; }
predict(x,iters,dt){ const o=this.outputs(x,iters,dt); let bi=0,bv=-1e9; for(let i=0;i<o.length;i++) if(o[i]>bv){bv=o[i];bi=i;} return bi; }
accum(x,label,iters,nIters,dt,gW,gB){
const L=this.L, sz=this.sizes, ad=(this.mode==='adaptive'); const rho=ad?sg:fhnRho;
const tgt=new Float32Array(sz[L-1]); tgt[label]=1; const bN=this.betaN;
const uf=this.relax(x,iters,dt,0,null);
const o=uf[L-1]; let loss=0; for(let i=0;i<o.length;i++){ const d=rho(o[i])-tgt[i]; loss+=d*d; }
const r=this.rmin+(1-this.rmin)*Math.min(1, loss/this.escale);
const up=this.relax(x,nIters,dt,+bN,tgt);
const um=this.relax(x,nIters,dt,-bN,tgt);
if(ad){ const c=this.adpC; for(let a=0;a<this.adpSteps;a++) for(let l=1;l<L;l++) for(let i=0;i<sz[l];i++){
up[l][i]=(1-c)*up[l][i]+c*uf[l][i]; um[l][i]=(1-c)*um[l][i]+c*uf[l][i];
}}
for(let l=0;l<L-1;l++){
const ni=sz[l], no=sz[l+1]; const rip=up[l+1], rim=um[l+1], rjp=up[l], rjm=um[l];
for(let i=0;i<no;i++){
const a=rho(rip[i]), cc=rho(rim[i]);
gB[l][i]+= r*(a-cc)/(2*bN);
for(let j=0;j<ni;j++) gW[l][i*ni+j]+= r*(a*rho(rjp[j])-cc*rho(rjm[j]))/(2*bN);
}
}
return loss;
}
apply(gW,gB,bs,lr){
const eps=1e-8, mu=0.9, gam=1.0; this.bc++;
for(let l=0;l<this.L-1;l++){
const ni=this.sizes[l], no=this.sizes[l+1], W=this.W[l], B=this.b[l];
if(this.opt==='sgd'){
for(let k=0;k<W.length;k++) W[k]+=lr*gW[l][k]/bs;
for(let k=0;k<B.length;k++) B[k]+=lr*gB[l][k]/bs;
} else if(this.opt==='adagrad'){
for(let k=0;k<W.length;k++){ const g=gW[l][k]/bs; this.GW[l][k]+=g*g; W[k]+=lr*g/(Math.sqrt(this.GW[l][k])+1e-6); }
for(let k=0;k<B.length;k++){ const g=gB[l][k]/bs; this.GB[l][k]+=g*g; B[k]+=lr*g/(Math.sqrt(this.GB[l][k])+1e-6); }
} else { // adago
let gn2=0; for(let k=0;k<W.length;k++){ const g=gW[l][k]/bs; this.MW[l][k]=mu*this.MW[l][k]+(1-mu)*g; gn2+=g*g; }
const gn=Math.sqrt(gn2); this.vW[l]+=Math.min(gn2,gam*gam);
if(!this.OW[l] || this.bc%this.OW_K===0) this.OW[l]=orth(this.MW[l],no,ni);
const O=this.OW[l];
const alpha=Math.max(eps, lr*Math.min(gn,gam)/(Math.sqrt(this.vW[l])+eps));
for(let k=0;k<W.length;k++) W[k]+= alpha*O[k];
let bn2=0; for(let k=0;k<B.length;k++){ const g=gB[l][k]/bs; this.MB[l][k]=mu*this.MB[l][k]+(1-mu)*g; bn2+=g*g; }
const bn=Math.sqrt(bn2); this.vB[l]+=Math.min(bn2,gam*gam);
const ba=Math.max(eps, lr*Math.min(bn,gam)/(Math.sqrt(this.vB[l])+eps));
for(let k=0;k<B.length;k++) B[k]+= ba*(bn>0? this.MB[l][k]/bn : 0);
}
}
}
newGradBuffers(){ return [this.W.map(w=>new Float64Array(w.length)), this.b.map(b=>new Float64Array(b.length))]; }
}
// MNIST loader (browser side: fetches the JSON pack served from disk).
export async function loadMnist(url){
const j = await (await fetch(url)).json();
const dec = (b64)=>{ const bin=atob(b64); const u=new Uint8Array(bin.length); for(let i=0;i<bin.length;i++) u[i]=bin.charCodeAt(i); return u; };
// Cin defaults to 1 (single-channel, back-compat with MNIST/Fashion).
// For RGB (CIFAR) packs, Cin=3. Layout is Cin-major: [R-plane (R*R) || G-plane || B-plane] per sample.
const Cin = j.Cin || 1;
const out = {Cin, R:j.R, D:Cin*j.R*j.R, ntr:j.ntr, nte:j.nte, xtr:dec(j.xtr), ytr:dec(j.ytr), xte:dec(j.xte), yte:dec(j.yte)};
return out;
}
export function imgF32(MN, set, i){
const D=MN.D, a=new Float32Array(D), src=set==='tr'?MN.xtr:MN.xte;
for(let p=0;p<D;p++) a[p]=src[i*D+p]/255;
return a;
}