cortex-conv / tests /gpu_lib_conv_full.js
Ex0bit's picture
initial deployment: cortex-conv ships pre-trained at 96.8% MNIST
d035fbd
// Full convolutional EqProp trainer (WebGPU).
// Architecture: one conv layer (weight-shared kernel) + one dense readout to O classes.
// All EqProp, no backprop. Validated vs ConvEqPropNet CPU reference.
import { orth as orthCPU } from './eqprop_lib.js';
const WGSL_CONV_RELAX = `
struct CP {
B: u32, Cin: u32, Cout: u32, H: u32,
W: u32, Hout: u32, Wout: u32, KH: u32,
KW: u32, stride: u32, pad: u32, _p0: u32,
dt: f32, beta: f32, gamma: f32, _p1: f32,
has_topdown: u32, nxt_size: u32, has_target: u32, _p2: u32,
};
@group(0) @binding(0) var<uniform> p : CP;
@group(0) @binding(1) var<storage, read> Xin : array<f32>;
@group(0) @binding(2) var<storage, read> Wt : array<f32>;
@group(0) @binding(3) var<storage, read> Bs : array<f32>;
@group(0) @binding(4) var<storage, read> Wnxt: array<f32>;
@group(0) @binding(5) var<storage, read_write> Uh : array<f32>;
@group(0) @binding(6) var<storage, read> Uo : array<f32>;
fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
fn rho(u: f32) -> f32 { return sg(u); }
@compute @workgroup_size(8, 8, 1) fn conv_pass(@builtin(global_invocation_id) gid: vec3<u32>) {
let xo = gid.x; let yo = gid.y; let bk = gid.z;
if (xo >= p.Wout || yo >= p.Hout) { return; }
let b = bk / p.Cout; let k = bk % p.Cout;
if (b >= p.B) { return; }
let img_size = p.Cin * p.H * p.W;
let map_size = p.Cout * p.Hout * p.Wout;
var c : f32 = Bs[k];
for (var kin: u32 = 0u; kin < p.Cin; kin = kin + 1u) {
for (var dy: u32 = 0u; dy < p.KH; dy = dy + 1u) {
let iy_s = i32(yo * p.stride + dy) - i32(p.pad);
if (iy_s < 0 || iy_s >= i32(p.H)) { continue; }
let iy = u32(iy_s);
for (var dx: u32 = 0u; dx < p.KW; dx = dx + 1u) {
let ix_s = i32(xo * p.stride + dx) - i32(p.pad);
if (ix_s < 0 || ix_s >= i32(p.W)) { continue; }
let ix = u32(ix_s);
let xidx = b * img_size + kin * p.H * p.W + iy * p.W + ix;
let widx = ((k * p.Cin + kin) * p.KH + dy) * p.KW + dx;
c = c + Wt[widx] * rho(Xin[xidx]);
}
}
}
if (p.has_topdown != 0u) {
var td : f32 = 0.0;
let this_flat = k * p.Hout * p.Wout + yo * p.Wout + xo;
for (var n: u32 = 0u; n < p.nxt_size; n = n + 1u) {
td = td + Wnxt[n * map_size + this_flat] * rho(Uo[b * p.nxt_size + n]);
}
c = c + p.gamma * td;
}
// STABILIZER #4: clip pre-activation drive to a bounded range BEFORE σ.
c = clamp(c, -2.5, 3.5);
// SYNTHESIS #3: LOCAL RESPONSE NORMALIZATION (LRN) — V1 cross-channel lateral inhibition.
// For each (b, k, y, x), divisively normalize σ(c) by the L2 norm of neighbor-channel activities at
// the same spatial position. This is biologically faithful (V1 surround inhibition) and breaks the
// saturation positive-feedback loop by ensuring no single channel can dominate.
// normalized = σ(c) / (k₀ + α · Σ_{k' ∈ [k-n, k+n]} Uh[b,k',y,x]²)^β
// We read Uh from PREVIOUS iter (current-iter writes haven't happened yet within this pass).
// Applied identically in all 3 phases (free/+β/−β) so the EqProp gradient remains consistent
// for the LRN-modified energy.
let chan_size = p.Hout * p.Wout;
let spatial_off = b * map_size + yo * p.Wout + xo;
var sum_sq : f32 = 0.0;
let n_half : i32 = 2;
let kk_min : i32 = max(0, i32(k) - n_half);
let kk_max : i32 = min(i32(p.Cout) - 1, i32(k) + n_half);
for (var kk: i32 = kk_min; kk <= kk_max; kk = kk + 1) {
let v = Uh[spatial_off + u32(kk) * chan_size];
sum_sq = sum_sq + v * v;
}
// LRN DISABLED (v13 found it degrades single-conv accuracy; keep code path for re-enable).
let sg_c_norm = sg(c);
let idx = b * map_size + k * p.Hout * p.Wout + yo * p.Wout + xo;
let u_old = Uh[idx];
let drive = -u_old + sg_c_norm;
Uh[idx] = u_old + p.dt * drive;
}
@compute @workgroup_size(64) fn init_state(@builtin(global_invocation_id) gid: vec3<u32>) {
let stride = 65535u * 64u;
let g = gid.y * stride + gid.x;
let n = p.B * p.Cout * p.Hout * p.Wout;
if (g < n) { Uh[g] = 0.1; }
}
`;
// Dense output pass: u_out[b, i] = -u + σ(b_i + Σ_j W[i,j] σ(conv_flat[b,j])) + β·(tgt-u)
const WGSL_DENSE_OUT = `
struct DP {
B: u32, Ni: u32, No: u32, _p0: u32,
dt: f32, beta: f32, _p1: f32, _p2: f32,
has_target: u32, _p3: u32, _p4: u32, _p5: u32,
};
@group(0) @binding(0) var<uniform> p : DP;
@group(0) @binding(1) var<storage, read> Xin : array<f32>; // conv hidden, [B*Ni]
@group(0) @binding(2) var<storage, read> Wt : array<f32>; // [No*Ni]
@group(0) @binding(3) var<storage, read> Bs : array<f32>; // [No]
@group(0) @binding(4) var<storage, read> Tgt : array<f32>; // [B*No]
@group(0) @binding(5) var<storage, read_write> Uo : array<f32>; // [B*No]
fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
fn rho(u: f32) -> f32 { return sg(u); }
@compute @workgroup_size(64, 1) fn dense_pass(@builtin(global_invocation_id) gid: vec3<u32>) {
let b = gid.y; let i = gid.x;
if (b >= p.B || i >= p.No) { return; }
var c : f32 = Bs[i];
for (var j: u32 = 0u; j < p.Ni; j = j + 1u) {
c = c + Wt[i * p.Ni + j] * rho(Xin[b * p.Ni + j]);
}
// STABILIZER #4 (dense output): clamp pre-activation drive too — same reasoning.
c = clamp(c, -2.5, 3.5);
let idx = b * p.No + i;
let u_old = Uo[idx];
var drive : f32 = -u_old + sg(c);
if (p.has_target != 0u && p.beta != 0.0) {
drive = drive + p.beta * (Tgt[idx] - u_old);
}
Uo[idx] = u_old + p.dt * drive;
}
@compute @workgroup_size(64) fn init_state_out(@builtin(global_invocation_id) gid: vec3<u32>) {
let g = gid.x; let n = p.B * p.No;
if (g < n) { Uo[g] = 0.1; }
}
`;
const WGSL_AUX_CONV = `
struct AP {
B: u32, O: u32, convFlat: u32, _p0: u32,
c_adp: f32, mode: f32, _p1: f32, _p2: f32,
};
@group(0) @binding(0) var<uniform> p : AP;
@group(0) @binding(1) var<storage, read> UoF : array<f32>;
@group(0) @binding(2) var<storage, read> Tgt : array<f32>;
@group(0) @binding(3) var<storage, read_write> R : array<f32>;
@group(0) @binding(4) var<storage, read> Uf : array<f32>;
@group(0) @binding(5) var<storage, read_write> Up : array<f32>;
@group(0) @binding(6) var<storage, read_write> Um : array<f32>;
fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
fn rho_out(u: f32) -> f32 { return sg(u); }
@compute @workgroup_size(64) fn compute_reward(@builtin(global_invocation_id) gid: vec3<u32>) {
let b = gid.x;
if (b >= p.B) { return; }
var loss : f32 = 0.0;
let off = b * p.O;
for (var i: u32 = 0u; i < p.O; i = i + 1u) {
let d = rho_out(UoF[off + i]) - Tgt[off + i];
loss = loss + d * d;
}
var r : f32 = loss / 0.4;
if (r > 1.0) { r = 1.0; }
R[b] = 0.1 + 0.9 * r;
}
@compute @workgroup_size(64) fn adapt_layer(@builtin(global_invocation_id) gid: vec3<u32>) {
let stride = 65535u * 64u;
let g = gid.y * stride + gid.x;
if (g >= arrayLength(&Uf)) { return; }
let f = Uf[g];
Up[g] = (1.0 - p.c_adp) * Up[g] + p.c_adp * f;
Um[g] = (1.0 - p.c_adp) * Um[g] + p.c_adp * f;
}
`;
const WGSL_GRAD_CONV = `
struct CGP {
B: u32, Cin: u32, Cout: u32, H: u32,
W: u32, Hout: u32, Wout: u32, KH: u32,
KW: u32, stride: u32, pad: u32, _p0: u32,
two_beta: f32, _p1: f32, _p2: f32, _p3: f32,
};
@group(0) @binding(0) var<uniform> p : CGP;
@group(0) @binding(1) var<storage, read> Xp : array<f32>;
@group(0) @binding(2) var<storage, read> Xm : array<f32>;
@group(0) @binding(3) var<storage, read> Up : array<f32>;
@group(0) @binding(4) var<storage, read> Um : array<f32>;
@group(0) @binding(5) var<storage, read> R : array<f32>;
@group(0) @binding(6) var<storage, read_write> gW : array<f32>;
@group(0) @binding(7) var<storage, read_write> gB : array<f32>;
fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
fn rho(u: f32) -> f32 { return sg(u); }
@compute @workgroup_size(8, 8, 1) fn grad_W_conv(@builtin(global_invocation_id) gid: vec3<u32>) {
let dx = gid.x; let dy = gid.y; let kk = gid.z;
if (dx >= p.KW || dy >= p.KH) { return; }
let kout = kk / p.Cin; let kin = kk % p.Cin;
if (kout >= p.Cout) { return; }
let img_size = p.Cin * p.H * p.W;
let map_size = p.Cout * p.Hout * p.Wout;
var acc : f32 = 0.0;
for (var b: u32 = 0u; b < p.B; b = b + 1u) {
let rb = R[b];
for (var yo: u32 = 0u; yo < p.Hout; yo = yo + 1u) {
let iy_s = i32(yo * p.stride + dy) - i32(p.pad);
if (iy_s < 0 || iy_s >= i32(p.H)) { continue; }
let iy = u32(iy_s);
for (var xo: u32 = 0u; xo < p.Wout; xo = xo + 1u) {
let ix_s = i32(xo * p.stride + dx) - i32(p.pad);
if (ix_s < 0 || ix_s >= i32(p.W)) { continue; }
let ix = u32(ix_s);
let u_flat = b * map_size + kout * p.Hout * p.Wout + yo * p.Wout + xo;
let x_flat = b * img_size + kin * p.H * p.W + iy * p.W + ix;
acc = acc + rb * (rho(Up[u_flat]) * rho(Xp[x_flat]) - rho(Um[u_flat]) * rho(Xm[x_flat]));
}
}
}
let widx = ((kout * p.Cin + kin) * p.KH + dy) * p.KW + dx;
gW[widx] = acc / p.two_beta;
}
@compute @workgroup_size(64) fn grad_B_conv(@builtin(global_invocation_id) gid: vec3<u32>) {
let kout = gid.x;
if (kout >= p.Cout) { return; }
let map_size = p.Cout * p.Hout * p.Wout;
var acc : f32 = 0.0;
for (var b: u32 = 0u; b < p.B; b = b + 1u) {
let rb = R[b];
for (var yo: u32 = 0u; yo < p.Hout; yo = yo + 1u) {
for (var xo: u32 = 0u; xo < p.Wout; xo = xo + 1u) {
let u_flat = b * map_size + kout * p.Hout * p.Wout + yo * p.Wout + xo;
acc = acc + rb * (rho(Up[u_flat]) - rho(Um[u_flat]));
}
}
}
gB[kout] = acc / p.two_beta;
}
`;
const WGSL_GRAD_DENSE = `
struct DGP {
B: u32, Ni: u32, No: u32, _p0: u32,
two_beta: f32, _p1: f32, _p2: f32, _p3: f32,
};
@group(0) @binding(0) var<uniform> p : DGP;
@group(0) @binding(1) var<storage, read> Xp : array<f32>; // pre-layer plus phase [B*Ni]
@group(0) @binding(2) var<storage, read> Xm : array<f32>;
@group(0) @binding(3) var<storage, read> Up : array<f32>; // post-layer plus phase [B*No]
@group(0) @binding(4) var<storage, read> Um : array<f32>;
@group(0) @binding(5) var<storage, read> R : array<f32>;
@group(0) @binding(6) var<storage, read_write> gW : array<f32>;
@group(0) @binding(7) var<storage, read_write> gB : array<f32>;
fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); }
fn rho(u: f32) -> f32 { return sg(u); }
@compute @workgroup_size(8, 8) fn grad_W_dense(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.y; let j = gid.x;
if (i >= p.No || j >= p.Ni) { return; }
var acc : f32 = 0.0;
for (var b: u32 = 0u; b < p.B; b = b + 1u) {
let rb = R[b];
let ip = rho(Up[b * p.No + i]); let im = rho(Um[b * p.No + i]);
let jp = rho(Xp[b * p.Ni + j]); let jm = rho(Xm[b * p.Ni + j]);
acc = acc + rb * (ip * jp - im * jm);
}
gW[i * p.Ni + j] = acc / p.two_beta;
}
@compute @workgroup_size(64) fn grad_B_dense(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
if (i >= p.No) { return; }
var acc : f32 = 0.0;
for (var b: u32 = 0u; b < p.B; b = b + 1u) {
let rb = R[b];
acc = acc + rb * (rho(Up[b * p.No + i]) - rho(Um[b * p.No + i]));
}
gB[i] = acc / p.two_beta;
}
`;
export async function makeGPUConv({powerPreference='high-performance'}={}){
if(!navigator.gpu) throw new Error('no webgpu');
const adapter = await navigator.gpu.requestAdapter({powerPreference});
if(!adapter) throw new Error('no adapter');
const want = {};
for(const k of ['maxStorageBuffersPerShaderStage','maxBufferSize','maxStorageBufferBindingSize',
'maxComputeInvocationsPerWorkgroup','maxComputeWorkgroupSizeX','maxComputeWorkgroupStorageSize','maxBindGroups']){
const v=adapter.limits[k]; if(typeof v==='number') want[k]=v;
}
const dev = await adapter.requestDevice({requiredLimits: want});
return {adapter, dev, info: adapter.info||{}};
}
const PHASE_F = 0, PHASE_P = 1, PHASE_M = 2;
export class GPUTrainerConvFull {
constructor({dev, convCfg, denseSize, B}){
this.dev = dev; this.cfg = convCfg; this.B = B; this.O = denseSize;
this.Hout = Math.floor((convCfg.H + 2*convCfg.pad - convCfg.KH)/convCfg.stride) + 1;
this.Wout = Math.floor((convCfg.W + 2*convCfg.pad - convCfg.KW)/convCfg.stride) + 1;
this.convFlat = convCfg.Cout * this.Hout * this.Wout;
this._build();
}
_F32(n, usage){ return this.dev.createBuffer({size: Math.max(4, n*4), usage}); }
_build(){
const dev = this.dev, B = this.B, c = this.cfg, O = this.O;
const RW = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST;
const R = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC;
const UNI= GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST;
const RDS= GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ;
const imgSize = c.Cin * c.H * c.W;
const cw_n = c.Cout * c.Cin * c.KH * c.KW;
// Input + targets (shared across phases)
this.bufXin = this._F32(B * imgSize, R);
this.bufTgt = this._F32(B * O, R);
// Weights
this.bufWconv = this._F32(cw_n, R);
this.bufBconv = this._F32(c.Cout, R);
this.bufWdense = this._F32(O * this.convFlat, R);
this.bufBdense = this._F32(O, R);
// Per-phase states
this.bufUconv = [];
this.bufUout = [];
for(let phase=0; phase<3; phase++){
this.bufUconv.push(this._F32(B * this.convFlat, RW));
this.bufUout.push(this._F32(B * O, RW));
}
// Reward + dummies (need separate dummy buffers for distinct read_write slots to avoid aliasing)
this.bufR = this._F32(B, RW);
this.bufDummyR = this._F32(4, R);
this.bufDummyRW1 = this._F32(4, RW);
this.bufDummyRW2 = this._F32(4, RW);
this.bufDummyRW3 = this._F32(4, RW);
// Gradient buffers
this.bufGWconv = this._F32(cw_n, RW);
this.bufGBconv = this._F32(c.Cout, RW);
this.bufGWdense= this._F32(O * this.convFlat, RW);
this.bufGBdense= this._F32(O, RW);
// Readback
this.rbUoF = dev.createBuffer({size: B*O*4, usage: RDS});
this.rbUconvF = dev.createBuffer({size: B*this.convFlat*4, usage: RDS});
this.rbGWconv = dev.createBuffer({size: cw_n*4, usage: RDS});
this.rbGBconv = dev.createBuffer({size: c.Cout*4, usage: RDS});
this.rbGWdense = dev.createBuffer({size: O*this.convFlat*4, usage: RDS});
this.rbGBdense = dev.createBuffer({size: O*4, usage: RDS});
// Uniform buffers: one per phase for conv (CP=80 bytes), one per phase for dense (DP=48 bytes), one for grads (CGP=64, DGP=32), aux (AP=32)
this.bufP_conv = [];
this.bufP_dense = [];
this.bufP_init_conv = [];
this.bufP_init_dense = [];
for(let phase=0; phase<3; phase++){
this.bufP_conv.push(dev.createBuffer({size:80, usage:UNI}));
this.bufP_dense.push(dev.createBuffer({size:48, usage:UNI}));
this.bufP_init_conv.push(dev.createBuffer({size:80, usage:UNI}));
this.bufP_init_dense.push(dev.createBuffer({size:48, usage:UNI}));
}
this.bufP_grad_conv = dev.createBuffer({size:64, usage:UNI});
this.bufP_grad_dense = dev.createBuffer({size:32, usage:UNI});
this.bufP_rew = dev.createBuffer({size:32, usage:UNI});
this.bufP_adapt_conv = dev.createBuffer({size:32, usage:UNI});
this.bufP_adapt_out = dev.createBuffer({size:32, usage:UNI});
// Pipelines
const sR = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'read-only-storage'}});
const sRW = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'storage'}});
const uN = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'uniform'}});
// ---- Conv relax pipeline ----
const modConv = dev.createShaderModule({code: WGSL_CONV_RELAX});
this.bglConv = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sRW(5), sR(6)]});
this.plConv = dev.createPipelineLayout({bindGroupLayouts:[this.bglConv]});
this.pipeConv = dev.createComputePipeline({layout:this.plConv, compute:{module:modConv, entryPoint:'conv_pass'}});
this.pipeInitConv = dev.createComputePipeline({layout:this.plConv, compute:{module:modConv, entryPoint:'init_state'}});
// ---- Dense relax pipeline ----
const modDense = dev.createShaderModule({code: WGSL_DENSE_OUT});
this.bglDense = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sRW(5)]});
this.plDense = dev.createPipelineLayout({bindGroupLayouts:[this.bglDense]});
this.pipeDense = dev.createComputePipeline({layout:this.plDense, compute:{module:modDense, entryPoint:'dense_pass'}});
this.pipeInitDense= dev.createComputePipeline({layout:this.plDense, compute:{module:modDense, entryPoint:'init_state_out'}});
// ---- Aux (reward + adapt) pipeline ----
const modAux = dev.createShaderModule({code: WGSL_AUX_CONV});
this.bglAux = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sRW(3), sR(4), sRW(5), sRW(6)]});
this.plAux = dev.createPipelineLayout({bindGroupLayouts:[this.bglAux]});
this.pipeReward = dev.createComputePipeline({layout:this.plAux, compute:{module:modAux, entryPoint:'compute_reward'}});
this.pipeAdapt = dev.createComputePipeline({layout:this.plAux, compute:{module:modAux, entryPoint:'adapt_layer'}});
// ---- Grad conv pipeline ----
const modGC = dev.createShaderModule({code: WGSL_GRAD_CONV});
this.bglGC = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]});
this.plGC = dev.createPipelineLayout({bindGroupLayouts:[this.bglGC]});
this.pipeGWconv = dev.createComputePipeline({layout:this.plGC, compute:{module:modGC, entryPoint:'grad_W_conv'}});
this.pipeGBconv = dev.createComputePipeline({layout:this.plGC, compute:{module:modGC, entryPoint:'grad_B_conv'}});
// ---- Grad dense pipeline ----
const modGD = dev.createShaderModule({code: WGSL_GRAD_DENSE});
this.bglGD = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]});
this.plGD = dev.createPipelineLayout({bindGroupLayouts:[this.bglGD]});
this.pipeGWdense = dev.createComputePipeline({layout:this.plGD, compute:{module:modGD, entryPoint:'grad_W_dense'}});
this.pipeGBdense = dev.createComputePipeline({layout:this.plGD, compute:{module:modGD, entryPoint:'grad_B_dense'}});
// ---- Bind groups ----
// Conv per phase (relax + init each use same layout, different uniform buffer)
this.bgConv = [], this.bgInitConv = [];
for(let phase=0; phase<3; phase++){
this.bgConv.push(dev.createBindGroup({layout:this.bglConv, entries:[
{binding:0, resource:{buffer:this.bufP_conv[phase]}},
{binding:1, resource:{buffer:this.bufXin}},
{binding:2, resource:{buffer:this.bufWconv}},
{binding:3, resource:{buffer:this.bufBconv}},
{binding:4, resource:{buffer:this.bufWdense}},
{binding:5, resource:{buffer:this.bufUconv[phase]}},
{binding:6, resource:{buffer:this.bufUout[phase]}},
]}));
this.bgInitConv.push(dev.createBindGroup({layout:this.bglConv, entries:[
{binding:0, resource:{buffer:this.bufP_init_conv[phase]}},
{binding:1, resource:{buffer:this.bufXin}},
{binding:2, resource:{buffer:this.bufWconv}},
{binding:3, resource:{buffer:this.bufBconv}},
{binding:4, resource:{buffer:this.bufWdense}},
{binding:5, resource:{buffer:this.bufUconv[phase]}},
{binding:6, resource:{buffer:this.bufUout[phase]}},
]}));
}
// Dense per phase
this.bgDense = [], this.bgInitDense = [];
for(let phase=0; phase<3; phase++){
this.bgDense.push(dev.createBindGroup({layout:this.bglDense, entries:[
{binding:0, resource:{buffer:this.bufP_dense[phase]}},
{binding:1, resource:{buffer:this.bufUconv[phase]}}, // input = conv hidden
{binding:2, resource:{buffer:this.bufWdense}},
{binding:3, resource:{buffer:this.bufBdense}},
{binding:4, resource:{buffer:this.bufTgt}},
{binding:5, resource:{buffer:this.bufUout[phase]}},
]}));
this.bgInitDense.push(dev.createBindGroup({layout:this.bglDense, entries:[
{binding:0, resource:{buffer:this.bufP_init_dense[phase]}},
{binding:1, resource:{buffer:this.bufUconv[phase]}},
{binding:2, resource:{buffer:this.bufWdense}},
{binding:3, resource:{buffer:this.bufBdense}},
{binding:4, resource:{buffer:this.bufTgt}},
{binding:5, resource:{buffer:this.bufUout[phase]}},
]}));
}
// Aux: reward (uses Uo_free) + adapt (conv, out)
this.bgRew = dev.createBindGroup({layout:this.bglAux, entries:[
{binding:0, resource:{buffer:this.bufP_rew}},
{binding:1, resource:{buffer:this.bufUout[PHASE_F]}},
{binding:2, resource:{buffer:this.bufTgt}},
{binding:3, resource:{buffer:this.bufR}},
{binding:4, resource:{buffer:this.bufDummyR}},
{binding:5, resource:{buffer:this.bufDummyRW1}},
{binding:6, resource:{buffer:this.bufDummyRW2}},
]});
this.bgAdaptConv = dev.createBindGroup({layout:this.bglAux, entries:[
{binding:0, resource:{buffer:this.bufP_adapt_conv}},
{binding:1, resource:{buffer:this.bufDummyR}},
{binding:2, resource:{buffer:this.bufDummyR}},
{binding:3, resource:{buffer:this.bufDummyRW3}},
{binding:4, resource:{buffer:this.bufUconv[PHASE_F]}},
{binding:5, resource:{buffer:this.bufUconv[PHASE_P]}},
{binding:6, resource:{buffer:this.bufUconv[PHASE_M]}},
]});
this.bgAdaptOut = dev.createBindGroup({layout:this.bglAux, entries:[
{binding:0, resource:{buffer:this.bufP_adapt_out}},
{binding:1, resource:{buffer:this.bufDummyR}},
{binding:2, resource:{buffer:this.bufDummyR}},
{binding:3, resource:{buffer:this.bufDummyRW3}},
{binding:4, resource:{buffer:this.bufUout[PHASE_F]}},
{binding:5, resource:{buffer:this.bufUout[PHASE_P]}},
{binding:6, resource:{buffer:this.bufUout[PHASE_M]}},
]});
// Grad
this.bgGC = dev.createBindGroup({layout:this.bglGC, entries:[
{binding:0, resource:{buffer:this.bufP_grad_conv}},
{binding:1, resource:{buffer:this.bufXin}}, // Xp (conv input)
{binding:2, resource:{buffer:this.bufXin}}, // Xm (same input — no augmentation in two phases)
{binding:3, resource:{buffer:this.bufUconv[PHASE_P]}},
{binding:4, resource:{buffer:this.bufUconv[PHASE_M]}},
{binding:5, resource:{buffer:this.bufR}},
{binding:6, resource:{buffer:this.bufGWconv}},
{binding:7, resource:{buffer:this.bufGBconv}},
]});
this.bgGD = dev.createBindGroup({layout:this.bglGD, entries:[
{binding:0, resource:{buffer:this.bufP_grad_dense}},
{binding:1, resource:{buffer:this.bufUconv[PHASE_P]}},
{binding:2, resource:{buffer:this.bufUconv[PHASE_M]}},
{binding:3, resource:{buffer:this.bufUout[PHASE_P]}},
{binding:4, resource:{buffer:this.bufUout[PHASE_M]}},
{binding:5, resource:{buffer:this.bufR}},
{binding:6, resource:{buffer:this.bufGWdense}},
{binding:7, resource:{buffer:this.bufGBdense}},
]});
}
_writeConvParams(buf, fields){
const u32 = new Uint32Array(20); const f32 = new Float32Array(u32.buffer);
u32[0]=this.B; u32[1]=this.cfg.Cin; u32[2]=this.cfg.Cout; u32[3]=this.cfg.H;
u32[4]=this.cfg.W; u32[5]=this.Hout; u32[6]=this.Wout; u32[7]=this.cfg.KH;
u32[8]=this.cfg.KW; u32[9]=this.cfg.stride; u32[10]=this.cfg.pad; u32[11]=0;
f32[12]=fields.dt||0; f32[13]=fields.beta||0; f32[14]=fields.gamma||0; f32[15]=0;
u32[16]=fields.has_topdown||0; u32[17]=this.O; u32[18]=fields.has_target||0; u32[19]=0;
this.dev.queue.writeBuffer(buf, 0, u32.buffer);
}
_writeDenseParams(buf, fields){
const u32 = new Uint32Array(12); const f32 = new Float32Array(u32.buffer);
u32[0]=this.B; u32[1]=this.convFlat; u32[2]=this.O; u32[3]=0;
f32[4]=fields.dt||0; f32[5]=fields.beta||0; f32[6]=0; f32[7]=0;
u32[8]=fields.has_target||0; u32[9]=0; u32[10]=0; u32[11]=0;
this.dev.queue.writeBuffer(buf, 0, u32.buffer);
}
_writeGradConvParams(two_beta){
const u32 = new Uint32Array(16); const f32 = new Float32Array(u32.buffer);
u32[0]=this.B; u32[1]=this.cfg.Cin; u32[2]=this.cfg.Cout; u32[3]=this.cfg.H;
u32[4]=this.cfg.W; u32[5]=this.Hout; u32[6]=this.Wout; u32[7]=this.cfg.KH;
u32[8]=this.cfg.KW; u32[9]=this.cfg.stride; u32[10]=this.cfg.pad; u32[11]=0;
f32[12]=two_beta; f32[13]=0; f32[14]=0; f32[15]=0;
this.dev.queue.writeBuffer(this.bufP_grad_conv, 0, u32.buffer);
}
_writeGradDenseParams(two_beta){
const u32 = new Uint32Array(8); const f32 = new Float32Array(u32.buffer);
u32[0]=this.B; u32[1]=this.convFlat; u32[2]=this.O; u32[3]=0;
f32[4]=two_beta; f32[5]=0; f32[6]=0; f32[7]=0;
this.dev.queue.writeBuffer(this.bufP_grad_dense, 0, u32.buffer);
}
_writeAuxParams(buf, fields){
const u32 = new Uint32Array(8); const f32 = new Float32Array(u32.buffer);
u32[0]=this.B; u32[1]=this.O; u32[2]=this.convFlat; u32[3]=0;
f32[4]=fields.c_adp||0; f32[5]=fields.mode||0; f32[6]=0; f32[7]=0;
this.dev.queue.writeBuffer(buf, 0, u32.buffer);
}
uploadWeights(Wconv, bconv, Wdense, bdense){
const q = this.dev.queue;
q.writeBuffer(this.bufWconv, 0, Wconv.buffer, Wconv.byteOffset, Wconv.byteLength);
q.writeBuffer(this.bufBconv, 0, bconv.buffer, bconv.byteOffset, bconv.byteLength);
q.writeBuffer(this.bufWdense, 0, Wdense.buffer, Wdense.byteOffset, Wdense.byteLength);
q.writeBuffer(this.bufBdense, 0, bdense.buffer, bdense.byteOffset, bdense.byteLength);
}
uploadInputs(X, T){
const q = this.dev.queue;
q.writeBuffer(this.bufXin, 0, X.buffer, X.byteOffset, X.byteLength);
q.writeBuffer(this.bufTgt, 0, T.buffer, T.byteOffset, T.byteLength);
}
_writeAllPhaseUniforms(dt, beta, gamma){
// free phase: beta=0
this._writeConvParams (this.bufP_conv[0], {dt, beta:0, gamma, has_topdown:1, has_target:0});
this._writeConvParams (this.bufP_conv[1], {dt, beta:0, gamma, has_topdown:1, has_target:0});
this._writeConvParams (this.bufP_conv[2], {dt, beta:0, gamma, has_topdown:1, has_target:0});
this._writeDenseParams(this.bufP_dense[0], {dt, beta:0, has_target:0});
this._writeDenseParams(this.bufP_dense[1], {dt, beta:+beta, has_target:1});
this._writeDenseParams(this.bufP_dense[2], {dt, beta:-beta, has_target:1});
// init uniforms (beta=0, same as relax)
for(let phase=0; phase<3; phase++){
this._writeConvParams(this.bufP_init_conv[phase], {dt, beta:0, gamma:0, has_topdown:0, has_target:0});
this._writeDenseParams(this.bufP_init_dense[phase],{dt, beta:0, has_target:0});
}
}
_initAllPhases(enc){
const MAX_WG_X = 65535;
for(let phase=0; phase<3; phase++){
// conv init
const nc = this.B * this.convFlat;
const wgC = Math.ceil(nc/64);
let p = enc.beginComputePass();
p.setPipeline(this.pipeInitConv); p.setBindGroup(0, this.bgInitConv[phase]);
p.dispatchWorkgroups(Math.min(wgC, MAX_WG_X), Math.ceil(wgC/MAX_WG_X));
p.end();
// dense init
const no = this.B * this.O;
p = enc.beginComputePass();
p.setPipeline(this.pipeInitDense); p.setBindGroup(0, this.bgInitDense[phase]);
p.dispatchWorkgroups(Math.ceil(no/64));
p.end();
}
}
_runPhaseRelax(enc, phase, iters){
const B = this.B;
for(let t=0; t<iters; t++){
// Conv pass
let p = enc.beginComputePass();
p.setPipeline(this.pipeConv); p.setBindGroup(0, this.bgConv[phase]);
p.dispatchWorkgroups(Math.ceil(this.Wout/8), Math.ceil(this.Hout/8), B*this.cfg.Cout);
p.end();
// Dense pass
p = enc.beginComputePass();
p.setPipeline(this.pipeDense); p.setBindGroup(0, this.bgDense[phase]);
p.dispatchWorkgroups(Math.ceil(this.O/64), B);
p.end();
}
}
_runReward(enc){
this._writeAuxParams(this.bufP_rew, {c_adp:0, mode:0});
const p = enc.beginComputePass();
p.setPipeline(this.pipeReward); p.setBindGroup(0, this.bgRew);
p.dispatchWorkgroups(Math.ceil(this.B/64));
p.end();
}
_runAdapt(enc, adpC, adpSteps){
if(adpSteps<=0) return;
this._writeAuxParams(this.bufP_adapt_conv, {c_adp:adpC, mode:0});
this._writeAuxParams(this.bufP_adapt_out, {c_adp:adpC, mode:0});
const MAX_WG_X = 65535;
const nc = this.B * this.convFlat, no = this.B * this.O;
const wgC = Math.ceil(nc/64), wgO = Math.ceil(no/64);
for(let a=0; a<adpSteps; a++){
let p = enc.beginComputePass();
p.setPipeline(this.pipeAdapt); p.setBindGroup(0, this.bgAdaptConv);
p.dispatchWorkgroups(Math.min(wgC, MAX_WG_X), Math.ceil(wgC/MAX_WG_X));
p.end();
p = enc.beginComputePass();
p.setPipeline(this.pipeAdapt); p.setBindGroup(0, this.bgAdaptOut);
p.dispatchWorkgroups(Math.min(wgO, MAX_WG_X), Math.ceil(wgO/MAX_WG_X));
p.end();
}
}
_runGrad(enc, beta){
this._writeGradConvParams(2*beta);
this._writeGradDenseParams(2*beta);
// grad conv
let p = enc.beginComputePass();
p.setPipeline(this.pipeGWconv); p.setBindGroup(0, this.bgGC);
p.dispatchWorkgroups(Math.ceil(this.cfg.KW/8), Math.ceil(this.cfg.KH/8), this.cfg.Cout*this.cfg.Cin);
p.setPipeline(this.pipeGBconv); p.setBindGroup(0, this.bgGC);
p.dispatchWorkgroups(Math.ceil(this.cfg.Cout/64));
p.end();
// grad dense
p = enc.beginComputePass();
p.setPipeline(this.pipeGWdense); p.setBindGroup(0, this.bgGD);
p.dispatchWorkgroups(Math.ceil(this.convFlat/8), Math.ceil(this.O/8));
p.setPipeline(this.pipeGBdense); p.setBindGroup(0, this.bgGD);
p.dispatchWorkgroups(Math.ceil(this.O/64));
p.end();
}
async runFreeAndReadOutputs(iters, dt, gamma=0.6){
this._writeAllPhaseUniforms(dt, 0, gamma);
const enc = this.dev.createCommandEncoder();
this._initAllPhases(enc);
this._runPhaseRelax(enc, PHASE_F, iters);
enc.copyBufferToBuffer(this.bufUout[PHASE_F], 0, this.rbUoF, 0, this.B*this.O*4);
this.dev.queue.submit([enc.finish()]);
await this.rbUoF.mapAsync(GPUMapMode.READ);
const r = new Float32Array(this.rbUoF.getMappedRange().slice(0));
this.rbUoF.unmap();
return r;
}
// Greedy layer-wise EqProp helper:
// After conv1 is trained, run the free phase and return the post-σ conv hidden state
// so it can be re-fed into a second conv stage as input.
// Returns Float32Array of shape [B * Cout * Hout * Wout], values in [0,1] after σ.
async runFreeAndReadConvHidden(iters, dt, gamma=0.6){
this._writeAllPhaseUniforms(dt, 0, gamma);
const enc = this.dev.createCommandEncoder();
this._initAllPhases(enc);
this._runPhaseRelax(enc, PHASE_F, iters);
enc.copyBufferToBuffer(this.bufUconv[PHASE_F], 0, this.rbUconvF, 0, this.B*this.convFlat*4);
this.dev.queue.submit([enc.finish()]);
await this.rbUconvF.mapAsync(GPUMapMode.READ);
const raw = new Float32Array(this.rbUconvF.getMappedRange().slice(0));
this.rbUconvF.unmap();
// Apply σ(u) = 1/(1+exp(-4(u-0.5))) — same adaptive-mode firing rate as in WGSL.
const out = new Float32Array(raw.length);
for(let i=0;i<raw.length;i++){
out[i] = 1 / (1 + Math.exp(-4*(raw[i]-0.5)));
}
return out;
}
async runOnePassGetGradients({itF=8, itN=5, dt=0.7, beta=0.5, gamma=0.6, adpC=0.15, adpSteps=3}={}){
this._writeAllPhaseUniforms(dt, beta, gamma);
const enc = this.dev.createCommandEncoder();
this._initAllPhases(enc);
this._runPhaseRelax(enc, PHASE_F, itF);
this._runPhaseRelax(enc, PHASE_P, itN);
this._runPhaseRelax(enc, PHASE_M, itN);
this._runReward(enc);
this._runAdapt(enc, adpC, adpSteps);
this._runGrad(enc, beta);
enc.copyBufferToBuffer(this.bufGWconv, 0, this.rbGWconv, 0, this.cfg.Cout*this.cfg.Cin*this.cfg.KH*this.cfg.KW*4);
enc.copyBufferToBuffer(this.bufGBconv, 0, this.rbGBconv, 0, this.cfg.Cout*4);
enc.copyBufferToBuffer(this.bufGWdense, 0, this.rbGWdense, 0, this.O*this.convFlat*4);
enc.copyBufferToBuffer(this.bufGBdense, 0, this.rbGBdense, 0, this.O*4);
enc.copyBufferToBuffer(this.bufUout[PHASE_F], 0, this.rbUoF, 0, this.B*this.O*4);
this.dev.queue.submit([enc.finish()]);
await Promise.all([
this.rbGWconv.mapAsync(GPUMapMode.READ),
this.rbGBconv.mapAsync(GPUMapMode.READ),
this.rbGWdense.mapAsync(GPUMapMode.READ),
this.rbGBdense.mapAsync(GPUMapMode.READ),
this.rbUoF.mapAsync(GPUMapMode.READ),
]);
const gWconv = new Float32Array(this.rbGWconv.getMappedRange().slice(0));
const gBconv = new Float32Array(this.rbGBconv.getMappedRange().slice(0));
const gWdense = new Float32Array(this.rbGWdense.getMappedRange().slice(0));
const gBdense = new Float32Array(this.rbGBdense.getMappedRange().slice(0));
const uoF = new Float32Array(this.rbUoF.getMappedRange().slice(0));
this.rbGWconv.unmap(); this.rbGBconv.unmap(); this.rbGWdense.unmap(); this.rbGBdense.unmap(); this.rbUoF.unmap();
return {gWconv, gBconv, gWdense, gBdense, uoF};
}
destroy(){
const bufs = [this.bufXin, this.bufTgt, this.bufWconv, this.bufBconv, this.bufWdense, this.bufBdense,
this.bufR, this.bufDummyR, this.bufDummyRW1, this.bufDummyRW2, this.bufDummyRW3,
this.bufGWconv, this.bufGBconv, this.bufGWdense, this.bufGBdense,
this.rbUoF, this.rbGWconv, this.rbGBconv, this.rbGWdense, this.rbGBdense,
this.bufP_grad_conv, this.bufP_grad_dense, this.bufP_rew, this.bufP_adapt_conv, this.bufP_adapt_out];
for(const arr of [this.bufUconv, this.bufUout, this.bufP_conv, this.bufP_dense, this.bufP_init_conv, this.bufP_init_dense]) bufs.push(...arr);
for(const v of bufs) if(v && v.destroy) try{ v.destroy(); }catch(e){}
}
}