Spaces:
Running
Running
| // Multi-hidden-layer WebGPU EqProp trainer. | |
| // sizes = [D, H1, H2, ..., Hk, O] | |
| // Modes: | |
| // 0 = adaptive (σ everywhere, u-nudge at output) | |
| // 1 = fhn (clip everywhere, ρ-nudge at output) | |
| // | |
| // Key design: | |
| // * ONE generic compute pipeline `pass_layer` updates any layer using uniforms (ni, no, nxt, ...). | |
| // * 3 phases (free / plus / minus): each has its own per-layer state buffer + per-layer bind group. | |
| // * Uniform buffers are pre-written once per pass (one per (phase, layer) pair), so the encoder | |
| // records dispatches that each pick up the right uniforms. | |
| // * Gradient kernels are also generic — per layer transition, compute outer product reduction over batch. | |
| import { orth as orthCPU } from './eqprop_lib.js'; | |
| // ----- WGSL: relax (one kernel handles any layer) ----- | |
| const WGSL_RELAX = ` | |
| struct P { | |
| ni : u32, no : u32, nxt : u32, B : u32, | |
| dt : f32, beta : f32, gamma : f32, mode : f32, // mode: 0=adaptive σ, 1=fhn(clip+cubic), 2=prism (soft-clip via softplus) | |
| has_topdown : u32, has_target : u32, noise_scale : f32, iter_seed : u32, | |
| // sEqProp: noise_scale > 0 injects per-iter per-neuron Gaussian-ish noise into drive. | |
| // Bio-faithful (real synapses are stochastic). At test, run M passes & average outputs. | |
| clamp_lo : f32, clamp_hi : f32, _p_t1 : f32, _p_t2 : f32, | |
| // Tier A — pre-σ drive clamp (algorithmic, uniform-driven). Bounds the pre-activation | |
| // value c before σ(c) to prevent saturation runaway. When clamp_hi <= clamp_lo the | |
| // kernel treats it as DISABLED (no-op). Default in caller = clamp_lo=clamp_hi=0 → disabled. | |
| }; | |
| @group(0) @binding(0) var<uniform> p : P; | |
| @group(0) @binding(1) var<storage, read> Win : array<f32>; | |
| @group(0) @binding(2) var<storage, read> W0 : array<f32>; // [no x ni] | |
| @group(0) @binding(3) var<storage, read> b0 : array<f32>; // [no] | |
| @group(0) @binding(4) var<storage, read> W1 : array<f32>; // [nxt x no] (top-down) | |
| @group(0) @binding(5) var<storage, read_write> Uh : array<f32>; // [B*no] | |
| @group(0) @binding(6) var<storage, read> Uo : array<f32>; // [B*nxt] | |
| @group(0) @binding(7) var<storage, read> Tgt : array<f32>; // [B*no] | |
| // HPSN: heterogeneous time constants — per-neuron multiplier on drive integration. | |
| // Tau[i] replaces the global p.dt. Constant Tau[i]=p.dt → behavior identical to scalar-dt EqProp. | |
| // Sampled from Uniform[tau_min, tau_max] → diverse temporal scales like real cortical neurons. | |
| @group(0) @binding(8) var<storage, read> Tau : array<f32>; // [no] | |
| const A1 : f32 = 0.07407407407407407; | |
| const PRISM_K : f32 = 10.0; // sharpness; higher k → harder clip | |
| // PCG-style cheap hash → uniform [0,1). Per-thread, per-iter, per-neuron stochasticity. | |
| fn pcg_hash(seed_in: u32) -> u32 { | |
| var state : u32 = seed_in * 747796405u + 2891336453u; | |
| let word : u32 = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; | |
| return (word >> 22u) ^ word; | |
| } | |
| fn unif_noise(b: u32, i: u32, t: u32) -> f32 { | |
| // Triangular distribution (sum of 2 uniforms - 1) ≈ Gaussian-ish, mean=0, variance=1/6. | |
| let h1 = pcg_hash(b * 65537u + i * 257u + t * 31u); | |
| let h2 = pcg_hash(b * 31337u + i * 1031u + t * 17u + 12345u); | |
| let u1 = f32(h1) / 4294967296.0; | |
| let u2 = f32(h2) / 4294967296.0; | |
| return (u1 + u2) - 1.0; // range [-1, 1], roughly triangular | |
| } | |
| fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } | |
| fn fhn_rho(u: f32) -> f32 { return clamp(u, 0.0, 1.0); } | |
| fn fhn_rho_p(u: f32) -> f32 { return select(0.0, 1.0, u > 0.0 && u < 1.0); } | |
| fn fhn_f(u: f32) -> f32 { return A1 * u - u*u*u; } | |
| // PRISM activation: ρ(u) = (softplus(k·u) - softplus(k·(u-1))) / k | |
| // Smooth approximation of clip(u,0,1). Derivative: σ(k·u) - σ(k·(u-1)). | |
| // "Prism" = splits drive into a smooth-yet-saturating activation with gradient flow on both sides. | |
| fn softplus(x: f32) -> f32 { return select(x + log(1.0 + exp(-x)), log(1.0 + exp(x)), x <= 0.0); } | |
| fn prism_rho(u: f32) -> f32 { | |
| return (softplus(PRISM_K * u) - softplus(PRISM_K * (u - 1.0))) / PRISM_K; | |
| } | |
| fn prism_rho_p(u: f32) -> f32 { | |
| let k = PRISM_K; | |
| return 1.0/(1.0+exp(-k*u)) - 1.0/(1.0+exp(-k*(u-1.0))); | |
| } | |
| fn rho(u: f32) -> f32 { | |
| if (p.mode > 1.5) { return prism_rho(u); } | |
| if (p.mode > 0.5) { return fhn_rho(u); } | |
| return sg(u); | |
| } | |
| @compute @workgroup_size(8, 8) fn pass_layer(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let b = gid.y; let i = gid.x; | |
| if (b >= p.B || i >= p.no) { return; } | |
| // bottom-up: c = b0[i] + sum_j W0[i,j] * rho(Win[b,j]) | |
| var c : f32 = b0[i]; | |
| let row0 = i * p.ni; | |
| let xoff = b * p.ni; | |
| for (var j: u32 = 0u; j < p.ni; j = j + 1u) { | |
| c = c + W0[row0 + j] * rho(Win[xoff + j]); | |
| } | |
| // top-down: gamma * sum_k W1[k,i] * rho(Uo[b,k]) (if next layer exists) | |
| if (p.has_topdown != 0u) { | |
| var td : f32 = 0.0; | |
| let uo_off = b * p.nxt; | |
| for (var k: u32 = 0u; k < p.nxt; k = k + 1u) { | |
| td = td + W1[k * p.no + i] * rho(Uo[uo_off + k]); | |
| } | |
| if (p.mode > 0.5) { c = c + td; } else { c = c + p.gamma * td; } | |
| } | |
| // Tier A — pre-σ drive clamp (algorithmic). Active iff clamp_hi > clamp_lo. | |
| if (p.clamp_hi > p.clamp_lo) { | |
| c = clamp(c, p.clamp_lo, p.clamp_hi); | |
| } | |
| let idx = b * p.no + i; | |
| let u_old = Uh[idx]; | |
| // sEqProp noise: per-(b, i, iter_seed) triangular noise added to drive. Zero by default. | |
| let noise = select(0.0, p.noise_scale * unif_noise(b, i, p.iter_seed), p.noise_scale > 0.0); | |
| var u_new : f32; | |
| if (p.mode > 1.5) { | |
| // PRISM: u̇ = ρ'(u)·c + (linear pull) ; ρ-nudge for output. Smooth saturating dynamics. | |
| var drive : f32 = prism_rho_p(u_old) * c - 0.1 * (u_old - 0.5) + noise; | |
| if (p.has_target != 0u && p.beta != 0.0) { | |
| drive = drive + p.beta * (Tgt[idx] - prism_rho(u_old)); | |
| } | |
| u_new = u_old + Tau[i] * drive; | |
| u_new = clamp(u_new, -0.3, 1.3); | |
| } else if (p.mode > 0.5) { | |
| // FHN | |
| var drive : f32 = fhn_rho_p(u_old) * c + fhn_f(u_old) + noise; | |
| if (p.has_target != 0u && p.beta != 0.0) { | |
| drive = drive + p.beta * (Tgt[idx] - fhn_rho(u_old)); | |
| } | |
| u_new = u_old + Tau[i] * drive; | |
| u_new = clamp(u_new, -0.2, 1.2); | |
| } else { | |
| // Adaptive | |
| var drive : f32 = -u_old + sg(c) + noise; | |
| if (p.has_target != 0u && p.beta != 0.0) { | |
| drive = drive + p.beta * (Tgt[idx] - u_old); | |
| } | |
| u_new = u_old + Tau[i] * drive; | |
| } | |
| Uh[idx] = u_new; | |
| } | |
| // 2D dispatch to handle big buffers (B*no can exceed the 65535 per-dim workgroup limit). | |
| @compute @workgroup_size(64) fn init_state(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let stride = 65535u * 64u; // workgroups_per_X * threads_per_workgroup | |
| let g = gid.y * stride + gid.x; | |
| let n = p.B * p.no; | |
| if (g < n) { Uh[g] = 0.1; } | |
| } | |
| `; | |
| // ----- WGSL: gradient (one kernel handles any layer transition) ----- | |
| const WGSL_GRAD = ` | |
| struct GP { | |
| ni : u32, no : u32, _pad : u32, B : u32, | |
| c : f32, two_beta : f32, mode_pre : f32, mode_post : f32, // mode_pre/post: 0=σ, 1=clip, 2=identity | |
| }; | |
| @group(0) @binding(0) var<uniform> p : GP; | |
| @group(0) @binding(1) var<storage, read> UpreP : array<f32>; // [B*ni] - "input" layer state, plus phase | |
| @group(0) @binding(2) var<storage, read> UpreM : array<f32>; // [B*ni] - minus | |
| @group(0) @binding(3) var<storage, read> UpostP: array<f32>; // [B*no] | |
| @group(0) @binding(4) var<storage, read> UpostM: array<f32>; // [B*no] | |
| @group(0) @binding(5) var<storage, read> R : array<f32>; // [B] | |
| @group(0) @binding(6) var<storage, read_write> gW : array<f32>; // [no*ni] | |
| @group(0) @binding(7) var<storage, read_write> gB : array<f32>; // [no] | |
| fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } | |
| const PRISM_K2 : f32 = 10.0; | |
| fn softplus2(x: f32) -> f32 { return select(x + log(1.0 + exp(-x)), log(1.0 + exp(x)), x <= 0.0); } | |
| fn prism_rho_g(u: f32) -> f32 { return (softplus2(PRISM_K2*u) - softplus2(PRISM_K2*(u-1.0))) / PRISM_K2; } | |
| fn rho_mode(u: f32, m: f32) -> f32 { | |
| if (m > 2.5) { return u; } // identity (linear) | |
| if (m > 1.5) { return prism_rho_g(u); } // prism soft-clip | |
| if (m > 0.5) { return clamp(u, 0.0, 1.0); } // hard-clip (FHN) | |
| return sg(u); // σ (adaptive) | |
| } | |
| @compute @workgroup_size(8, 8) fn grad_W(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let i = gid.y; let j = gid.x; | |
| if (i >= p.no || j >= p.ni) { return; } | |
| var acc : f32 = 0.0; | |
| for (var b: u32 = 0u; b < p.B; b = b + 1u) { | |
| let rh = R[b]; | |
| let ip = rho_mode(UpostP[b * p.no + i], p.mode_post); | |
| let im = rho_mode(UpostM[b * p.no + i], p.mode_post); | |
| let jp = rho_mode(UpreP[b * p.ni + j], p.mode_pre); | |
| let jm = rho_mode(UpreM[b * p.ni + j], p.mode_pre); | |
| acc = acc + rh * (ip * jp - im * jm); | |
| } | |
| gW[i * p.ni + j] = acc / p.two_beta; | |
| } | |
| @compute @workgroup_size(64) fn grad_B(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let i = gid.x; | |
| if (i >= p.no) { return; } | |
| var acc : f32 = 0.0; | |
| for (var b: u32 = 0u; b < p.B; b = b + 1u) { | |
| let rh = R[b]; | |
| let ip = rho_mode(UpostP[b * p.no + i], p.mode_post); | |
| let im = rho_mode(UpostM[b * p.no + i], p.mode_post); | |
| acc = acc + rh * (ip - im); | |
| } | |
| gB[i] = acc / p.two_beta; | |
| } | |
| `; | |
| // ----- WGSL: reward + adaptation (depends on output layer state) ----- | |
| const WGSL_AUX = ` | |
| struct AP { | |
| B : u32, O : u32, H_max : u32, n_hidden : u32, | |
| c : f32, mode : f32, _p0 : f32, _p1 : f32, | |
| }; | |
| @group(0) @binding(0) var<uniform> p : AP; | |
| @group(0) @binding(1) var<storage, read> UoF : array<f32>; | |
| @group(0) @binding(2) var<storage, read> Tgt : array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> R : array<f32>; | |
| // adaptation buffers (variable size; we pass single layer at a time via separate bind groups) | |
| @group(0) @binding(4) var<storage, read> Uf : array<f32>; | |
| @group(0) @binding(5) var<storage, read_write> Up : array<f32>; | |
| @group(0) @binding(6) var<storage, read_write> Um : array<f32>; | |
| fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } | |
| fn rho_out(u: f32) -> f32 { | |
| if (p.mode > 0.5) { return clamp(u, 0.0, 1.0); } | |
| return sg(u); | |
| } | |
| @compute @workgroup_size(64) fn compute_reward(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let b = gid.x; | |
| if (b >= p.B) { return; } | |
| var loss : f32 = 0.0; | |
| let off = b * p.O; | |
| for (var i: u32 = 0u; i < p.O; i = i + 1u) { | |
| let d = rho_out(UoF[off + i]) - Tgt[off + i]; | |
| loss = loss + d * d; | |
| } | |
| let escale : f32 = 0.4; | |
| let rmin : f32 = 0.1; | |
| var r : f32 = loss / escale; | |
| if (r > 1.0) { r = 1.0; } | |
| R[b] = rmin + (1.0 - rmin) * r; | |
| } | |
| // Adjusted Adaptation per layer: Up,Um ← (1-c)*Up + c*Uf. 2D dispatch safe for large buffers. | |
| @compute @workgroup_size(64) fn adapt_layer(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let stride = 65535u * 64u; | |
| let g = gid.y * stride + gid.x; | |
| if (g >= arrayLength(&Uf)) { return; } | |
| let f = Uf[g]; | |
| Up[g] = (1.0 - p.c) * Up[g] + p.c * f; | |
| Um[g] = (1.0 - p.c) * Um[g] + p.c * f; | |
| } | |
| `; | |
| // ----- JS: trainer class ----- | |
| export async function makeGPUDeep({powerPreference='high-performance'}={}){ | |
| if(!navigator.gpu) throw new Error('no webgpu'); | |
| const adapter = await navigator.gpu.requestAdapter({powerPreference}); | |
| if(!adapter) throw new Error('no adapter'); | |
| const want = {}; | |
| const tryKeys = ['maxStorageBuffersPerShaderStage','maxBufferSize','maxStorageBufferBindingSize', | |
| 'maxComputeInvocationsPerWorkgroup','maxComputeWorkgroupSizeX','maxComputeWorkgroupStorageSize','maxBindGroups']; | |
| for(const k of tryKeys){ const v=adapter.limits[k]; if(typeof v==='number') want[k]=v; } | |
| const dev = await adapter.requestDevice({requiredLimits: want}); | |
| return {adapter, dev, info: adapter.info||{}}; | |
| } | |
| const PHASE_F = 0, PHASE_P = 1, PHASE_M = 2; | |
| export class GPUTrainerDeep { | |
| // sizes: [D, H1, H2, ..., Hk, O] — len L+1; L = number of weight matrices = sizes.length-1 | |
| // mode: 'adaptive' | 'fhn' | 'prism' | |
| // driveClampLo, driveClampHi: Tier A — pre-σ drive clamp; ACTIVE iff hi > lo. Default 0,0 = disabled. | |
| constructor({dev, sizes, B, mode='adaptive', gamma=0.6, hpsnTauMin=0, hpsnTauMax=0, hpsnSeed=42, driveClampLo=0, driveClampHi=0}={}){ | |
| this.dev = dev; this.sizes = sizes; | |
| this.L = sizes.length - 1; // number of weight matrices (transitions) | |
| this.B = B; this.O = sizes[sizes.length-1]; | |
| this.mode = mode; | |
| this.modeFlag = (mode==='prism') ? 2.0 : (mode==='fhn' ? 1.0 : 0.0); | |
| this.gamma = gamma; | |
| this.hpsnTauMin = hpsnTauMin; | |
| this.hpsnTauMax = hpsnTauMax; | |
| this.hpsnSeed = hpsnSeed; | |
| this.useHPSN = (hpsnTauMax > hpsnTauMin && hpsnTauMin > 0); | |
| this.driveClampLo = driveClampLo; | |
| this.driveClampHi = driveClampHi; | |
| this._build(); | |
| // Initialize Tau buffers — either constant=0.7 (backward compat) or per-neuron Uniform[hpsnTauMin, hpsnTauMax]. | |
| if(this.useHPSN){ | |
| this.setAllTau(0.7, hpsnTauMin, hpsnTauMax, hpsnSeed); | |
| } else { | |
| this.setAllTau(0.7); | |
| } | |
| } | |
| _F32buf(n, usage){ | |
| if(!Number.isFinite(n) || n <= 0){ | |
| console.error('BAD _F32buf size', {n, sizes:JSON.stringify(this.sizes), sizesArr:this.sizes, B:this.B, L:this.L, S0:this.sizes&&this.sizes[0], S0type:typeof (this.sizes&&this.sizes[0])}); | |
| throw new Error('_F32buf called with non-finite n=' + n + ' sizes=' + JSON.stringify(this.sizes)); | |
| } | |
| return this.dev.createBuffer({size:Math.max(4,n*4), usage}); | |
| } | |
| _build(){ | |
| const dev = this.dev, S = this.sizes, B = this.B, L = this.L; | |
| const RW = GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_SRC|GPUBufferUsage.COPY_DST; | |
| const R = GPUBufferUsage.STORAGE|GPUBufferUsage.COPY_DST|GPUBufferUsage.COPY_SRC; | |
| const UNI = GPUBufferUsage.UNIFORM|GPUBufferUsage.COPY_DST; | |
| const RDS = GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ; | |
| // input + target (shared across phases) | |
| this.bufWin = this._F32buf(B * S[0], R); | |
| this.bufTgt = this._F32buf(B * S[L], R); | |
| // weights & biases (one per transition) | |
| this.bufW = []; this.bufB = []; | |
| for(let l=0; l<L; l++){ | |
| this.bufW.push(this._F32buf(S[l+1]*S[l], R)); | |
| this.bufB.push(this._F32buf(S[l+1], R)); | |
| } | |
| // HPSN: per-layer Tau buffer of size [no]. Default = uniform scalar dt (backward compat). | |
| // User can call setHeterogeneousTau(layer, tauMin, tauMax) to enable HPSN per layer. | |
| this.bufTau = []; | |
| for(let l=0; l<L; l++){ this.bufTau.push(this._F32buf(S[l+1], R)); } | |
| // state buffers: for each of 3 phases, L state buffers (one per non-input layer) | |
| this.bufU = [[],[],[]]; // bufU[phase][l] is layer l+1's state (l=0..L-1, sizes S[1..L]) | |
| for(let phase=0; phase<3; phase++){ | |
| for(let l=1; l<=L; l++){ | |
| this.bufU[phase].push(this._F32buf(B * S[l], RW)); | |
| } | |
| } | |
| // reward + dummies (need separate buffers for read-only vs writable slots to avoid aliasing). | |
| this.bufR = this._F32buf(B, RW); | |
| this.bufDummyR = this._F32buf(4, R); // read-only dummy | |
| this.bufDummyRW1 = this._F32buf(4, RW); // writable dummy slot 1 | |
| this.bufDummyRW2 = this._F32buf(4, RW); // writable dummy slot 2 (different buffer!) | |
| this.bufDummyRW3 = this._F32buf(4, RW); | |
| // gradient buffer (packed: all gW and gB together) | |
| const gSizes = []; let total=0; | |
| for(let l=0; l<L; l++){ gSizes.push({offW:total, sizW:S[l+1]*S[l], offB:total+S[l+1]*S[l], sizB:S[l+1], total:S[l+1]*S[l]+S[l+1]}); total += gSizes[l].total; } | |
| this.gOff = gSizes; this.gTotal = total; | |
| this.bufG = this._F32buf(total, RW); | |
| this.rbG = dev.createBuffer({size: total*4, usage: RDS}); | |
| // readback for output free-phase Uo (for accuracy/loss) | |
| this.rbUoF = dev.createBuffer({size: B * S[L] * 4, usage: RDS}); | |
| // ---- pipelines ---- | |
| // Relax pipeline (generic) | |
| const modR = dev.createShaderModule({code: WGSL_RELAX}); | |
| const sR = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'read-only-storage'}}); | |
| const sRW= (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'storage'}}); | |
| const uN = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'uniform'}}); | |
| this.bglR = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sRW(5), sR(6), sR(7), sR(8)]}); | |
| this.plR = dev.createPipelineLayout({bindGroupLayouts:[this.bglR]}); | |
| this.pipeRelax = dev.createComputePipeline({layout:this.plR, compute:{module:modR, entryPoint:'pass_layer'}}); | |
| this.pipeInit = dev.createComputePipeline({layout:this.plR, compute:{module:modR, entryPoint:'init_state'}}); | |
| // Grad pipeline (generic) | |
| const modG = dev.createShaderModule({code: WGSL_GRAD}); | |
| this.bglG = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]}); | |
| this.plG = dev.createPipelineLayout({bindGroupLayouts:[this.bglG]}); | |
| this.pipeGW = dev.createComputePipeline({layout:this.plG, compute:{module:modG, entryPoint:'grad_W'}}); | |
| this.pipeGB = dev.createComputePipeline({layout:this.plG, compute:{module:modG, entryPoint:'grad_B'}}); | |
| // Aux pipeline (reward + adaptation) | |
| const modA = dev.createShaderModule({code: WGSL_AUX}); | |
| this.bglA = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sRW(3), sR(4), sRW(5), sRW(6)]}); | |
| this.plA = dev.createPipelineLayout({bindGroupLayouts:[this.bglA]}); | |
| this.pipeReward = dev.createComputePipeline({layout:this.plA, compute:{module:modA, entryPoint:'compute_reward'}}); | |
| this.pipeAdapt = dev.createComputePipeline({layout:this.plA, compute:{module:modA, entryPoint:'adapt_layer'}}); | |
| // ---- uniform buffers (one per (phase, layer)) ---- | |
| // Each entry: 48 bytes (12 u32/f32 slots) | |
| this.bufP_relax = []; | |
| for(let phase=0; phase<3; phase++){ | |
| this.bufP_relax.push([]); | |
| for(let l=1; l<=L; l++){ | |
| this.bufP_relax[phase].push(dev.createBuffer({size: 64, usage: UNI})); | |
| } | |
| } | |
| // init uniform (one per layer, beta=0) | |
| this.bufP_init = []; | |
| for(let l=1; l<=L; l++) this.bufP_init.push(dev.createBuffer({size: 64, usage: UNI})); | |
| // Grad uniform (one per layer transition) | |
| this.bufP_grad = []; | |
| for(let l=0; l<L; l++) this.bufP_grad.push(dev.createBuffer({size: 32, usage: UNI})); | |
| // Aux uniforms: reward + per-layer adaptation | |
| this.bufP_rew = dev.createBuffer({size: 32, usage: UNI}); | |
| this.bufP_adapt = []; for(let l=1; l<=L; l++) this.bufP_adapt.push(dev.createBuffer({size: 32, usage: UNI})); | |
| // ---- bind groups ---- | |
| // Relax: per (phase, layer) | |
| this.bgR = [[],[],[]]; | |
| for(let phase=0; phase<3; phase++){ | |
| for(let l=1; l<=L; l++){ | |
| // For layer l (1-indexed): Win = state[phase][l-2] if l>1 else bufWin | |
| // W0 = bufW[l-1], b0 = bufB[l-1] | |
| // W1 = bufW[l] (top-down weights to layer l+1), used if l<L | |
| // Uh = state[phase][l-1] | |
| // Uo = state[phase][l] (next layer), used if l<L | |
| // Tgt = bufTgt (used if l==L) | |
| const Win = (l===1) ? this.bufWin : this.bufU[phase][l-2]; | |
| const W0 = this.bufW[l-1], b0 = this.bufB[l-1]; | |
| const W1 = (l < L) ? this.bufW[l] : this.bufDummyR; | |
| const Uh = this.bufU[phase][l-1]; | |
| const Uo = (l < L) ? this.bufU[phase][l] : this.bufDummyR; | |
| const Tgt = this.bufTgt; | |
| this.bgR[phase].push(dev.createBindGroup({layout: this.bglR, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_relax[phase][l-1]}}, | |
| {binding:1, resource:{buffer:Win}}, | |
| {binding:2, resource:{buffer:W0}}, | |
| {binding:3, resource:{buffer:b0}}, | |
| {binding:4, resource:{buffer:W1}}, | |
| {binding:5, resource:{buffer:Uh}}, | |
| {binding:6, resource:{buffer:Uo}}, | |
| {binding:7, resource:{buffer:Tgt}}, | |
| {binding:8, resource:{buffer:this.bufTau[l-1]}}, | |
| ]})); | |
| } | |
| } | |
| // Init: per (phase, layer) — uses bufP_init (beta=0) | |
| this.bgInit = [[],[],[]]; | |
| for(let phase=0; phase<3; phase++){ | |
| for(let l=1; l<=L; l++){ | |
| const Win = (l===1) ? this.bufWin : this.bufU[phase][l-2]; | |
| const W0 = this.bufW[l-1], b0 = this.bufB[l-1]; | |
| const W1 = (l < L) ? this.bufW[l] : this.bufDummyR; | |
| const Uh = this.bufU[phase][l-1]; | |
| const Uo = (l < L) ? this.bufU[phase][l] : this.bufDummyR; | |
| const Tgt = this.bufTgt; | |
| this.bgInit[phase].push(dev.createBindGroup({layout: this.bglR, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_init[l-1]}}, | |
| {binding:1, resource:{buffer:Win}}, | |
| {binding:2, resource:{buffer:W0}}, | |
| {binding:3, resource:{buffer:b0}}, | |
| {binding:4, resource:{buffer:W1}}, | |
| {binding:5, resource:{buffer:Uh}}, | |
| {binding:6, resource:{buffer:Uo}}, | |
| {binding:7, resource:{buffer:Tgt}}, | |
| {binding:8, resource:{buffer:this.bufTau[l-1]}}, | |
| ]})); | |
| } | |
| } | |
| // Grad: per layer transition. Note: gW and gB need offsets into bufG. | |
| // Approach: instead of bindings to subranges, we make a SEPARATE buffer per layer for gradients (simpler). | |
| // Pack later by reading back. | |
| this.bufGW = []; this.bufGB = []; this.rbGW = []; this.rbGB = []; | |
| for(let l=0; l<L; l++){ | |
| const gw = this._F32buf(S[l+1]*S[l], RW); | |
| const gb = this._F32buf(S[l+1], RW); | |
| this.bufGW.push(gw); this.bufGB.push(gb); | |
| this.rbGW.push(dev.createBuffer({size: S[l+1]*S[l]*4, usage: RDS})); | |
| this.rbGB.push(dev.createBuffer({size: S[l+1]*4, usage: RDS})); | |
| } | |
| // Per layer-transition bind groups for grad: | |
| // Pre layer = state[l-1] (sizes[l]), Post layer = state[l] (sizes[l+1]) | |
| // For l=0 (input transition): Pre = bufWin, Post = state[0] (sizes[1]) | |
| // For l>0: Pre = state[l-1] (sizes[l]), Post = state[l] (sizes[l+1]) | |
| this.bgG = []; | |
| for(let l=0; l<L; l++){ | |
| const UpreP = (l===0) ? this.bufWin : this.bufU[PHASE_P][l-1]; | |
| const UpreM = (l===0) ? this.bufWin : this.bufU[PHASE_M][l-1]; | |
| const UpostP = this.bufU[PHASE_P][l]; | |
| const UpostM = this.bufU[PHASE_M][l]; | |
| this.bgG.push(dev.createBindGroup({layout: this.bglG, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_grad[l]}}, | |
| {binding:1, resource:{buffer:UpreP}}, | |
| {binding:2, resource:{buffer:UpreM}}, | |
| {binding:3, resource:{buffer:UpostP}}, | |
| {binding:4, resource:{buffer:UpostM}}, | |
| {binding:5, resource:{buffer:this.bufR}}, | |
| {binding:6, resource:{buffer:this.bufGW[l]}}, | |
| {binding:7, resource:{buffer:this.bufGB[l]}}, | |
| ]})); | |
| } | |
| // Reward bind group (uses output layer free-phase state) | |
| const Uo_free = this.bufU[PHASE_F][L-1]; | |
| this.bgRew = dev.createBindGroup({layout: this.bglA, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_rew}}, | |
| {binding:1, resource:{buffer:Uo_free}}, | |
| {binding:2, resource:{buffer:this.bufTgt}}, | |
| {binding:3, resource:{buffer:this.bufR}}, | |
| {binding:4, resource:{buffer:this.bufDummyR}}, | |
| {binding:5, resource:{buffer:this.bufDummyRW1}}, | |
| {binding:6, resource:{buffer:this.bufDummyRW2}}, | |
| ]}); | |
| // Adaptation bind groups (per layer): adapt Up,Um toward Uf | |
| this.bgAdapt = []; | |
| for(let l=1; l<=L; l++){ | |
| this.bgAdapt.push(dev.createBindGroup({layout: this.bglA, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_adapt[l-1]}}, | |
| {binding:1, resource:{buffer:this.bufDummyR}}, | |
| {binding:2, resource:{buffer:this.bufDummyR}}, | |
| {binding:3, resource:{buffer:this.bufDummyRW3}}, | |
| {binding:4, resource:{buffer:this.bufU[PHASE_F][l-1]}}, | |
| {binding:5, resource:{buffer:this.bufU[PHASE_P][l-1]}}, | |
| {binding:6, resource:{buffer:this.bufU[PHASE_M][l-1]}}, | |
| ]})); | |
| } | |
| } | |
| _writeRelaxParams(buf, {ni, no, nxt, B, dt, beta, gamma, mode, has_topdown, has_target, noise_scale=0, iter_seed=0, clamp_lo=0, clamp_hi=0}){ | |
| const buf32 = new ArrayBuffer(64); | |
| const u32 = new Uint32Array(buf32); const f32 = new Float32Array(buf32); | |
| u32[0]=ni; u32[1]=no; u32[2]=nxt; u32[3]=B; | |
| f32[4]=dt; f32[5]=beta; f32[6]=gamma; f32[7]=mode; | |
| u32[8]=has_topdown; u32[9]=has_target; | |
| f32[10]=noise_scale; | |
| u32[11]=iter_seed; | |
| f32[12]=clamp_lo; f32[13]=clamp_hi; f32[14]=0; f32[15]=0; | |
| this.dev.queue.writeBuffer(buf, 0, buf32); | |
| } | |
| _writeGradParams(buf, {ni, no, B, two_beta, mode_pre, mode_post}){ | |
| const buf32 = new ArrayBuffer(32); | |
| const u32 = new Uint32Array(buf32); const f32 = new Float32Array(buf32); | |
| u32[0]=ni; u32[1]=no; u32[2]=0; u32[3]=B; | |
| f32[4]=0; f32[5]=two_beta; f32[6]=mode_pre; f32[7]=mode_post; | |
| this.dev.queue.writeBuffer(buf, 0, buf32); | |
| } | |
| _writeAuxParams(buf, {B, O, c, mode}){ | |
| const buf32 = new ArrayBuffer(32); | |
| const u32 = new Uint32Array(buf32); const f32 = new Float32Array(buf32); | |
| u32[0]=B; u32[1]=O; u32[2]=0; u32[3]=0; | |
| f32[4]=c; f32[5]=mode; f32[6]=0; f32[7]=0; | |
| this.dev.queue.writeBuffer(buf, 0, buf32); | |
| } | |
| uploadWeights(W, b){ | |
| const q = this.dev.queue; | |
| for(let l=0; l<this.L; l++){ | |
| q.writeBuffer(this.bufW[l], 0, W[l].buffer, W[l].byteOffset, W[l].byteLength); | |
| q.writeBuffer(this.bufB[l], 0, b[l].buffer, b[l].byteOffset, b[l].byteLength); | |
| } | |
| } | |
| // HPSN: set per-neuron time constants. If tauMax > tauMin > 0, samples Uniform[tauMin, tauMax]. | |
| // Otherwise fills with constant scalarTau (backward-compat with old fixed-dt EqProp). | |
| setTau(layerIdx, scalarTau, tauMin=0, tauMax=0, seed=42){ | |
| const no = this.sizes[layerIdx+1]; | |
| const arr = new Float32Array(no); | |
| if(tauMax > tauMin && tauMin > 0){ | |
| // Deterministic LCG for reproducible per-neuron tau distribution. | |
| let s = (seed>>>0) || 1; | |
| const rng = ()=>{ s = (Math.imul(s, 1664525) + 1013904223) >>> 0; return s/4294967296; }; | |
| for(let i=0;i<no;i++) arr[i] = tauMin + rng()*(tauMax - tauMin); | |
| } else { | |
| arr.fill(scalarTau); | |
| } | |
| this.dev.queue.writeBuffer(this.bufTau[layerIdx], 0, arr.buffer, arr.byteOffset, arr.byteLength); | |
| return arr; // return so caller can inspect distribution | |
| } | |
| // Convenience: set ALL layers to the same (scalar or distribution) tau spec. | |
| setAllTau(scalarTau, tauMin=0, tauMax=0, seed=42){ | |
| for(let l=0; l<this.L; l++) this.setTau(l, scalarTau, tauMin, tauMax, seed + l*1000); | |
| } | |
| uploadInputs(X, T){ | |
| const q = this.dev.queue; | |
| q.writeBuffer(this.bufWin, 0, X.buffer, X.byteOffset, X.byteLength); | |
| q.writeBuffer(this.bufTgt, 0, T.buffer, T.byteOffset, T.byteLength); | |
| } | |
| _initAllPhases(enc){ | |
| const L = this.L; | |
| const MAX_WG_X = 65535; | |
| for(let phase=0; phase<3; phase++){ | |
| for(let l=1; l<=L; l++){ | |
| const n = this.B * this.sizes[l]; | |
| const wgTotal = Math.ceil(n/64); | |
| // 2D dispatch when wgTotal exceeds per-dim limit | |
| const wgX = Math.min(wgTotal, MAX_WG_X); | |
| const wgY = Math.ceil(wgTotal / MAX_WG_X); | |
| const pass = enc.beginComputePass(); | |
| pass.setPipeline(this.pipeInit); | |
| pass.setBindGroup(0, this.bgInit[phase][l-1]); | |
| pass.dispatchWorkgroups(wgX, wgY); | |
| pass.end(); | |
| } | |
| } | |
| } | |
| // CRITICAL: each layer update must be in its OWN compute pass so the GPU sees | |
| // a barrier between writes to layer l's state and reads of that state by layer l+1. | |
| // WebGPU has no implicit synchronization between dispatches within a single pass. | |
| _runPhaseRelax(enc, phase, iters){ | |
| const L = this.L, B = this.B; | |
| for(let t=0; t<iters; t++){ | |
| for(let l=1; l<=L; l++){ | |
| const pass = enc.beginComputePass(); | |
| pass.setPipeline(this.pipeRelax); | |
| pass.setBindGroup(0, this.bgR[phase][l-1]); | |
| const no = this.sizes[l]; | |
| pass.dispatchWorkgroups(Math.ceil(no/8), Math.ceil(B/8)); | |
| pass.end(); | |
| } | |
| } | |
| } | |
| // Write all uniform buffers for the pass. | |
| // noiseScale > 0 enables sEqProp; seedBase is added to iteration counter for per-call variation. | |
| _writeAllUniformsForPass(dt, beta, noiseScale=0, seedBase=0){ | |
| const S=this.sizes, L=this.L, B=this.B, gam=this.gamma, mf=this.modeFlag; | |
| const phaseBetas = [0, +beta, -beta]; // free, plus, minus | |
| const ns = (typeof this.noiseScale === 'number') ? this.noiseScale : noiseScale; | |
| const sb = (typeof this.noiseSeedBase === 'number') ? this.noiseSeedBase : seedBase; | |
| const cLo = this.driveClampLo || 0; | |
| const cHi = this.driveClampHi || 0; | |
| // Relax uniforms (per phase, per layer). iter_seed is incremented per call below. | |
| for(let phase=0; phase<3; phase++){ | |
| for(let l=1; l<=L; l++){ | |
| const isOut = (l === L); | |
| const isHid = !isOut; | |
| const ni = S[l-1], no = S[l], nxt = isHid ? S[l+1] : 0; | |
| const phaseBeta = (isOut) ? phaseBetas[phase] : 0; | |
| this._writeRelaxParams(this.bufP_relax[phase][l-1], { | |
| ni, no, nxt, B, dt, beta: phaseBeta, gamma: gam, mode: mf, | |
| has_topdown: isHid ? 1 : 0, has_target: isOut ? 1 : 0, | |
| noise_scale: ns, | |
| iter_seed: (sb + phase * 7919 + (l-1) * 1009) >>> 0, | |
| clamp_lo: cLo, clamp_hi: cHi, | |
| }); | |
| } | |
| } | |
| for(let l=1; l<=L; l++){ | |
| this._writeRelaxParams(this.bufP_init[l-1], { | |
| ni: S[l-1], no: S[l], nxt: 0, B, dt, beta: 0, gamma: gam, mode: mf, has_topdown: 0, has_target: 0, | |
| noise_scale: 0, iter_seed: 0, // init kernel doesn't use noise | |
| clamp_lo: 0, clamp_hi: 0, // init kernel doesn't run drive — clamp irrelevant | |
| }); | |
| } | |
| } | |
| // Tier A — runtime setter for drive clamp. Pass (0,0) to disable. | |
| setDriveClamp(lo, hi){ | |
| this.driveClampLo = lo; | |
| this.driveClampHi = hi; | |
| } | |
| // sEqProp: set per-pass noise scale and seed base. Call before runFreeAndReadOutputs / runOnePass. | |
| setSEqPropNoise(noiseScale, seedBase){ | |
| this.noiseScale = noiseScale; | |
| this.noiseSeedBase = (seedBase >>> 0) || 0; | |
| } | |
| _runReward(enc){ | |
| this._writeAuxParams(this.bufP_rew, {B: this.B, O: this.O, c: 0, mode: this.modeFlag}); | |
| const pass = enc.beginComputePass(); | |
| pass.setPipeline(this.pipeReward); | |
| pass.setBindGroup(0, this.bgRew); | |
| pass.dispatchWorkgroups(Math.ceil(this.B/64)); | |
| pass.end(); | |
| } | |
| _runAdaptation(enc, adpC, adpSteps){ | |
| if(adpSteps <= 0 || this.mode === 'fhn') return; // skip adaptation in FHN mode | |
| const L = this.L; | |
| for(let l=1; l<=L; l++){ | |
| this._writeAuxParams(this.bufP_adapt[l-1], {B: this.B, O: this.O, c: adpC, mode: this.modeFlag}); | |
| } | |
| const MAX_WG_X = 65535; | |
| for(let a=0; a<adpSteps; a++){ | |
| for(let l=1; l<=L; l++){ | |
| const n = this.B * this.sizes[l]; | |
| const wgTotal = Math.ceil(n/64); | |
| const wgX = Math.min(wgTotal, MAX_WG_X); | |
| const wgY = Math.ceil(wgTotal / MAX_WG_X); | |
| const pass = enc.beginComputePass(); | |
| pass.setPipeline(this.pipeAdapt); | |
| pass.setBindGroup(0, this.bgAdapt[l-1]); | |
| pass.dispatchWorkgroups(wgX, wgY); | |
| pass.end(); | |
| } | |
| } | |
| } | |
| _runGrad(enc, beta){ | |
| const L = this.L, B = this.B; | |
| // Grad uniforms per layer transition | |
| for(let l=0; l<L; l++){ | |
| const ni = this.sizes[l], no = this.sizes[l+1]; | |
| // Determine ρ modes: | |
| // FHN: both pre/post are clip (1). | |
| // Adaptive: σ for both EXCEPT input layer (l=0) where Win is treated with σ (mode_pre=0 always since adaptive). | |
| const mode_pre = this.modeFlag; // 0 for adaptive, 1 for fhn — applies to all states | |
| const mode_post = this.modeFlag; | |
| this._writeGradParams(this.bufP_grad[l], {ni, no, B, two_beta: 2*beta, mode_pre, mode_post}); | |
| } | |
| for(let l=0; l<L; l++){ | |
| const ni = this.sizes[l], no = this.sizes[l+1]; | |
| const pass = enc.beginComputePass(); | |
| pass.setPipeline(this.pipeGW); | |
| pass.setBindGroup(0, this.bgG[l]); | |
| pass.dispatchWorkgroups(Math.ceil(ni/8), Math.ceil(no/8)); | |
| pass.setPipeline(this.pipeGB); | |
| pass.setBindGroup(0, this.bgG[l]); | |
| pass.dispatchWorkgroups(Math.ceil(no/64)); | |
| pass.end(); | |
| } | |
| } | |
| async runFreeAndReadOutputs(iters, dt){ | |
| if(!this.useHPSN){ | |
| if(this._lastTauDt !== dt){ this.setAllTau(dt); this._lastTauDt = dt; } | |
| } | |
| this._writeAllUniformsForPass(dt, 0); // beta=0 → all phases free | |
| const enc = this.dev.createCommandEncoder(); | |
| this._initAllPhases(enc); | |
| this._runPhaseRelax(enc, PHASE_F, iters); | |
| const O = this.O; | |
| enc.copyBufferToBuffer(this.bufU[PHASE_F][this.L-1], 0, this.rbUoF, 0, this.B*O*4); | |
| this.dev.queue.submit([enc.finish()]); | |
| await this.rbUoF.mapAsync(GPUMapMode.READ); | |
| const r = new Float32Array(this.rbUoF.getMappedRange().slice(0)); | |
| this.rbUoF.unmap(); | |
| return r; | |
| } | |
| // Free-phase relax, then read back the activations of an arbitrary internal layer (l in [1..L]). | |
| async runFreeAndReadLayer(iters, dt, layerIdx){ | |
| if(!this.useHPSN){ | |
| if(this._lastTauDt !== dt){ this.setAllTau(dt); this._lastTauDt = dt; } | |
| } | |
| if(layerIdx < 1 || layerIdx > this.L) throw new Error('layerIdx out of range'); | |
| this._writeAllUniformsForPass(dt, 0); | |
| const enc = this.dev.createCommandEncoder(); | |
| this._initAllPhases(enc); | |
| this._runPhaseRelax(enc, PHASE_F, iters); | |
| const size = this.B * this.sizes[layerIdx] * 4; | |
| const rb = this.dev.createBuffer({size, usage: GPUBufferUsage.COPY_DST|GPUBufferUsage.MAP_READ}); | |
| enc.copyBufferToBuffer(this.bufU[PHASE_F][layerIdx-1], 0, rb, 0, size); | |
| this.dev.queue.submit([enc.finish()]); | |
| await rb.mapAsync(GPUMapMode.READ); | |
| const r = new Float32Array(rb.getMappedRange().slice(0)); | |
| rb.unmap(); rb.destroy?.(); | |
| return r; | |
| } | |
| async runOnePassGetGradients({itF=8, itN=5, dt=0.7, beta=0.5, adpC=0.15, adpSteps=3}={}){ | |
| if(this.mode === 'fhn') adpSteps = 0; | |
| // HPSN backward-compat: when not using heterogeneous-τ, refresh Tau to match runtime dt. | |
| // When useHPSN=true, the user-set heterogeneous distribution is preserved (Tau not overwritten). | |
| if(!this.useHPSN){ | |
| if(this._lastTauDt !== dt){ this.setAllTau(dt); this._lastTauDt = dt; } | |
| } | |
| this._writeAllUniformsForPass(dt, beta); | |
| const enc = this.dev.createCommandEncoder(); | |
| this._initAllPhases(enc); | |
| this._runPhaseRelax(enc, PHASE_F, itF); | |
| this._runPhaseRelax(enc, PHASE_P, itN); | |
| this._runPhaseRelax(enc, PHASE_M, itN); | |
| this._runReward(enc); | |
| this._runAdaptation(enc, adpC, adpSteps); | |
| this._runGrad(enc, beta); | |
| // Readback all gradients (separate buffers per layer) + Uo_free | |
| for(let l=0; l<this.L; l++){ | |
| enc.copyBufferToBuffer(this.bufGW[l], 0, this.rbGW[l], 0, this.sizes[l+1]*this.sizes[l]*4); | |
| enc.copyBufferToBuffer(this.bufGB[l], 0, this.rbGB[l], 0, this.sizes[l+1]*4); | |
| } | |
| enc.copyBufferToBuffer(this.bufU[PHASE_F][this.L-1], 0, this.rbUoF, 0, this.B*this.O*4); | |
| this.dev.queue.submit([enc.finish()]); | |
| const maps = [this.rbUoF.mapAsync(GPUMapMode.READ)]; | |
| for(let l=0; l<this.L; l++){ maps.push(this.rbGW[l].mapAsync(GPUMapMode.READ)); maps.push(this.rbGB[l].mapAsync(GPUMapMode.READ)); } | |
| await Promise.all(maps); | |
| const uoF = new Float32Array(this.rbUoF.getMappedRange().slice(0)); | |
| this.rbUoF.unmap(); | |
| const gW = [], gB = []; | |
| for(let l=0; l<this.L; l++){ | |
| gW.push(new Float32Array(this.rbGW[l].getMappedRange().slice(0))); | |
| gB.push(new Float32Array(this.rbGB[l].getMappedRange().slice(0))); | |
| this.rbGW[l].unmap(); this.rbGB[l].unmap(); | |
| } | |
| return {gW, gB, uoF}; | |
| } | |
| destroy(){ | |
| const bufs = [this.bufWin, this.bufTgt, this.bufR, this.bufDummyR, this.bufDummyRW1, this.bufDummyRW2, this.bufDummyRW3, this.bufG, this.rbG, this.rbUoF, this.bufP_rew]; | |
| for(const arr of [this.bufW, this.bufB, this.bufGW, this.bufGB, this.rbGW, this.rbGB, this.bufP_init, this.bufP_grad, this.bufP_adapt, this.bufTau]) bufs.push(...arr); | |
| for(const ph of this.bufU) bufs.push(...ph); | |
| for(const ph of this.bufP_relax) bufs.push(...ph); | |
| for(const v of bufs) if(v && v.destroy) try{ v.destroy(); }catch(e){} | |
| } | |
| } | |
| // Multi-layer AdaGO optimizer. | |
| export class AdaGODeep { | |
| constructor(sizes, {OW_K=8}={}){ | |
| this.sizes = sizes; this.L = sizes.length - 1; | |
| this.MW=[]; this.MB=[]; this.vW=new Float64Array(this.L); this.vB=new Float64Array(this.L); | |
| this.OW=new Array(this.L).fill(null); this.OW_K=OW_K; this.bc=0; | |
| for(let l=0; l<this.L; l++){ | |
| this.MW.push(new Float64Array(sizes[l+1]*sizes[l])); | |
| this.MB.push(new Float64Array(sizes[l+1])); | |
| } | |
| } | |
| step(l, W, B, gW, gB, bs, lr){ | |
| const eps=1e-8, mu=0.9, gam=1.0; | |
| const ni=this.sizes[l], no=this.sizes[l+1]; | |
| let gn2=0; for(let k=0;k<W.length;k++){ const g=gW[k]/bs; this.MW[l][k]=mu*this.MW[l][k]+(1-mu)*g; gn2+=g*g; } | |
| const gn=Math.sqrt(gn2); this.vW[l]+=Math.min(gn2,gam*gam); | |
| if(!this.OW[l] || this.bc%this.OW_K===0) this.OW[l]=orthCPU(this.MW[l], no, ni); | |
| const O=this.OW[l]; | |
| const alpha=Math.max(eps, lr*Math.min(gn,gam)/(Math.sqrt(this.vW[l])+eps)); | |
| for(let k=0;k<W.length;k++) W[k]+= alpha*O[k]; | |
| let bn2=0; for(let k=0;k<B.length;k++){ const g=gB[k]/bs; this.MB[l][k]=mu*this.MB[l][k]+(1-mu)*g; bn2+=g*g; } | |
| const bn=Math.sqrt(bn2); this.vB[l]+=Math.min(bn2,gam*gam); | |
| const ba=Math.max(eps, lr*Math.min(bn,gam)/(Math.sqrt(this.vB[l])+eps)); | |
| for(let k=0;k<B.length;k++) B[k]+= ba*(bn>0?this.MB[l][k]/bn:0); | |
| } | |
| endBatch(){ this.bc++; } | |
| } | |
| // Adam (with optional weight decay → AdamW). EqProp gives an ascent direction, so we += step. | |
| export class Adam { | |
| constructor(sizes, {beta1=0.9, beta2=0.999, eps=1e-8, weightDecay=0}={}){ | |
| this.sizes=sizes; this.L=sizes.length-1; this.beta1=beta1; this.beta2=beta2; this.eps=eps; | |
| this.wd = weightDecay; // AdamW-style decoupled weight decay (applied to W only, not bias) | |
| this.mW=[]; this.vW=[]; this.mB=[]; this.vB=[]; this.t=0; | |
| for(let l=0; l<this.L; l++){ | |
| this.mW.push(new Float64Array(sizes[l+1]*sizes[l])); | |
| this.vW.push(new Float64Array(sizes[l+1]*sizes[l])); | |
| this.mB.push(new Float64Array(sizes[l+1])); | |
| this.vB.push(new Float64Array(sizes[l+1])); | |
| } | |
| } | |
| step(l, W, B, gW, gB, bs, lr){ | |
| this.t++; | |
| const b1=this.beta1, b2=this.beta2, eps=this.eps; | |
| const bc1 = 1 - Math.pow(b1, this.t), bc2 = 1 - Math.pow(b2, this.t); | |
| const wd = this.wd; | |
| for(let k=0;k<W.length;k++){ | |
| const g = gW[k]/bs; | |
| this.mW[l][k] = b1*this.mW[l][k] + (1-b1)*g; | |
| this.vW[l][k] = b2*this.vW[l][k] + (1-b2)*g*g; | |
| const m_hat = this.mW[l][k]/bc1, v_hat = this.vW[l][k]/bc2; | |
| // AdamW: decoupled decay (W *= 1 - lr*wd) ; EqProp ascent => + m̂/√v̂ | |
| if(wd > 0) W[k] *= (1 - lr * wd); | |
| W[k] += lr * m_hat / (Math.sqrt(v_hat) + eps); | |
| } | |
| for(let k=0;k<B.length;k++){ | |
| const g = gB[k]/bs; | |
| this.mB[l][k] = b1*this.mB[l][k] + (1-b1)*g; | |
| this.vB[l][k] = b2*this.vB[l][k] + (1-b2)*g*g; | |
| const m_hat = this.mB[l][k]/bc1, v_hat = this.vB[l][k]/bc2; | |
| B[k] += lr * m_hat / (Math.sqrt(v_hat) + eps); | |
| } | |
| } | |
| endBatch(){} | |
| } | |
| // Muon optimizer (Keller Jordan's MomentUm Orthogonalized by Newton-schulz). | |
| // Steps in sign-of-singular-values direction. Per-step Newton-Schulz quintic on momentum. | |
| // Uses shape-aware scaling: step = lr * sqrt(max(m,n)/min(m,n)) * Orth(M). | |
| // Coefficients from K. Jordan's original Muon: a, b, c chosen so f(x)=ax+bx³+cx⁵ pushes | |
| // singular values toward 1 (with controlled overshoot). | |
| function muonOrth(M_in, m, n, iters=5){ | |
| // Normalize so spectral norm ≤ 1 (use Frobenius norm as upper bound). | |
| // Then apply iters of X ← a X + b (XXᵀ)X + c (XXᵀ)² X (Muon's quintic Newton-Schulz) | |
| let X = new Float64Array(M_in.length); for(let k=0;k<M_in.length;k++) X[k]=M_in[k]; | |
| let nrm=0; for(const x of X) nrm += x*x; nrm = Math.sqrt(nrm) + 1e-30; | |
| for(let k=0;k<X.length;k++) X[k] /= nrm; | |
| // Decide which side is smaller (use m≥n form: work with X[m,n], operate on (XᵀX) [n,n] is smaller if n<m) | |
| // For generality just do the quintic in M form using square root tricks. Use the m≥n path; if n>m, transpose. | |
| const transp = (n > m); | |
| let R = m, C = n; | |
| if(transp){ | |
| // swap to make R ≥ C | |
| const T = new Float64Array(n*m); | |
| for(let i=0;i<m;i++) for(let j=0;j<n;j++) T[j*m+i] = X[i*n+j]; | |
| X = T; R = n; C = m; | |
| } | |
| const a = 3.4445, b = -4.7750, c = 2.0315; | |
| // helper: matmul A(p,q)·B(q,r) → out(p,r) | |
| function mm(A, B, p, q, r){ | |
| const O = new Float64Array(p*r); | |
| for(let i=0;i<p;i++) for(let k=0;k<q;k++){ const aa=A[i*q+k]; if(aa) for(let j=0;j<r;j++) O[i*r+j]+=aa*B[k*r+j]; } | |
| return O; | |
| } | |
| function transpose(A, p, q){ const T=new Float64Array(p*q); for(let i=0;i<p;i++) for(let j=0;j<q;j++) T[j*p+i]=A[i*q+j]; return T; } | |
| for(let it=0; it<iters; it++){ | |
| // We want X ← a X + b X(XᵀX) + c X(XᵀX)² | |
| // Compute G = XᵀX (C×C) | |
| const Xt = transpose(X, R, C); | |
| const G = mm(Xt, X, C, R, C); // (C×C) | |
| const G2 = mm(G, G, C, C, C); // (C×C) | |
| const XG = mm(X, G, R, C, C); // (R×C) | |
| const XG2 = mm(X, G2, R, C, C); // (R×C) | |
| const Y = new Float64Array(R*C); | |
| for(let k=0;k<R*C;k++) Y[k] = a*X[k] + b*XG[k] + c*XG2[k]; | |
| X = Y; | |
| } | |
| // Transpose back if needed | |
| if(transp){ | |
| const O = new Float64Array(m*n); | |
| for(let i=0;i<R;i++) for(let j=0;j<C;j++) O[j*m+i] = X[i*C+j]; // X was C×R after transp → output m×n | |
| // wait: when transp we used X[n×m], R=n C=m. So X[i*C+j] is X[n][m]. We want output[m,n] = transpose. | |
| // O[j*n+i] = X[i*m+j] would mean output_row_j_col_i = X_row_i_col_j. Let me redo. | |
| const out = new Float64Array(m*n); | |
| // X is R×C = n×m, so X[i,j] for i in 0..n, j in 0..m. We want M_orth[m,n] = (X_transposed)[a,b] = X[b,a]. | |
| for(let a=0;a<m;a++) for(let b=0;b<n;b++) out[a*n+b] = X[b*m+a]; | |
| return new Float32Array(out); | |
| } | |
| return new Float32Array(X); | |
| } | |
| export class Muon { | |
| constructor(sizes, {beta=0.95, weightDecay=0, iters=5}={}){ | |
| this.sizes=sizes; this.L=sizes.length-1; this.beta=beta; this.wd=weightDecay; this.iters=iters; | |
| this.MW=[]; this.mB=[]; | |
| for(let l=0; l<this.L; l++){ | |
| this.MW.push(new Float64Array(sizes[l+1]*sizes[l])); | |
| this.mB.push(new Float64Array(sizes[l+1])); | |
| } | |
| } | |
| step(l, W, B, gW, gB, bs, lr){ | |
| const beta=this.beta, wd=this.wd; | |
| const no = this.sizes[l+1], ni = this.sizes[l]; | |
| // Momentum update (Muon uses Nesterov-style momentum) | |
| for(let k=0;k<W.length;k++){ | |
| const g = gW[k]/bs; | |
| this.MW[l][k] = beta*this.MW[l][k] + g; | |
| } | |
| // Orthogonalize momentum via quintic NS | |
| const O = muonOrth(this.MW[l], no, ni, this.iters); | |
| // Shape-aware scaling: lr · sqrt(max/min) | |
| const scale = lr * Math.sqrt(Math.max(no, ni) / Math.min(no, ni)); | |
| // Step (ASCEND since EqProp gives ascent direction) | |
| for(let k=0;k<W.length;k++){ | |
| if(wd>0) W[k] *= (1 - lr*wd); | |
| W[k] += scale * O[k]; | |
| } | |
| // Bias: plain momentum (Muon spec says biases get separate Adam-like; here just SGD-momentum for simplicity) | |
| for(let k=0;k<B.length;k++){ | |
| const g = gB[k]/bs; | |
| this.mB[l][k] = beta*this.mB[l][k] + g; | |
| B[k] += lr * this.mB[l][k]; | |
| } | |
| } | |
| endBatch(){} | |
| } | |
| // Lion optimizer (sign of momentum). Often outperforms Adam for some tasks, less memory. | |
| export class Lion { | |
| constructor(sizes, {beta1=0.9, beta2=0.99, weightDecay=0}={}){ | |
| this.sizes=sizes; this.L=sizes.length-1; this.beta1=beta1; this.beta2=beta2; this.wd=weightDecay; | |
| this.mW=[]; this.mB=[]; | |
| for(let l=0; l<this.L; l++){ | |
| this.mW.push(new Float64Array(sizes[l+1]*sizes[l])); | |
| this.mB.push(new Float64Array(sizes[l+1])); | |
| } | |
| } | |
| step(l, W, B, gW, gB, bs, lr){ | |
| const b1=this.beta1, b2=this.beta2, wd=this.wd; | |
| for(let k=0;k<W.length;k++){ | |
| const g = gW[k]/bs; | |
| // update direction: sign(b1*m + (1-b1)*g) | |
| const c = b1*this.mW[l][k] + (1-b1)*g; | |
| const u = c >= 0 ? 1 : -1; | |
| if(wd>0) W[k] *= (1 - lr*wd); | |
| W[k] += lr * u; | |
| // momentum update with b2 | |
| this.mW[l][k] = b2*this.mW[l][k] + (1-b2)*g; | |
| } | |
| for(let k=0;k<B.length;k++){ | |
| const g = gB[k]/bs; | |
| const c = b1*this.mB[l][k] + (1-b1)*g; | |
| const u = c >= 0 ? 1 : -1; | |
| B[k] += lr * u; | |
| this.mB[l][k] = b2*this.mB[l][k] + (1-b2)*g; | |
| } | |
| } | |
| endBatch(){} | |
| } | |