cortex-conv / tests /gpu_lib_deep.js
Ex0bit's picture
initial deployment: cortex-conv ships pre-trained at 96.8% MNIST
d035fbd
// Multi-hidden-layer WebGPU EqProp trainer.
// sizes = [D, H1, H2, ..., Hk, O]
// Modes:
// 0 = adaptive (σ everywhere, u-nudge at output)
// 1 = fhn (clip everywhere, ρ-nudge at output)
//
// Key design:
// * ONE generic compute pipeline `pass_layer` updates any layer using uniforms (ni, no, nxt, ...).
// * 3 phases (free / plus / minus): each has its own per-layer state buffer + per-layer bind group.
// * Uniform buffers are pre-written once per pass (one per (phase, layer) pair), so the encoder
// records dispatches that each pick up the right uniforms.
// * Gradient kernels are also generic — per layer transition, compute outer product reduction over batch.
import { orth as orthCPU } from './eqprop_lib.js';
// ----- WGSL: relax (one kernel handles any layer) -----
const WGSL_RELAX = `
struct P {
ni : u32, no : u32, nxt : u32, B : u32,
dt : f32, beta : f32, gamma : f32, mode : f32, // mode: 0=adaptive σ, 1=fhn(clip+cubic), 2=prism (soft-clip via softplus)
has_topdown : u32, has_target : u32, noise_scale : f32, iter_seed : u32,
// sEqProp: noise_scale > 0 injects per-iter per-neuron Gaussian-ish noise into drive.
// Bio-faithful (real synapses are stochastic). At test, run M passes & average outputs.
clamp_lo : f32, clamp_hi : f32, _p_t1 : f32, _p_t2 : f32,
// Tier A — pre-σ drive clamp (algorithmic, uniform-driven). Bounds the pre-activation
// value c before σ(c) to prevent saturation runaway. When clamp_hi <= clamp_lo the
// kernel treats it as DISABLED (no-op). Default in caller = clamp_lo=clamp_hi=0 → disabled.
};
@group(0) @binding(0) var<uniform> p : P;
@group(0) @binding(1) var<storage, read> Win : array<f32>;
@group(0) @binding(2) var<storage, read> W0 : array<f32>; // [no x ni]
@group(0) @binding(3) var<storage, read> b0 : array<f32>; // [no]
@group(0) @binding(4) var<storage, read> W1 : array<f32>; // [nxt x no] (top-down)
@group(0) @binding(5) var<storage, read_write> Uh : array<f32>; // [B*no]
@group(0) @binding(6) var<storage, read> Uo : array<f32>; // [B*nxt]
@group(0) @binding(7) var<storage, read> Tgt : array<f32>; // [B*no]
// HPSN: heterogeneous time constants — per-neuron multiplier on drive integration.
// Tau[i] replaces the global p.dt. Constant Tau[i]=p.dt → behavior identical to scalar-dt EqProp.
// Sampled from Uniform[tau_min, tau_max] → diverse temporal scales like real cortical neurons.
@group(0) @binding(8) var<storage, read> Tau : array<f32>; // [no]
const A1 : f32 = 0.07407407407407407;
const PRISM_K : f32 = 10.0; // sharpness; higher k → harder clip
// PCG-style cheap hash → uniform [0,1). Per-thread, per-iter, per-neuron stochasticity.
fn pcg_hash(seed_in: u32) -> u32 {
var state : u32 = seed_in * 747796405u + 2891336453u;
let word : u32 = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;
return (word >> 22u) ^ word;
}
fn unif_noise(b: u32, i: u32, t: u32) -> f32 {
// Triangular distribution (sum of 2 uniforms - 1) ≈ Gaussian-ish, mean=0, variance=1/6.
let h1 = pcg_hash(b * 65537u + i * 257u + t * 31u);
let h2 = pcg_hash(b * 31337u + i * 1031u + t * 17u + 12345u);
let u1 = f32(h1) / 4294967296.0;
let u2 = f32(h2) / 4294967296.0;
return (u1 + u2) - 1.0; // range [-1, 1], roughly triangular
}
fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
fn fhn_rho(u: f32) -> f32 { return clamp(u, 0.0, 1.0); }
fn fhn_rho_p(u: f32) -> f32 { return select(0.0, 1.0, u > 0.0 && u < 1.0); }
fn fhn_f(u: f32) -> f32 { return A1 * u - u*u*u; }
// PRISM activation: ρ(u) = (softplus(k·u) - softplus(k·(u-1))) / k
// Smooth approximation of clip(u,0,1). Derivative: σ(k·u) - σ(k·(u-1)).
// "Prism" = splits drive into a smooth-yet-saturating activation with gradient flow on both sides.
fn softplus(x: f32) -> f32 { return select(x + log(1.0 + exp(-x)), log(1.0 + exp(x)), x <= 0.0); }
fn prism_rho(u: f32) -> f32 {
return (softplus(PRISM_K * u) - softplus(PRISM_K * (u - 1.0))) / PRISM_K;
}
fn prism_rho_p(u: f32) -> f32 {
let k = PRISM_K;
return 1.0/(1.0+exp(-k*u)) - 1.0/(1.0+exp(-k*(u-1.0)));
}
fn rho(u: f32) -> f32 {
if (p.mode > 1.5) { return prism_rho(u); }
if (p.mode > 0.5) { return fhn_rho(u); }
return sg(u);
}
@compute @workgroup_size(8, 8) fn pass_layer(@builtin(global_invocation_id) gid: vec3<u32>) {
let b = gid.y; let i = gid.x;
if (b >= p.B || i >= p.no) { return; }
// bottom-up: c = b0[i] + sum_j W0[i,j] * rho(Win[b,j])
var c : f32 = b0[i];
let row0 = i * p.ni;
let xoff = b * p.ni;
for (var j: u32 = 0u; j < p.ni; j = j + 1u) {
c = c + W0[row0 + j] * rho(Win[xoff + j]);
}
// top-down: gamma * sum_k W1[k,i] * rho(Uo[b,k]) (if next layer exists)
if (p.has_topdown != 0u) {
var td : f32 = 0.0;
let uo_off = b * p.nxt;
for (var k: u32 = 0u; k < p.nxt; k = k + 1u) {
td = td + W1[k * p.no + i] * rho(Uo[uo_off + k]);
}
if (p.mode > 0.5) { c = c + td; } else { c = c + p.gamma * td; }
}
// Tier A — pre-σ drive clamp (algorithmic). Active iff clamp_hi > clamp_lo.
if (p.clamp_hi > p.clamp_lo) {
c = clamp(c, p.clamp_lo, p.clamp_hi);
}
let idx = b * p.no + i;
let u_old = Uh[idx];
// sEqProp noise: per-(b, i, iter_seed) triangular noise added to drive. Zero by default.
let noise = select(0.0, p.noise_scale * unif_noise(b, i, p.iter_seed), p.noise_scale > 0.0);
var u_new : f32;
if (p.mode > 1.5) {
// PRISM: u̇ = ρ'(u)·c + (linear pull) ; ρ-nudge for output. Smooth saturating dynamics.
var drive : f32 = prism_rho_p(u_old) * c - 0.1 * (u_old - 0.5) + noise;
if (p.has_target != 0u && p.beta != 0.0) {
drive = drive + p.beta * (Tgt[idx] - prism_rho(u_old));
}
u_new = u_old + Tau[i] * drive;
u_new = clamp(u_new, -0.3, 1.3);
} else if (p.mode > 0.5) {
// FHN
var drive : f32 = fhn_rho_p(u_old) * c + fhn_f(u_old) + noise;
if (p.has_target != 0u && p.beta != 0.0) {
drive = drive + p.beta * (Tgt[idx] - fhn_rho(u_old));
}
u_new = u_old + Tau[i] * drive;
u_new = clamp(u_new, -0.2, 1.2);
} else {
// Adaptive
var drive : f32 = -u_old + sg(c) + noise;
if (p.has_target != 0u && p.beta != 0.0) {
drive = drive + p.beta * (Tgt[idx] - u_old);
}
u_new = u_old + Tau[i] * drive;
}
Uh[idx] = u_new;
}
// 2D dispatch to handle big buffers (B*no can exceed the 65535 per-dim workgroup limit).
@compute @workgroup_size(64) fn init_state(@builtin(global_invocation_id) gid: vec3<u32>) {
let stride = 65535u * 64u; // workgroups_per_X * threads_per_workgroup
let g = gid.y * stride + gid.x;
let n = p.B * p.no;
if (g < n) { Uh[g] = 0.1; }
}
`;
// ----- WGSL: gradient (one kernel handles any layer transition) -----
const WGSL_GRAD = `
struct GP {
ni : u32, no : u32, _pad : u32, B : u32,
c : f32, two_beta : f32, mode_pre : f32, mode_post : f32, // mode_pre/post: 0=σ, 1=clip, 2=identity
};
@group(0) @binding(0) var<uniform> p : GP;
@group(0) @binding(1) var<storage, read> UpreP : array<f32>; // [B*ni] - "input" layer state, plus phase
@group(0) @binding(2) var<storage, read> UpreM : array<f32>; // [B*ni] - minus
@group(0) @binding(3) var<storage, read> UpostP: array<f32>; // [B*no]
@group(0) @binding(4) var<storage, read> UpostM: array<f32>; // [B*no]
@group(0) @binding(5) var<storage, read> R : array<f32>; // [B]
@group(0) @binding(6) var<storage, read_write> gW : array<f32>; // [no*ni]
@group(0) @binding(7) var<storage, read_write> gB : array<f32>; // [no]
fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
const PRISM_K2 : f32 = 10.0;
fn softplus2(x: f32) -> f32 { return select(x + log(1.0 + exp(-x)), log(1.0 + exp(x)), x <= 0.0); }
fn prism_rho_g(u: f32) -> f32 { return (softplus2(PRISM_K2*u) - softplus2(PRISM_K2*(u-1.0))) / PRISM_K2; }
fn rho_mode(u: f32, m: f32) -> f32 {
if (m > 2.5) { return u; } // identity (linear)
if (m > 1.5) { return prism_rho_g(u); } // prism soft-clip
if (m > 0.5) { return clamp(u, 0.0, 1.0); } // hard-clip (FHN)
return sg(u); // σ (adaptive)
}
@compute @workgroup_size(8, 8) fn grad_W(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.y; let j = gid.x;
if (i >= p.no || j >= p.ni) { return; }
var acc : f32 = 0.0;
for (var b: u32 = 0u; b < p.B; b = b + 1u) {
let rh = R[b];
let ip = rho_mode(UpostP[b * p.no + i], p.mode_post);
let im = rho_mode(UpostM[b * p.no + i], p.mode_post);
let jp = rho_mode(UpreP[b * p.ni + j], p.mode_pre);
let jm = rho_mode(UpreM[b * p.ni + j], p.mode_pre);
acc = acc + rh * (ip * jp - im * jm);
}
gW[i * p.ni + j] = acc / p.two_beta;
}
@compute @workgroup_size(64) fn grad_B(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
if (i >= p.no) { return; }
var acc : f32 = 0.0;
for (var b: u32 = 0u; b < p.B; b = b + 1u) {
let rh = R[b];
let ip = rho_mode(UpostP[b * p.no + i], p.mode_post);
let im = rho_mode(UpostM[b * p.no + i], p.mode_post);
acc = acc + rh * (ip - im);
}
gB[i] = acc / p.two_beta;
}
`;
// ----- WGSL: reward + adaptation (depends on output layer state) -----
const WGSL_AUX = `
struct AP {
B : u32, O : u32, H_max : u32, n_hidden : u32,
c : f32, mode : f32, _p0 : f32, _p1 : f32,
};
@group(0) @binding(0) var<uniform> p : AP;
@group(0) @binding(1) var<storage, read> UoF : array<f32>;
@group(0) @binding(2) var<storage, read> Tgt : array<f32>;
@group(0) @binding(3) var<storage, read_write> R : array<f32>;
// adaptation buffers (variable size; we pass single layer at a time via separate bind groups)
@group(0) @binding(4) var<storage, read> Uf : array<f32>;
@group(0) @binding(5) var<storage, read_write> Up : array<f32>;
@group(0) @binding(6) var<storage, read_write> Um : array<f32>;
fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
fn rho_out(u: f32) -> f32 {
if (p.mode > 0.5) { return clamp(u, 0.0, 1.0); }
return sg(u);
}
@compute @workgroup_size(64) fn compute_reward(@builtin(global_invocation_id) gid: vec3<u32>) {
let b = gid.x;
if (b >= p.B) { return; }
var loss : f32 = 0.0;
let off = b * p.O;
for (var i: u32 = 0u; i < p.O; i = i + 1u) {
let d = rho_out(UoF[off + i]) - Tgt[off + i];
loss = loss + d * d;
}
let escale : f32 = 0.4;
let rmin : f32 = 0.1;
var r : f32 = loss / escale;
if (r > 1.0) { r = 1.0; }
R[b] = rmin + (1.0 - rmin) * r;
}
// Adjusted Adaptation per layer: Up,Um ← (1-c)*Up + c*Uf. 2D dispatch safe for large buffers.
@compute @workgroup_size(64) fn adapt_layer(@builtin(global_invocation_id) gid: vec3<u32>) {
let stride = 65535u * 64u;
let g = gid.y * stride + gid.x;
if (g >= arrayLength(&Uf)) { return; }
let f = Uf[g];
Up[g] = (1.0 - p.c) * Up[g] + p.c * f;
Um[g] = (1.0 - p.c) * Um[g] + p.c * f;
}
`;
// ----- JS: trainer class -----
export async function makeGPUDeep({powerPreference='high-performance'}={}){
if(!navigator.gpu) throw new Error('no webgpu');
const adapter = await navigator.gpu.requestAdapter({powerPreference});
if(!adapter) throw new Error('no adapter');
const want = {};
const tryKeys = ['maxStorageBuffersPerShaderStage','maxBufferSize','maxStorageBufferBindingSize',
'maxComputeInvocationsPerWorkgroup','maxComputeWorkgroupSizeX','maxComputeWorkgroupStorageSize','maxBindGroups'];
for(const k of tryKeys){ const v=adapter.limits[k]; if(typeof v==='number') want[k]=v; }
const dev = await adapter.requestDevice({requiredLimits: want});
return {adapter, dev, info: adapter.info||{}};
}
const PHASE_F = 0, PHASE_P = 1, PHASE_M = 2;
export class GPUTrainerDeep {
// sizes: [D, H1, H2, ..., Hk, O] — len L+1; L = number of weight matrices = sizes.length-1
// mode: 'adaptive' | 'fhn' | 'prism'
// driveClampLo, driveClampHi: Tier A — pre-σ drive clamp; ACTIVE iff hi > lo. Default 0,0 = disabled.
constructor({dev, sizes, B, mode='adaptive', gamma=0.6, hpsnTauMin=0, hpsnTauMax=0, hpsnSeed=42, driveClampLo=0, driveClampHi=0}={}){
this.dev = dev; this.sizes = sizes;
this.L = sizes.length - 1; // number of weight matrices (transitions)
this.B = B; this.O = sizes[sizes.length-1];
this.mode = mode;
this.modeFlag = (mode==='prism') ? 2.0 : (mode==='fhn' ? 1.0 : 0.0);
this.gamma = gamma;
this.hpsnTauMin = hpsnTauMin;
this.hpsnTauMax = hpsnTauMax;
this.hpsnSeed = hpsnSeed;
this.useHPSN = (hpsnTauMax > hpsnTauMin && hpsnTauMin > 0);
this.driveClampLo = driveClampLo;
this.driveClampHi = driveClampHi;
this._build();
// Initialize Tau buffers — either constant=0.7 (backward compat) or per-neuron Uniform[hpsnTauMin, hpsnTauMax].
if(this.useHPSN){
this.setAllTau(0.7, hpsnTauMin, hpsnTauMax, hpsnSeed);
} else {
this.setAllTau(0.7);
}
}
_F32buf(n, usage){
if(!Number.isFinite(n) || n <= 0){
console.error('BAD _F32buf size', {n, sizes:JSON.stringify(this.sizes), sizesArr:this.sizes, B:this.B, L:this.L, S0:this.sizes&&this.sizes[0], S0type:typeof (this.sizes&&this.sizes[0])});
throw new Error('_F32buf called with non-finite n=' + n + ' sizes=' + JSON.stringify(this.sizes));
}
return this.dev.createBuffer({size:Math.max(4,n*4), usage});
}
_build(){
const dev = this.dev, S = this.sizes, B = this.B, L = this.L;
const RW = GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_SRC|GPUBufferUsage.COPY_DST;
const R = GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST|GPUBufferUsage.COPY_SRC;
const UNI = GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST;
const RDS = GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ;
// input + target (shared across phases)
this.bufWin = this._F32buf(B * S[0], R);
this.bufTgt = this._F32buf(B * S[L], R);
// weights & biases (one per transition)
this.bufW = []; this.bufB = [];
for(let l=0; l<L; l++){
this.bufW.push(this._F32buf(S[l+1]*S[l], R));
this.bufB.push(this._F32buf(S[l+1], R));
}
// HPSN: per-layer Tau buffer of size [no]. Default = uniform scalar dt (backward compat).
// User can call setHeterogeneousTau(layer, tauMin, tauMax) to enable HPSN per layer.
this.bufTau = [];
for(let l=0; l<L; l++){ this.bufTau.push(this._F32buf(S[l+1], R)); }
// state buffers: for each of 3 phases, L state buffers (one per non-input layer)
this.bufU = [[],[],[]]; // bufU[phase][l] is layer l+1's state (l=0..L-1, sizes S[1..L])
for(let phase=0; phase<3; phase++){
for(let l=1; l<=L; l++){
this.bufU[phase].push(this._F32buf(B * S[l], RW));
}
}
// reward + dummies (need separate buffers for read-only vs writable slots to avoid aliasing).
this.bufR = this._F32buf(B, RW);
this.bufDummyR = this._F32buf(4, R); // read-only dummy
this.bufDummyRW1 = this._F32buf(4, RW); // writable dummy slot 1
this.bufDummyRW2 = this._F32buf(4, RW); // writable dummy slot 2 (different buffer!)
this.bufDummyRW3 = this._F32buf(4, RW);
// gradient buffer (packed: all gW and gB together)
const gSizes = []; let total=0;
for(let l=0; l<L; l++){ gSizes.push({offW:total, sizW:S[l+1]*S[l], offB:total+S[l+1]*S[l], sizB:S[l+1], total:S[l+1]*S[l]+S[l+1]}); total += gSizes[l].total; }
this.gOff = gSizes; this.gTotal = total;
this.bufG = this._F32buf(total, RW);
this.rbG = dev.createBuffer({size: total*4, usage: RDS});
// readback for output free-phase Uo (for accuracy/loss)
this.rbUoF = dev.createBuffer({size: B * S[L] * 4, usage: RDS});
// ---- pipelines ----
// Relax pipeline (generic)
const modR = dev.createShaderModule({code: WGSL_RELAX});
const sR = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'read-only-storage'}});
const sRW= (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'storage'}});
const uN = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'uniform'}});
this.bglR = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sRW(5), sR(6), sR(7), sR(8)]});
this.plR = dev.createPipelineLayout({bindGroupLayouts:[this.bglR]});
this.pipeRelax = dev.createComputePipeline({layout:this.plR, compute:{module:modR, entryPoint:'pass_layer'}});
this.pipeInit = dev.createComputePipeline({layout:this.plR, compute:{module:modR, entryPoint:'init_state'}});
// Grad pipeline (generic)
const modG = dev.createShaderModule({code: WGSL_GRAD});
this.bglG = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]});
this.plG = dev.createPipelineLayout({bindGroupLayouts:[this.bglG]});
this.pipeGW = dev.createComputePipeline({layout:this.plG, compute:{module:modG, entryPoint:'grad_W'}});
this.pipeGB = dev.createComputePipeline({layout:this.plG, compute:{module:modG, entryPoint:'grad_B'}});
// Aux pipeline (reward + adaptation)
const modA = dev.createShaderModule({code: WGSL_AUX});
this.bglA = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sRW(3), sR(4), sRW(5), sRW(6)]});
this.plA = dev.createPipelineLayout({bindGroupLayouts:[this.bglA]});
this.pipeReward = dev.createComputePipeline({layout:this.plA, compute:{module:modA, entryPoint:'compute_reward'}});
this.pipeAdapt = dev.createComputePipeline({layout:this.plA, compute:{module:modA, entryPoint:'adapt_layer'}});
// ---- uniform buffers (one per (phase, layer)) ----
// Each entry: 48 bytes (12 u32/f32 slots)
this.bufP_relax = [];
for(let phase=0; phase<3; phase++){
this.bufP_relax.push([]);
for(let l=1; l<=L; l++){
this.bufP_relax[phase].push(dev.createBuffer({size: 64, usage: UNI}));
}
}
// init uniform (one per layer, beta=0)
this.bufP_init = [];
for(let l=1; l<=L; l++) this.bufP_init.push(dev.createBuffer({size: 64, usage: UNI}));
// Grad uniform (one per layer transition)
this.bufP_grad = [];
for(let l=0; l<L; l++) this.bufP_grad.push(dev.createBuffer({size: 32, usage: UNI}));
// Aux uniforms: reward + per-layer adaptation
this.bufP_rew = dev.createBuffer({size: 32, usage: UNI});
this.bufP_adapt = []; for(let l=1; l<=L; l++) this.bufP_adapt.push(dev.createBuffer({size: 32, usage: UNI}));
// ---- bind groups ----
// Relax: per (phase, layer)
this.bgR = [[],[],[]];
for(let phase=0; phase<3; phase++){
for(let l=1; l<=L; l++){
// For layer l (1-indexed): Win = state[phase][l-2] if l>1 else bufWin
// W0 = bufW[l-1], b0 = bufB[l-1]
// W1 = bufW[l] (top-down weights to layer l+1), used if l<L
// Uh = state[phase][l-1]
// Uo = state[phase][l] (next layer), used if l<L
// Tgt = bufTgt (used if l==L)
const Win = (l===1) ? this.bufWin : this.bufU[phase][l-2];
const W0 = this.bufW[l-1], b0 = this.bufB[l-1];
const W1 = (l < L) ? this.bufW[l] : this.bufDummyR;
const Uh = this.bufU[phase][l-1];
const Uo = (l < L) ? this.bufU[phase][l] : this.bufDummyR;
const Tgt = this.bufTgt;
this.bgR[phase].push(dev.createBindGroup({layout: this.bglR, entries:[
{binding:0, resource:{buffer:this.bufP_relax[phase][l-1]}},
{binding:1, resource:{buffer:Win}},
{binding:2, resource:{buffer:W0}},
{binding:3, resource:{buffer:b0}},
{binding:4, resource:{buffer:W1}},
{binding:5, resource:{buffer:Uh}},
{binding:6, resource:{buffer:Uo}},
{binding:7, resource:{buffer:Tgt}},
{binding:8, resource:{buffer:this.bufTau[l-1]}},
]}));
}
}
// Init: per (phase, layer) — uses bufP_init (beta=0)
this.bgInit = [[],[],[]];
for(let phase=0; phase<3; phase++){
for(let l=1; l<=L; l++){
const Win = (l===1) ? this.bufWin : this.bufU[phase][l-2];
const W0 = this.bufW[l-1], b0 = this.bufB[l-1];
const W1 = (l < L) ? this.bufW[l] : this.bufDummyR;
const Uh = this.bufU[phase][l-1];
const Uo = (l < L) ? this.bufU[phase][l] : this.bufDummyR;
const Tgt = this.bufTgt;
this.bgInit[phase].push(dev.createBindGroup({layout: this.bglR, entries:[
{binding:0, resource:{buffer:this.bufP_init[l-1]}},
{binding:1, resource:{buffer:Win}},
{binding:2, resource:{buffer:W0}},
{binding:3, resource:{buffer:b0}},
{binding:4, resource:{buffer:W1}},
{binding:5, resource:{buffer:Uh}},
{binding:6, resource:{buffer:Uo}},
{binding:7, resource:{buffer:Tgt}},
{binding:8, resource:{buffer:this.bufTau[l-1]}},
]}));
}
}
// Grad: per layer transition. Note: gW and gB need offsets into bufG.
// Approach: instead of bindings to subranges, we make a SEPARATE buffer per layer for gradients (simpler).
// Pack later by reading back.
this.bufGW = []; this.bufGB = []; this.rbGW = []; this.rbGB = [];
for(let l=0; l<L; l++){
const gw = this._F32buf(S[l+1]*S[l], RW);
const gb = this._F32buf(S[l+1], RW);
this.bufGW.push(gw); this.bufGB.push(gb);
this.rbGW.push(dev.createBuffer({size: S[l+1]*S[l]*4, usage: RDS}));
this.rbGB.push(dev.createBuffer({size: S[l+1]*4, usage: RDS}));
}
// Per layer-transition bind groups for grad:
// Pre layer = state[l-1] (sizes[l]), Post layer = state[l] (sizes[l+1])
// For l=0 (input transition): Pre = bufWin, Post = state[0] (sizes[1])
// For l>0: Pre = state[l-1] (sizes[l]), Post = state[l] (sizes[l+1])
this.bgG = [];
for(let l=0; l<L; l++){
const UpreP = (l===0) ? this.bufWin : this.bufU[PHASE_P][l-1];
const UpreM = (l===0) ? this.bufWin : this.bufU[PHASE_M][l-1];
const UpostP = this.bufU[PHASE_P][l];
const UpostM = this.bufU[PHASE_M][l];
this.bgG.push(dev.createBindGroup({layout: this.bglG, entries:[
{binding:0, resource:{buffer:this.bufP_grad[l]}},
{binding:1, resource:{buffer:UpreP}},
{binding:2, resource:{buffer:UpreM}},
{binding:3, resource:{buffer:UpostP}},
{binding:4, resource:{buffer:UpostM}},
{binding:5, resource:{buffer:this.bufR}},
{binding:6, resource:{buffer:this.bufGW[l]}},
{binding:7, resource:{buffer:this.bufGB[l]}},
]}));
}
// Reward bind group (uses output layer free-phase state)
const Uo_free = this.bufU[PHASE_F][L-1];
this.bgRew = dev.createBindGroup({layout: this.bglA, entries:[
{binding:0, resource:{buffer:this.bufP_rew}},
{binding:1, resource:{buffer:Uo_free}},
{binding:2, resource:{buffer:this.bufTgt}},
{binding:3, resource:{buffer:this.bufR}},
{binding:4, resource:{buffer:this.bufDummyR}},
{binding:5, resource:{buffer:this.bufDummyRW1}},
{binding:6, resource:{buffer:this.bufDummyRW2}},
]});
// Adaptation bind groups (per layer): adapt Up,Um toward Uf
this.bgAdapt = [];
for(let l=1; l<=L; l++){
this.bgAdapt.push(dev.createBindGroup({layout: this.bglA, entries:[
{binding:0, resource:{buffer:this.bufP_adapt[l-1]}},
{binding:1, resource:{buffer:this.bufDummyR}},
{binding:2, resource:{buffer:this.bufDummyR}},
{binding:3, resource:{buffer:this.bufDummyRW3}},
{binding:4, resource:{buffer:this.bufU[PHASE_F][l-1]}},
{binding:5, resource:{buffer:this.bufU[PHASE_P][l-1]}},
{binding:6, resource:{buffer:this.bufU[PHASE_M][l-1]}},
]}));
}
}
_writeRelaxParams(buf, {ni, no, nxt, B, dt, beta, gamma, mode, has_topdown, has_target, noise_scale=0, iter_seed=0, clamp_lo=0, clamp_hi=0}){
const buf32 = new ArrayBuffer(64);
const u32 = new Uint32Array(buf32); const f32 = new Float32Array(buf32);
u32[0]=ni; u32[1]=no; u32[2]=nxt; u32[3]=B;
f32[4]=dt; f32[5]=beta; f32[6]=gamma; f32[7]=mode;
u32[8]=has_topdown; u32[9]=has_target;
f32[10]=noise_scale;
u32[11]=iter_seed;
f32[12]=clamp_lo; f32[13]=clamp_hi; f32[14]=0; f32[15]=0;
this.dev.queue.writeBuffer(buf, 0, buf32);
}
_writeGradParams(buf, {ni, no, B, two_beta, mode_pre, mode_post}){
const buf32 = new ArrayBuffer(32);
const u32 = new Uint32Array(buf32); const f32 = new Float32Array(buf32);
u32[0]=ni; u32[1]=no; u32[2]=0; u32[3]=B;
f32[4]=0; f32[5]=two_beta; f32[6]=mode_pre; f32[7]=mode_post;
this.dev.queue.writeBuffer(buf, 0, buf32);
}
_writeAuxParams(buf, {B, O, c, mode}){
const buf32 = new ArrayBuffer(32);
const u32 = new Uint32Array(buf32); const f32 = new Float32Array(buf32);
u32[0]=B; u32[1]=O; u32[2]=0; u32[3]=0;
f32[4]=c; f32[5]=mode; f32[6]=0; f32[7]=0;
this.dev.queue.writeBuffer(buf, 0, buf32);
}
uploadWeights(W, b){
const q = this.dev.queue;
for(let l=0; l<this.L; l++){
q.writeBuffer(this.bufW[l], 0, W[l].buffer, W[l].byteOffset, W[l].byteLength);
q.writeBuffer(this.bufB[l], 0, b[l].buffer, b[l].byteOffset, b[l].byteLength);
}
}
// HPSN: set per-neuron time constants. If tauMax > tauMin > 0, samples Uniform[tauMin, tauMax].
// Otherwise fills with constant scalarTau (backward-compat with old fixed-dt EqProp).
setTau(layerIdx, scalarTau, tauMin=0, tauMax=0, seed=42){
const no = this.sizes[layerIdx+1];
const arr = new Float32Array(no);
if(tauMax > tauMin && tauMin > 0){
// Deterministic LCG for reproducible per-neuron tau distribution.
let s = (seed>>>0) || 1;
const rng = ()=>{ s = (Math.imul(s, 1664525) + 1013904223) >>> 0; return s/4294967296; };
for(let i=0;i<no;i++) arr[i] = tauMin + rng()*(tauMax - tauMin);
} else {
arr.fill(scalarTau);
}
this.dev.queue.writeBuffer(this.bufTau[layerIdx], 0, arr.buffer, arr.byteOffset, arr.byteLength);
return arr; // return so caller can inspect distribution
}
// Convenience: set ALL layers to the same (scalar or distribution) tau spec.
setAllTau(scalarTau, tauMin=0, tauMax=0, seed=42){
for(let l=0; l<this.L; l++) this.setTau(l, scalarTau, tauMin, tauMax, seed + l*1000);
}
uploadInputs(X, T){
const q = this.dev.queue;
q.writeBuffer(this.bufWin, 0, X.buffer, X.byteOffset, X.byteLength);
q.writeBuffer(this.bufTgt, 0, T.buffer, T.byteOffset, T.byteLength);
}
_initAllPhases(enc){
const L = this.L;
const MAX_WG_X = 65535;
for(let phase=0; phase<3; phase++){
for(let l=1; l<=L; l++){
const n = this.B * this.sizes[l];
const wgTotal = Math.ceil(n/64);
// 2D dispatch when wgTotal exceeds per-dim limit
const wgX = Math.min(wgTotal, MAX_WG_X);
const wgY = Math.ceil(wgTotal / MAX_WG_X);
const pass = enc.beginComputePass();
pass.setPipeline(this.pipeInit);
pass.setBindGroup(0, this.bgInit[phase][l-1]);
pass.dispatchWorkgroups(wgX, wgY);
pass.end();
}
}
}
// CRITICAL: each layer update must be in its OWN compute pass so the GPU sees
// a barrier between writes to layer l's state and reads of that state by layer l+1.
// WebGPU has no implicit synchronization between dispatches within a single pass.
_runPhaseRelax(enc, phase, iters){
const L = this.L, B = this.B;
for(let t=0; t<iters; t++){
for(let l=1; l<=L; l++){
const pass = enc.beginComputePass();
pass.setPipeline(this.pipeRelax);
pass.setBindGroup(0, this.bgR[phase][l-1]);
const no = this.sizes[l];
pass.dispatchWorkgroups(Math.ceil(no/8), Math.ceil(B/8));
pass.end();
}
}
}
// Write all uniform buffers for the pass.
// noiseScale > 0 enables sEqProp; seedBase is added to iteration counter for per-call variation.
_writeAllUniformsForPass(dt, beta, noiseScale=0, seedBase=0){
const S=this.sizes, L=this.L, B=this.B, gam=this.gamma, mf=this.modeFlag;
const phaseBetas = [0, +beta, -beta]; // free, plus, minus
const ns = (typeof this.noiseScale === 'number') ? this.noiseScale : noiseScale;
const sb = (typeof this.noiseSeedBase === 'number') ? this.noiseSeedBase : seedBase;
const cLo = this.driveClampLo || 0;
const cHi = this.driveClampHi || 0;
// Relax uniforms (per phase, per layer). iter_seed is incremented per call below.
for(let phase=0; phase<3; phase++){
for(let l=1; l<=L; l++){
const isOut = (l === L);
const isHid = !isOut;
const ni = S[l-1], no = S[l], nxt = isHid ? S[l+1] : 0;
const phaseBeta = (isOut) ? phaseBetas[phase] : 0;
this._writeRelaxParams(this.bufP_relax[phase][l-1], {
ni, no, nxt, B, dt, beta: phaseBeta, gamma: gam, mode: mf,
has_topdown: isHid ? 1 : 0, has_target: isOut ? 1 : 0,
noise_scale: ns,
iter_seed: (sb + phase * 7919 + (l-1) * 1009) >>> 0,
clamp_lo: cLo, clamp_hi: cHi,
});
}
}
for(let l=1; l<=L; l++){
this._writeRelaxParams(this.bufP_init[l-1], {
ni: S[l-1], no: S[l], nxt: 0, B, dt, beta: 0, gamma: gam, mode: mf, has_topdown: 0, has_target: 0,
noise_scale: 0, iter_seed: 0, // init kernel doesn't use noise
clamp_lo: 0, clamp_hi: 0, // init kernel doesn't run drive — clamp irrelevant
});
}
}
// Tier A — runtime setter for drive clamp. Pass (0,0) to disable.
setDriveClamp(lo, hi){
this.driveClampLo = lo;
this.driveClampHi = hi;
}
// sEqProp: set per-pass noise scale and seed base. Call before runFreeAndReadOutputs / runOnePass.
setSEqPropNoise(noiseScale, seedBase){
this.noiseScale = noiseScale;
this.noiseSeedBase = (seedBase >>> 0) || 0;
}
_runReward(enc){
this._writeAuxParams(this.bufP_rew, {B: this.B, O: this.O, c: 0, mode: this.modeFlag});
const pass = enc.beginComputePass();
pass.setPipeline(this.pipeReward);
pass.setBindGroup(0, this.bgRew);
pass.dispatchWorkgroups(Math.ceil(this.B/64));
pass.end();
}
_runAdaptation(enc, adpC, adpSteps){
if(adpSteps <= 0 || this.mode === 'fhn') return; // skip adaptation in FHN mode
const L = this.L;
for(let l=1; l<=L; l++){
this._writeAuxParams(this.bufP_adapt[l-1], {B: this.B, O: this.O, c: adpC, mode: this.modeFlag});
}
const MAX_WG_X = 65535;
for(let a=0; a<adpSteps; a++){
for(let l=1; l<=L; l++){
const n = this.B * this.sizes[l];
const wgTotal = Math.ceil(n/64);
const wgX = Math.min(wgTotal, MAX_WG_X);
const wgY = Math.ceil(wgTotal / MAX_WG_X);
const pass = enc.beginComputePass();
pass.setPipeline(this.pipeAdapt);
pass.setBindGroup(0, this.bgAdapt[l-1]);
pass.dispatchWorkgroups(wgX, wgY);
pass.end();
}
}
}
_runGrad(enc, beta){
const L = this.L, B = this.B;
// Grad uniforms per layer transition
for(let l=0; l<L; l++){
const ni = this.sizes[l], no = this.sizes[l+1];
// Determine ρ modes:
// FHN: both pre/post are clip (1).
// Adaptive: σ for both EXCEPT input layer (l=0) where Win is treated with σ (mode_pre=0 always since adaptive).
const mode_pre = this.modeFlag; // 0 for adaptive, 1 for fhn — applies to all states
const mode_post = this.modeFlag;
this._writeGradParams(this.bufP_grad[l], {ni, no, B, two_beta: 2*beta, mode_pre, mode_post});
}
for(let l=0; l<L; l++){
const ni = this.sizes[l], no = this.sizes[l+1];
const pass = enc.beginComputePass();
pass.setPipeline(this.pipeGW);
pass.setBindGroup(0, this.bgG[l]);
pass.dispatchWorkgroups(Math.ceil(ni/8), Math.ceil(no/8));
pass.setPipeline(this.pipeGB);
pass.setBindGroup(0, this.bgG[l]);
pass.dispatchWorkgroups(Math.ceil(no/64));
pass.end();
}
}
async runFreeAndReadOutputs(iters, dt){
if(!this.useHPSN){
if(this._lastTauDt !== dt){ this.setAllTau(dt); this._lastTauDt = dt; }
}
this._writeAllUniformsForPass(dt, 0); // beta=0 → all phases free
const enc = this.dev.createCommandEncoder();
this._initAllPhases(enc);
this._runPhaseRelax(enc, PHASE_F, iters);
const O = this.O;
enc.copyBufferToBuffer(this.bufU[PHASE_F][this.L-1], 0, this.rbUoF, 0, this.B*O*4);
this.dev.queue.submit([enc.finish()]);
await this.rbUoF.mapAsync(GPUMapMode.READ);
const r = new Float32Array(this.rbUoF.getMappedRange().slice(0));
this.rbUoF.unmap();
return r;
}
// Free-phase relax, then read back the activations of an arbitrary internal layer (l in [1..L]).
async runFreeAndReadLayer(iters, dt, layerIdx){
if(!this.useHPSN){
if(this._lastTauDt !== dt){ this.setAllTau(dt); this._lastTauDt = dt; }
}
if(layerIdx < 1 || layerIdx > this.L) throw new Error('layerIdx out of range');
this._writeAllUniformsForPass(dt, 0);
const enc = this.dev.createCommandEncoder();
this._initAllPhases(enc);
this._runPhaseRelax(enc, PHASE_F, iters);
const size = this.B * this.sizes[layerIdx] * 4;
const rb = this.dev.createBuffer({size, usage: GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ});
enc.copyBufferToBuffer(this.bufU[PHASE_F][layerIdx-1], 0, rb, 0, size);
this.dev.queue.submit([enc.finish()]);
await rb.mapAsync(GPUMapMode.READ);
const r = new Float32Array(rb.getMappedRange().slice(0));
rb.unmap(); rb.destroy?.();
return r;
}
async runOnePassGetGradients({itF=8, itN=5, dt=0.7, beta=0.5, adpC=0.15, adpSteps=3}={}){
if(this.mode === 'fhn') adpSteps = 0;
// HPSN backward-compat: when not using heterogeneous-τ, refresh Tau to match runtime dt.
// When useHPSN=true, the user-set heterogeneous distribution is preserved (Tau not overwritten).
if(!this.useHPSN){
if(this._lastTauDt !== dt){ this.setAllTau(dt); this._lastTauDt = dt; }
}
this._writeAllUniformsForPass(dt, beta);
const enc = this.dev.createCommandEncoder();
this._initAllPhases(enc);
this._runPhaseRelax(enc, PHASE_F, itF);
this._runPhaseRelax(enc, PHASE_P, itN);
this._runPhaseRelax(enc, PHASE_M, itN);
this._runReward(enc);
this._runAdaptation(enc, adpC, adpSteps);
this._runGrad(enc, beta);
// Readback all gradients (separate buffers per layer) + Uo_free
for(let l=0; l<this.L; l++){
enc.copyBufferToBuffer(this.bufGW[l], 0, this.rbGW[l], 0, this.sizes[l+1]*this.sizes[l]*4);
enc.copyBufferToBuffer(this.bufGB[l], 0, this.rbGB[l], 0, this.sizes[l+1]*4);
}
enc.copyBufferToBuffer(this.bufU[PHASE_F][this.L-1], 0, this.rbUoF, 0, this.B*this.O*4);
this.dev.queue.submit([enc.finish()]);
const maps = [this.rbUoF.mapAsync(GPUMapMode.READ)];
for(let l=0; l<this.L; l++){ maps.push(this.rbGW[l].mapAsync(GPUMapMode.READ)); maps.push(this.rbGB[l].mapAsync(GPUMapMode.READ)); }
await Promise.all(maps);
const uoF = new Float32Array(this.rbUoF.getMappedRange().slice(0));
this.rbUoF.unmap();
const gW = [], gB = [];
for(let l=0; l<this.L; l++){
gW.push(new Float32Array(this.rbGW[l].getMappedRange().slice(0)));
gB.push(new Float32Array(this.rbGB[l].getMappedRange().slice(0)));
this.rbGW[l].unmap(); this.rbGB[l].unmap();
}
return {gW, gB, uoF};
}
destroy(){
const bufs = [this.bufWin, this.bufTgt, this.bufR, this.bufDummyR, this.bufDummyRW1, this.bufDummyRW2, this.bufDummyRW3, this.bufG, this.rbG, this.rbUoF, this.bufP_rew];
for(const arr of [this.bufW, this.bufB, this.bufGW, this.bufGB, this.rbGW, this.rbGB, this.bufP_init, this.bufP_grad, this.bufP_adapt, this.bufTau]) bufs.push(...arr);
for(const ph of this.bufU) bufs.push(...ph);
for(const ph of this.bufP_relax) bufs.push(...ph);
for(const v of bufs) if(v && v.destroy) try{ v.destroy(); }catch(e){}
}
}
// Multi-layer AdaGO optimizer.
export class AdaGODeep {
constructor(sizes, {OW_K=8}={}){
this.sizes = sizes; this.L = sizes.length - 1;
this.MW=[]; this.MB=[]; this.vW=new Float64Array(this.L); this.vB=new Float64Array(this.L);
this.OW=new Array(this.L).fill(null); this.OW_K=OW_K; this.bc=0;
for(let l=0; l<this.L; l++){
this.MW.push(new Float64Array(sizes[l+1]*sizes[l]));
this.MB.push(new Float64Array(sizes[l+1]));
}
}
step(l, W, B, gW, gB, bs, lr){
const eps=1e-8, mu=0.9, gam=1.0;
const ni=this.sizes[l], no=this.sizes[l+1];
let gn2=0; for(let k=0;k<W.length;k++){ const g=gW[k]/bs; this.MW[l][k]=mu*this.MW[l][k]+(1-mu)*g; gn2+=g*g; }
const gn=Math.sqrt(gn2); this.vW[l]+=Math.min(gn2,gam*gam);
if(!this.OW[l] || this.bc%this.OW_K===0) this.OW[l]=orthCPU(this.MW[l], no, ni);
const O=this.OW[l];
const alpha=Math.max(eps, lr*Math.min(gn,gam)/(Math.sqrt(this.vW[l])+eps));
for(let k=0;k<W.length;k++) W[k]+= alpha*O[k];
let bn2=0; for(let k=0;k<B.length;k++){ const g=gB[k]/bs; this.MB[l][k]=mu*this.MB[l][k]+(1-mu)*g; bn2+=g*g; }
const bn=Math.sqrt(bn2); this.vB[l]+=Math.min(bn2,gam*gam);
const ba=Math.max(eps, lr*Math.min(bn,gam)/(Math.sqrt(this.vB[l])+eps));
for(let k=0;k<B.length;k++) B[k]+= ba*(bn>0?this.MB[l][k]/bn:0);
}
endBatch(){ this.bc++; }
}
// Adam (with optional weight decay → AdamW). EqProp gives an ascent direction, so we += step.
export class Adam {
constructor(sizes, {beta1=0.9, beta2=0.999, eps=1e-8, weightDecay=0}={}){
this.sizes=sizes; this.L=sizes.length-1; this.beta1=beta1; this.beta2=beta2; this.eps=eps;
this.wd = weightDecay; // AdamW-style decoupled weight decay (applied to W only, not bias)
this.mW=[]; this.vW=[]; this.mB=[]; this.vB=[]; this.t=0;
for(let l=0; l<this.L; l++){
this.mW.push(new Float64Array(sizes[l+1]*sizes[l]));
this.vW.push(new Float64Array(sizes[l+1]*sizes[l]));
this.mB.push(new Float64Array(sizes[l+1]));
this.vB.push(new Float64Array(sizes[l+1]));
}
}
step(l, W, B, gW, gB, bs, lr){
this.t++;
const b1=this.beta1, b2=this.beta2, eps=this.eps;
const bc1 = 1 - Math.pow(b1, this.t), bc2 = 1 - Math.pow(b2, this.t);
const wd = this.wd;
for(let k=0;k<W.length;k++){
const g = gW[k]/bs;
this.mW[l][k] = b1*this.mW[l][k] + (1-b1)*g;
this.vW[l][k] = b2*this.vW[l][k] + (1-b2)*g*g;
const m_hat = this.mW[l][k]/bc1, v_hat = this.vW[l][k]/bc2;
// AdamW: decoupled decay (W *= 1 - lr*wd) ; EqProp ascent => + m̂/√v̂
if(wd > 0) W[k] *= (1 - lr * wd);
W[k] += lr * m_hat / (Math.sqrt(v_hat) + eps);
}
for(let k=0;k<B.length;k++){
const g = gB[k]/bs;
this.mB[l][k] = b1*this.mB[l][k] + (1-b1)*g;
this.vB[l][k] = b2*this.vB[l][k] + (1-b2)*g*g;
const m_hat = this.mB[l][k]/bc1, v_hat = this.vB[l][k]/bc2;
B[k] += lr * m_hat / (Math.sqrt(v_hat) + eps);
}
}
endBatch(){}
}
// Muon optimizer (Keller Jordan's MomentUm Orthogonalized by Newton-schulz).
// Steps in sign-of-singular-values direction. Per-step Newton-Schulz quintic on momentum.
// Uses shape-aware scaling: step = lr * sqrt(max(m,n)/min(m,n)) * Orth(M).
// Coefficients from K. Jordan's original Muon: a, b, c chosen so f(x)=ax+bx³+cx⁵ pushes
// singular values toward 1 (with controlled overshoot).
function muonOrth(M_in, m, n, iters=5){
// Normalize so spectral norm ≤ 1 (use Frobenius norm as upper bound).
// Then apply iters of X ← a X + b (XXᵀ)X + c (XXᵀ)² X (Muon's quintic Newton-Schulz)
let X = new Float64Array(M_in.length); for(let k=0;k<M_in.length;k++) X[k]=M_in[k];
let nrm=0; for(const x of X) nrm += x*x; nrm = Math.sqrt(nrm) + 1e-30;
for(let k=0;k<X.length;k++) X[k] /= nrm;
// Decide which side is smaller (use m≥n form: work with X[m,n], operate on (XᵀX) [n,n] is smaller if n<m)
// For generality just do the quintic in M form using square root tricks. Use the m≥n path; if n>m, transpose.
const transp = (n > m);
let R = m, C = n;
if(transp){
// swap to make R ≥ C
const T = new Float64Array(n*m);
for(let i=0;i<m;i++) for(let j=0;j<n;j++) T[j*m+i] = X[i*n+j];
X = T; R = n; C = m;
}
const a = 3.4445, b = -4.7750, c = 2.0315;
// helper: matmul A(p,q)·B(q,r) → out(p,r)
function mm(A, B, p, q, r){
const O = new Float64Array(p*r);
for(let i=0;i<p;i++) for(let k=0;k<q;k++){ const aa=A[i*q+k]; if(aa) for(let j=0;j<r;j++) O[i*r+j]+=aa*B[k*r+j]; }
return O;
}
function transpose(A, p, q){ const T=new Float64Array(p*q); for(let i=0;i<p;i++) for(let j=0;j<q;j++) T[j*p+i]=A[i*q+j]; return T; }
for(let it=0; it<iters; it++){
// We want X ← a X + b X(XᵀX) + c X(XᵀX)²
// Compute G = XᵀX (C×C)
const Xt = transpose(X, R, C);
const G = mm(Xt, X, C, R, C); // (C×C)
const G2 = mm(G, G, C, C, C); // (C×C)
const XG = mm(X, G, R, C, C); // (R×C)
const XG2 = mm(X, G2, R, C, C); // (R×C)
const Y = new Float64Array(R*C);
for(let k=0;k<R*C;k++) Y[k] = a*X[k] + b*XG[k] + c*XG2[k];
X = Y;
}
// Transpose back if needed
if(transp){
const O = new Float64Array(m*n);
for(let i=0;i<R;i++) for(let j=0;j<C;j++) O[j*m+i] = X[i*C+j]; // X was C×R after transp → output m×n
// wait: when transp we used X[n×m], R=n C=m. So X[i*C+j] is X[n][m]. We want output[m,n] = transpose.
// O[j*n+i] = X[i*m+j] would mean output_row_j_col_i = X_row_i_col_j. Let me redo.
const out = new Float64Array(m*n);
// X is R×C = n×m, so X[i,j] for i in 0..n, j in 0..m. We want M_orth[m,n] = (X_transposed)[a,b] = X[b,a].
for(let a=0;a<m;a++) for(let b=0;b<n;b++) out[a*n+b] = X[b*m+a];
return new Float32Array(out);
}
return new Float32Array(X);
}
export class Muon {
constructor(sizes, {beta=0.95, weightDecay=0, iters=5}={}){
this.sizes=sizes; this.L=sizes.length-1; this.beta=beta; this.wd=weightDecay; this.iters=iters;
this.MW=[]; this.mB=[];
for(let l=0; l<this.L; l++){
this.MW.push(new Float64Array(sizes[l+1]*sizes[l]));
this.mB.push(new Float64Array(sizes[l+1]));
}
}
step(l, W, B, gW, gB, bs, lr){
const beta=this.beta, wd=this.wd;
const no = this.sizes[l+1], ni = this.sizes[l];
// Momentum update (Muon uses Nesterov-style momentum)
for(let k=0;k<W.length;k++){
const g = gW[k]/bs;
this.MW[l][k] = beta*this.MW[l][k] + g;
}
// Orthogonalize momentum via quintic NS
const O = muonOrth(this.MW[l], no, ni, this.iters);
// Shape-aware scaling: lr · sqrt(max/min)
const scale = lr * Math.sqrt(Math.max(no, ni) / Math.min(no, ni));
// Step (ASCEND since EqProp gives ascent direction)
for(let k=0;k<W.length;k++){
if(wd>0) W[k] *= (1 - lr*wd);
W[k] += scale * O[k];
}
// Bias: plain momentum (Muon spec says biases get separate Adam-like; here just SGD-momentum for simplicity)
for(let k=0;k<B.length;k++){
const g = gB[k]/bs;
this.mB[l][k] = beta*this.mB[l][k] + g;
B[k] += lr * this.mB[l][k];
}
}
endBatch(){}
}
// Lion optimizer (sign of momentum). Often outperforms Adam for some tasks, less memory.
export class Lion {
constructor(sizes, {beta1=0.9, beta2=0.99, weightDecay=0}={}){
this.sizes=sizes; this.L=sizes.length-1; this.beta1=beta1; this.beta2=beta2; this.wd=weightDecay;
this.mW=[]; this.mB=[];
for(let l=0; l<this.L; l++){
this.mW.push(new Float64Array(sizes[l+1]*sizes[l]));
this.mB.push(new Float64Array(sizes[l+1]));
}
}
step(l, W, B, gW, gB, bs, lr){
const b1=this.beta1, b2=this.beta2, wd=this.wd;
for(let k=0;k<W.length;k++){
const g = gW[k]/bs;
// update direction: sign(b1*m + (1-b1)*g)
const c = b1*this.mW[l][k] + (1-b1)*g;
const u = c >= 0 ? 1 : -1;
if(wd>0) W[k] *= (1 - lr*wd);
W[k] += lr * u;
// momentum update with b2
this.mW[l][k] = b2*this.mW[l][k] + (1-b2)*g;
}
for(let k=0;k<B.length;k++){
const g = gB[k]/bs;
const c = b1*this.mB[l][k] + (1-b1)*g;
const u = c >= 0 ? 1 : -1;
B[k] += lr * u;
this.mB[l][k] = b2*this.mB[l][k] + (1-b2)*g;
}
}
endBatch(){}
}