Spaces:
Running
Running
| // Full convolutional EqProp trainer (WebGPU). | |
| // Architecture: one conv layer (weight-shared kernel) + one dense readout to O classes. | |
| // All EqProp, no backprop. Validated vs ConvEqPropNet CPU reference. | |
| import { orth as orthCPU } from './eqprop_lib.js'; | |
| const WGSL_CONV_RELAX = ` | |
| struct CP { | |
| B: u32, Cin: u32, Cout: u32, H: u32, | |
| W: u32, Hout: u32, Wout: u32, KH: u32, | |
| KW: u32, stride: u32, pad: u32, _p0: u32, | |
| dt: f32, beta: f32, gamma: f32, _p1: f32, | |
| has_topdown: u32, nxt_size: u32, has_target: u32, _p2: u32, | |
| }; | |
| @group(0) @binding(0) var<uniform> p : CP; | |
| @group(0) @binding(1) var<storage, read> Xin : array<f32>; | |
| @group(0) @binding(2) var<storage, read> Wt : array<f32>; | |
| @group(0) @binding(3) var<storage, read> Bs : array<f32>; | |
| @group(0) @binding(4) var<storage, read> Wnxt: array<f32>; | |
| @group(0) @binding(5) var<storage, read_write> Uh : array<f32>; | |
| @group(0) @binding(6) var<storage, read> Uo : array<f32>; | |
| fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } | |
| fn rho(u: f32) -> f32 { return sg(u); } | |
| @compute @workgroup_size(8, 8, 1) fn conv_pass(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let xo = gid.x; let yo = gid.y; let bk = gid.z; | |
| if (xo >= p.Wout || yo >= p.Hout) { return; } | |
| let b = bk / p.Cout; let k = bk % p.Cout; | |
| if (b >= p.B) { return; } | |
| let img_size = p.Cin * p.H * p.W; | |
| let map_size = p.Cout * p.Hout * p.Wout; | |
| var c : f32 = Bs[k]; | |
| for (var kin: u32 = 0u; kin < p.Cin; kin = kin + 1u) { | |
| for (var dy: u32 = 0u; dy < p.KH; dy = dy + 1u) { | |
| let iy_s = i32(yo * p.stride + dy) - i32(p.pad); | |
| if (iy_s < 0 || iy_s >= i32(p.H)) { continue; } | |
| let iy = u32(iy_s); | |
| for (var dx: u32 = 0u; dx < p.KW; dx = dx + 1u) { | |
| let ix_s = i32(xo * p.stride + dx) - i32(p.pad); | |
| if (ix_s < 0 || ix_s >= i32(p.W)) { continue; } | |
| let ix = u32(ix_s); | |
| let xidx = b * img_size + kin * p.H * p.W + iy * p.W + ix; | |
| let widx = ((k * p.Cin + kin) * p.KH + dy) * p.KW + dx; | |
| c = c + Wt[widx] * rho(Xin[xidx]); | |
| } | |
| } | |
| } | |
| if (p.has_topdown != 0u) { | |
| var td : f32 = 0.0; | |
| let this_flat = k * p.Hout * p.Wout + yo * p.Wout + xo; | |
| for (var n: u32 = 0u; n < p.nxt_size; n = n + 1u) { | |
| td = td + Wnxt[n * map_size + this_flat] * rho(Uo[b * p.nxt_size + n]); | |
| } | |
| c = c + p.gamma * td; | |
| } | |
| // STABILIZER #4: clip pre-activation drive to a bounded range BEFORE σ. | |
| c = clamp(c, -2.5, 3.5); | |
| // SYNTHESIS #3: LOCAL RESPONSE NORMALIZATION (LRN) — V1 cross-channel lateral inhibition. | |
| // For each (b, k, y, x), divisively normalize σ(c) by the L2 norm of neighbor-channel activities at | |
| // the same spatial position. This is biologically faithful (V1 surround inhibition) and breaks the | |
| // saturation positive-feedback loop by ensuring no single channel can dominate. | |
| // normalized = σ(c) / (k₀ + α · Σ_{k' ∈ [k-n, k+n]} Uh[b,k',y,x]²)^β | |
| // We read Uh from PREVIOUS iter (current-iter writes haven't happened yet within this pass). | |
| // Applied identically in all 3 phases (free/+β/−β) so the EqProp gradient remains consistent | |
| // for the LRN-modified energy. | |
| let chan_size = p.Hout * p.Wout; | |
| let spatial_off = b * map_size + yo * p.Wout + xo; | |
| var sum_sq : f32 = 0.0; | |
| let n_half : i32 = 2; | |
| let kk_min : i32 = max(0, i32(k) - n_half); | |
| let kk_max : i32 = min(i32(p.Cout) - 1, i32(k) + n_half); | |
| for (var kk: i32 = kk_min; kk <= kk_max; kk = kk + 1) { | |
| let v = Uh[spatial_off + u32(kk) * chan_size]; | |
| sum_sq = sum_sq + v * v; | |
| } | |
| // LRN DISABLED (v13 found it degrades single-conv accuracy; keep code path for re-enable). | |
| let sg_c_norm = sg(c); | |
| let idx = b * map_size + k * p.Hout * p.Wout + yo * p.Wout + xo; | |
| let u_old = Uh[idx]; | |
| let drive = -u_old + sg_c_norm; | |
| Uh[idx] = u_old + p.dt * drive; | |
| } | |
| @compute @workgroup_size(64) fn init_state(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let stride = 65535u * 64u; | |
| let g = gid.y * stride + gid.x; | |
| let n = p.B * p.Cout * p.Hout * p.Wout; | |
| if (g < n) { Uh[g] = 0.1; } | |
| } | |
| `; | |
| // Dense output pass: u_out[b, i] = -u + σ(b_i + Σ_j W[i,j] σ(conv_flat[b,j])) + β·(tgt-u) | |
| const WGSL_DENSE_OUT = ` | |
| struct DP { | |
| B: u32, Ni: u32, No: u32, _p0: u32, | |
| dt: f32, beta: f32, _p1: f32, _p2: f32, | |
| has_target: u32, _p3: u32, _p4: u32, _p5: u32, | |
| }; | |
| @group(0) @binding(0) var<uniform> p : DP; | |
| @group(0) @binding(1) var<storage, read> Xin : array<f32>; // conv hidden, [B*Ni] | |
| @group(0) @binding(2) var<storage, read> Wt : array<f32>; // [No*Ni] | |
| @group(0) @binding(3) var<storage, read> Bs : array<f32>; // [No] | |
| @group(0) @binding(4) var<storage, read> Tgt : array<f32>; // [B*No] | |
| @group(0) @binding(5) var<storage, read_write> Uo : array<f32>; // [B*No] | |
| fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } | |
| fn rho(u: f32) -> f32 { return sg(u); } | |
| @compute @workgroup_size(64, 1) fn dense_pass(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let b = gid.y; let i = gid.x; | |
| if (b >= p.B || i >= p.No) { return; } | |
| var c : f32 = Bs[i]; | |
| for (var j: u32 = 0u; j < p.Ni; j = j + 1u) { | |
| c = c + Wt[i * p.Ni + j] * rho(Xin[b * p.Ni + j]); | |
| } | |
| // STABILIZER #4 (dense output): clamp pre-activation drive too — same reasoning. | |
| c = clamp(c, -2.5, 3.5); | |
| let idx = b * p.No + i; | |
| let u_old = Uo[idx]; | |
| var drive : f32 = -u_old + sg(c); | |
| if (p.has_target != 0u && p.beta != 0.0) { | |
| drive = drive + p.beta * (Tgt[idx] - u_old); | |
| } | |
| Uo[idx] = u_old + p.dt * drive; | |
| } | |
| @compute @workgroup_size(64) fn init_state_out(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let g = gid.x; let n = p.B * p.No; | |
| if (g < n) { Uo[g] = 0.1; } | |
| } | |
| `; | |
| const WGSL_AUX_CONV = ` | |
| struct AP { | |
| B: u32, O: u32, convFlat: u32, _p0: u32, | |
| c_adp: f32, mode: f32, _p1: f32, _p2: f32, | |
| }; | |
| @group(0) @binding(0) var<uniform> p : AP; | |
| @group(0) @binding(1) var<storage, read> UoF : array<f32>; | |
| @group(0) @binding(2) var<storage, read> Tgt : array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> R : array<f32>; | |
| @group(0) @binding(4) var<storage, read> Uf : array<f32>; | |
| @group(0) @binding(5) var<storage, read_write> Up : array<f32>; | |
| @group(0) @binding(6) var<storage, read_write> Um : array<f32>; | |
| fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } | |
| fn rho_out(u: f32) -> f32 { return sg(u); } | |
| @compute @workgroup_size(64) fn compute_reward(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let b = gid.x; | |
| if (b >= p.B) { return; } | |
| var loss : f32 = 0.0; | |
| let off = b * p.O; | |
| for (var i: u32 = 0u; i < p.O; i = i + 1u) { | |
| let d = rho_out(UoF[off + i]) - Tgt[off + i]; | |
| loss = loss + d * d; | |
| } | |
| var r : f32 = loss / 0.4; | |
| if (r > 1.0) { r = 1.0; } | |
| R[b] = 0.1 + 0.9 * r; | |
| } | |
| @compute @workgroup_size(64) fn adapt_layer(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let stride = 65535u * 64u; | |
| let g = gid.y * stride + gid.x; | |
| if (g >= arrayLength(&Uf)) { return; } | |
| let f = Uf[g]; | |
| Up[g] = (1.0 - p.c_adp) * Up[g] + p.c_adp * f; | |
| Um[g] = (1.0 - p.c_adp) * Um[g] + p.c_adp * f; | |
| } | |
| `; | |
| const WGSL_GRAD_CONV = ` | |
| struct CGP { | |
| B: u32, Cin: u32, Cout: u32, H: u32, | |
| W: u32, Hout: u32, Wout: u32, KH: u32, | |
| KW: u32, stride: u32, pad: u32, _p0: u32, | |
| two_beta: f32, _p1: f32, _p2: f32, _p3: f32, | |
| }; | |
| @group(0) @binding(0) var<uniform> p : CGP; | |
| @group(0) @binding(1) var<storage, read> Xp : array<f32>; | |
| @group(0) @binding(2) var<storage, read> Xm : array<f32>; | |
| @group(0) @binding(3) var<storage, read> Up : array<f32>; | |
| @group(0) @binding(4) var<storage, read> Um : array<f32>; | |
| @group(0) @binding(5) var<storage, read> R : array<f32>; | |
| @group(0) @binding(6) var<storage, read_write> gW : array<f32>; | |
| @group(0) @binding(7) var<storage, read_write> gB : array<f32>; | |
| fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } | |
| fn rho(u: f32) -> f32 { return sg(u); } | |
| @compute @workgroup_size(8, 8, 1) fn grad_W_conv(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let dx = gid.x; let dy = gid.y; let kk = gid.z; | |
| if (dx >= p.KW || dy >= p.KH) { return; } | |
| let kout = kk / p.Cin; let kin = kk % p.Cin; | |
| if (kout >= p.Cout) { return; } | |
| let img_size = p.Cin * p.H * p.W; | |
| let map_size = p.Cout * p.Hout * p.Wout; | |
| var acc : f32 = 0.0; | |
| for (var b: u32 = 0u; b < p.B; b = b + 1u) { | |
| let rb = R[b]; | |
| for (var yo: u32 = 0u; yo < p.Hout; yo = yo + 1u) { | |
| let iy_s = i32(yo * p.stride + dy) - i32(p.pad); | |
| if (iy_s < 0 || iy_s >= i32(p.H)) { continue; } | |
| let iy = u32(iy_s); | |
| for (var xo: u32 = 0u; xo < p.Wout; xo = xo + 1u) { | |
| let ix_s = i32(xo * p.stride + dx) - i32(p.pad); | |
| if (ix_s < 0 || ix_s >= i32(p.W)) { continue; } | |
| let ix = u32(ix_s); | |
| let u_flat = b * map_size + kout * p.Hout * p.Wout + yo * p.Wout + xo; | |
| let x_flat = b * img_size + kin * p.H * p.W + iy * p.W + ix; | |
| acc = acc + rb * (rho(Up[u_flat]) * rho(Xp[x_flat]) - rho(Um[u_flat]) * rho(Xm[x_flat])); | |
| } | |
| } | |
| } | |
| let widx = ((kout * p.Cin + kin) * p.KH + dy) * p.KW + dx; | |
| gW[widx] = acc / p.two_beta; | |
| } | |
| @compute @workgroup_size(64) fn grad_B_conv(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let kout = gid.x; | |
| if (kout >= p.Cout) { return; } | |
| let map_size = p.Cout * p.Hout * p.Wout; | |
| var acc : f32 = 0.0; | |
| for (var b: u32 = 0u; b < p.B; b = b + 1u) { | |
| let rb = R[b]; | |
| for (var yo: u32 = 0u; yo < p.Hout; yo = yo + 1u) { | |
| for (var xo: u32 = 0u; xo < p.Wout; xo = xo + 1u) { | |
| let u_flat = b * map_size + kout * p.Hout * p.Wout + yo * p.Wout + xo; | |
| acc = acc + rb * (rho(Up[u_flat]) - rho(Um[u_flat])); | |
| } | |
| } | |
| } | |
| gB[kout] = acc / p.two_beta; | |
| } | |
| `; | |
| const WGSL_GRAD_DENSE = ` | |
| struct DGP { | |
| B: u32, Ni: u32, No: u32, _p0: u32, | |
| two_beta: f32, _p1: f32, _p2: f32, _p3: f32, | |
| }; | |
| @group(0) @binding(0) var<uniform> p : DGP; | |
| @group(0) @binding(1) var<storage, read> Xp : array<f32>; // pre-layer plus phase [B*Ni] | |
| @group(0) @binding(2) var<storage, read> Xm : array<f32>; | |
| @group(0) @binding(3) var<storage, read> Up : array<f32>; // post-layer plus phase [B*No] | |
| @group(0) @binding(4) var<storage, read> Um : array<f32>; | |
| @group(0) @binding(5) var<storage, read> R : array<f32>; | |
| @group(0) @binding(6) var<storage, read_write> gW : array<f32>; | |
| @group(0) @binding(7) var<storage, read_write> gB : array<f32>; | |
| fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } | |
| fn rho(u: f32) -> f32 { return sg(u); } | |
| @compute @workgroup_size(8, 8) fn grad_W_dense(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let i = gid.y; let j = gid.x; | |
| if (i >= p.No || j >= p.Ni) { return; } | |
| var acc : f32 = 0.0; | |
| for (var b: u32 = 0u; b < p.B; b = b + 1u) { | |
| let rb = R[b]; | |
| let ip = rho(Up[b * p.No + i]); let im = rho(Um[b * p.No + i]); | |
| let jp = rho(Xp[b * p.Ni + j]); let jm = rho(Xm[b * p.Ni + j]); | |
| acc = acc + rb * (ip * jp - im * jm); | |
| } | |
| gW[i * p.Ni + j] = acc / p.two_beta; | |
| } | |
| @compute @workgroup_size(64) fn grad_B_dense(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let i = gid.x; | |
| if (i >= p.No) { return; } | |
| var acc : f32 = 0.0; | |
| for (var b: u32 = 0u; b < p.B; b = b + 1u) { | |
| let rb = R[b]; | |
| acc = acc + rb * (rho(Up[b * p.No + i]) - rho(Um[b * p.No + i])); | |
| } | |
| gB[i] = acc / p.two_beta; | |
| } | |
| `; | |
| export async function makeGPUConv({powerPreference='high-performance'}={}){ | |
| if(!navigator.gpu) throw new Error('no webgpu'); | |
| const adapter = await navigator.gpu.requestAdapter({powerPreference}); | |
| if(!adapter) throw new Error('no adapter'); | |
| const want = {}; | |
| for(const k of ['maxStorageBuffersPerShaderStage','maxBufferSize','maxStorageBufferBindingSize', | |
| 'maxComputeInvocationsPerWorkgroup','maxComputeWorkgroupSizeX','maxComputeWorkgroupStorageSize','maxBindGroups']){ | |
| const v=adapter.limits[k]; if(typeof v==='number') want[k]=v; | |
| } | |
| const dev = await adapter.requestDevice({requiredLimits: want}); | |
| return {adapter, dev, info: adapter.info||{}}; | |
| } | |
| const PHASE_F = 0, PHASE_P = 1, PHASE_M = 2; | |
| export class GPUTrainerConvFull { | |
| constructor({dev, convCfg, denseSize, B}){ | |
| this.dev = dev; this.cfg = convCfg; this.B = B; this.O = denseSize; | |
| this.Hout = Math.floor((convCfg.H + 2*convCfg.pad - convCfg.KH)/convCfg.stride) + 1; | |
| this.Wout = Math.floor((convCfg.W + 2*convCfg.pad - convCfg.KW)/convCfg.stride) + 1; | |
| this.convFlat = convCfg.Cout * this.Hout * this.Wout; | |
| this._build(); | |
| } | |
| _F32(n, usage){ return this.dev.createBuffer({size: Math.max(4, n*4), usage}); } | |
| _build(){ | |
| const dev = this.dev, B = this.B, c = this.cfg, O = this.O; | |
| const RW = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST; | |
| const R = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC; | |
| const UNI= GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST; | |
| const RDS= GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ; | |
| const imgSize = c.Cin * c.H * c.W; | |
| const cw_n = c.Cout * c.Cin * c.KH * c.KW; | |
| // Input + targets (shared across phases) | |
| this.bufXin = this._F32(B * imgSize, R); | |
| this.bufTgt = this._F32(B * O, R); | |
| // Weights | |
| this.bufWconv = this._F32(cw_n, R); | |
| this.bufBconv = this._F32(c.Cout, R); | |
| this.bufWdense = this._F32(O * this.convFlat, R); | |
| this.bufBdense = this._F32(O, R); | |
| // Per-phase states | |
| this.bufUconv = []; | |
| this.bufUout = []; | |
| for(let phase=0; phase<3; phase++){ | |
| this.bufUconv.push(this._F32(B * this.convFlat, RW)); | |
| this.bufUout.push(this._F32(B * O, RW)); | |
| } | |
| // Reward + dummies (need separate dummy buffers for distinct read_write slots to avoid aliasing) | |
| this.bufR = this._F32(B, RW); | |
| this.bufDummyR = this._F32(4, R); | |
| this.bufDummyRW1 = this._F32(4, RW); | |
| this.bufDummyRW2 = this._F32(4, RW); | |
| this.bufDummyRW3 = this._F32(4, RW); | |
| // Gradient buffers | |
| this.bufGWconv = this._F32(cw_n, RW); | |
| this.bufGBconv = this._F32(c.Cout, RW); | |
| this.bufGWdense= this._F32(O * this.convFlat, RW); | |
| this.bufGBdense= this._F32(O, RW); | |
| // Readback | |
| this.rbUoF = dev.createBuffer({size: B*O*4, usage: RDS}); | |
| this.rbUconvF = dev.createBuffer({size: B*this.convFlat*4, usage: RDS}); | |
| this.rbGWconv = dev.createBuffer({size: cw_n*4, usage: RDS}); | |
| this.rbGBconv = dev.createBuffer({size: c.Cout*4, usage: RDS}); | |
| this.rbGWdense = dev.createBuffer({size: O*this.convFlat*4, usage: RDS}); | |
| this.rbGBdense = dev.createBuffer({size: O*4, usage: RDS}); | |
| // Uniform buffers: one per phase for conv (CP=80 bytes), one per phase for dense (DP=48 bytes), one for grads (CGP=64, DGP=32), aux (AP=32) | |
| this.bufP_conv = []; | |
| this.bufP_dense = []; | |
| this.bufP_init_conv = []; | |
| this.bufP_init_dense = []; | |
| for(let phase=0; phase<3; phase++){ | |
| this.bufP_conv.push(dev.createBuffer({size:80, usage:UNI})); | |
| this.bufP_dense.push(dev.createBuffer({size:48, usage:UNI})); | |
| this.bufP_init_conv.push(dev.createBuffer({size:80, usage:UNI})); | |
| this.bufP_init_dense.push(dev.createBuffer({size:48, usage:UNI})); | |
| } | |
| this.bufP_grad_conv = dev.createBuffer({size:64, usage:UNI}); | |
| this.bufP_grad_dense = dev.createBuffer({size:32, usage:UNI}); | |
| this.bufP_rew = dev.createBuffer({size:32, usage:UNI}); | |
| this.bufP_adapt_conv = dev.createBuffer({size:32, usage:UNI}); | |
| this.bufP_adapt_out = dev.createBuffer({size:32, usage:UNI}); | |
| // Pipelines | |
| const sR = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'read-only-storage'}}); | |
| const sRW = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'storage'}}); | |
| const uN = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'uniform'}}); | |
| // ---- Conv relax pipeline ---- | |
| const modConv = dev.createShaderModule({code: WGSL_CONV_RELAX}); | |
| this.bglConv = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sRW(5), sR(6)]}); | |
| this.plConv = dev.createPipelineLayout({bindGroupLayouts:[this.bglConv]}); | |
| this.pipeConv = dev.createComputePipeline({layout:this.plConv, compute:{module:modConv, entryPoint:'conv_pass'}}); | |
| this.pipeInitConv = dev.createComputePipeline({layout:this.plConv, compute:{module:modConv, entryPoint:'init_state'}}); | |
| // ---- Dense relax pipeline ---- | |
| const modDense = dev.createShaderModule({code: WGSL_DENSE_OUT}); | |
| this.bglDense = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sRW(5)]}); | |
| this.plDense = dev.createPipelineLayout({bindGroupLayouts:[this.bglDense]}); | |
| this.pipeDense = dev.createComputePipeline({layout:this.plDense, compute:{module:modDense, entryPoint:'dense_pass'}}); | |
| this.pipeInitDense= dev.createComputePipeline({layout:this.plDense, compute:{module:modDense, entryPoint:'init_state_out'}}); | |
| // ---- Aux (reward + adapt) pipeline ---- | |
| const modAux = dev.createShaderModule({code: WGSL_AUX_CONV}); | |
| this.bglAux = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sRW(3), sR(4), sRW(5), sRW(6)]}); | |
| this.plAux = dev.createPipelineLayout({bindGroupLayouts:[this.bglAux]}); | |
| this.pipeReward = dev.createComputePipeline({layout:this.plAux, compute:{module:modAux, entryPoint:'compute_reward'}}); | |
| this.pipeAdapt = dev.createComputePipeline({layout:this.plAux, compute:{module:modAux, entryPoint:'adapt_layer'}}); | |
| // ---- Grad conv pipeline ---- | |
| const modGC = dev.createShaderModule({code: WGSL_GRAD_CONV}); | |
| this.bglGC = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]}); | |
| this.plGC = dev.createPipelineLayout({bindGroupLayouts:[this.bglGC]}); | |
| this.pipeGWconv = dev.createComputePipeline({layout:this.plGC, compute:{module:modGC, entryPoint:'grad_W_conv'}}); | |
| this.pipeGBconv = dev.createComputePipeline({layout:this.plGC, compute:{module:modGC, entryPoint:'grad_B_conv'}}); | |
| // ---- Grad dense pipeline ---- | |
| const modGD = dev.createShaderModule({code: WGSL_GRAD_DENSE}); | |
| this.bglGD = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]}); | |
| this.plGD = dev.createPipelineLayout({bindGroupLayouts:[this.bglGD]}); | |
| this.pipeGWdense = dev.createComputePipeline({layout:this.plGD, compute:{module:modGD, entryPoint:'grad_W_dense'}}); | |
| this.pipeGBdense = dev.createComputePipeline({layout:this.plGD, compute:{module:modGD, entryPoint:'grad_B_dense'}}); | |
| // ---- Bind groups ---- | |
| // Conv per phase (relax + init each use same layout, different uniform buffer) | |
| this.bgConv = [], this.bgInitConv = []; | |
| for(let phase=0; phase<3; phase++){ | |
| this.bgConv.push(dev.createBindGroup({layout:this.bglConv, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_conv[phase]}}, | |
| {binding:1, resource:{buffer:this.bufXin}}, | |
| {binding:2, resource:{buffer:this.bufWconv}}, | |
| {binding:3, resource:{buffer:this.bufBconv}}, | |
| {binding:4, resource:{buffer:this.bufWdense}}, | |
| {binding:5, resource:{buffer:this.bufUconv[phase]}}, | |
| {binding:6, resource:{buffer:this.bufUout[phase]}}, | |
| ]})); | |
| this.bgInitConv.push(dev.createBindGroup({layout:this.bglConv, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_init_conv[phase]}}, | |
| {binding:1, resource:{buffer:this.bufXin}}, | |
| {binding:2, resource:{buffer:this.bufWconv}}, | |
| {binding:3, resource:{buffer:this.bufBconv}}, | |
| {binding:4, resource:{buffer:this.bufWdense}}, | |
| {binding:5, resource:{buffer:this.bufUconv[phase]}}, | |
| {binding:6, resource:{buffer:this.bufUout[phase]}}, | |
| ]})); | |
| } | |
| // Dense per phase | |
| this.bgDense = [], this.bgInitDense = []; | |
| for(let phase=0; phase<3; phase++){ | |
| this.bgDense.push(dev.createBindGroup({layout:this.bglDense, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_dense[phase]}}, | |
| {binding:1, resource:{buffer:this.bufUconv[phase]}}, // input = conv hidden | |
| {binding:2, resource:{buffer:this.bufWdense}}, | |
| {binding:3, resource:{buffer:this.bufBdense}}, | |
| {binding:4, resource:{buffer:this.bufTgt}}, | |
| {binding:5, resource:{buffer:this.bufUout[phase]}}, | |
| ]})); | |
| this.bgInitDense.push(dev.createBindGroup({layout:this.bglDense, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_init_dense[phase]}}, | |
| {binding:1, resource:{buffer:this.bufUconv[phase]}}, | |
| {binding:2, resource:{buffer:this.bufWdense}}, | |
| {binding:3, resource:{buffer:this.bufBdense}}, | |
| {binding:4, resource:{buffer:this.bufTgt}}, | |
| {binding:5, resource:{buffer:this.bufUout[phase]}}, | |
| ]})); | |
| } | |
| // Aux: reward (uses Uo_free) + adapt (conv, out) | |
| this.bgRew = dev.createBindGroup({layout:this.bglAux, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_rew}}, | |
| {binding:1, resource:{buffer:this.bufUout[PHASE_F]}}, | |
| {binding:2, resource:{buffer:this.bufTgt}}, | |
| {binding:3, resource:{buffer:this.bufR}}, | |
| {binding:4, resource:{buffer:this.bufDummyR}}, | |
| {binding:5, resource:{buffer:this.bufDummyRW1}}, | |
| {binding:6, resource:{buffer:this.bufDummyRW2}}, | |
| ]}); | |
| this.bgAdaptConv = dev.createBindGroup({layout:this.bglAux, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_adapt_conv}}, | |
| {binding:1, resource:{buffer:this.bufDummyR}}, | |
| {binding:2, resource:{buffer:this.bufDummyR}}, | |
| {binding:3, resource:{buffer:this.bufDummyRW3}}, | |
| {binding:4, resource:{buffer:this.bufUconv[PHASE_F]}}, | |
| {binding:5, resource:{buffer:this.bufUconv[PHASE_P]}}, | |
| {binding:6, resource:{buffer:this.bufUconv[PHASE_M]}}, | |
| ]}); | |
| this.bgAdaptOut = dev.createBindGroup({layout:this.bglAux, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_adapt_out}}, | |
| {binding:1, resource:{buffer:this.bufDummyR}}, | |
| {binding:2, resource:{buffer:this.bufDummyR}}, | |
| {binding:3, resource:{buffer:this.bufDummyRW3}}, | |
| {binding:4, resource:{buffer:this.bufUout[PHASE_F]}}, | |
| {binding:5, resource:{buffer:this.bufUout[PHASE_P]}}, | |
| {binding:6, resource:{buffer:this.bufUout[PHASE_M]}}, | |
| ]}); | |
| // Grad | |
| this.bgGC = dev.createBindGroup({layout:this.bglGC, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_grad_conv}}, | |
| {binding:1, resource:{buffer:this.bufXin}}, // Xp (conv input) | |
| {binding:2, resource:{buffer:this.bufXin}}, // Xm (same input — no augmentation in two phases) | |
| {binding:3, resource:{buffer:this.bufUconv[PHASE_P]}}, | |
| {binding:4, resource:{buffer:this.bufUconv[PHASE_M]}}, | |
| {binding:5, resource:{buffer:this.bufR}}, | |
| {binding:6, resource:{buffer:this.bufGWconv}}, | |
| {binding:7, resource:{buffer:this.bufGBconv}}, | |
| ]}); | |
| this.bgGD = dev.createBindGroup({layout:this.bglGD, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_grad_dense}}, | |
| {binding:1, resource:{buffer:this.bufUconv[PHASE_P]}}, | |
| {binding:2, resource:{buffer:this.bufUconv[PHASE_M]}}, | |
| {binding:3, resource:{buffer:this.bufUout[PHASE_P]}}, | |
| {binding:4, resource:{buffer:this.bufUout[PHASE_M]}}, | |
| {binding:5, resource:{buffer:this.bufR}}, | |
| {binding:6, resource:{buffer:this.bufGWdense}}, | |
| {binding:7, resource:{buffer:this.bufGBdense}}, | |
| ]}); | |
| } | |
| _writeConvParams(buf, fields){ | |
| const u32 = new Uint32Array(20); const f32 = new Float32Array(u32.buffer); | |
| u32[0]=this.B; u32[1]=this.cfg.Cin; u32[2]=this.cfg.Cout; u32[3]=this.cfg.H; | |
| u32[4]=this.cfg.W; u32[5]=this.Hout; u32[6]=this.Wout; u32[7]=this.cfg.KH; | |
| u32[8]=this.cfg.KW; u32[9]=this.cfg.stride; u32[10]=this.cfg.pad; u32[11]=0; | |
| f32[12]=fields.dt||0; f32[13]=fields.beta||0; f32[14]=fields.gamma||0; f32[15]=0; | |
| u32[16]=fields.has_topdown||0; u32[17]=this.O; u32[18]=fields.has_target||0; u32[19]=0; | |
| this.dev.queue.writeBuffer(buf, 0, u32.buffer); | |
| } | |
| _writeDenseParams(buf, fields){ | |
| const u32 = new Uint32Array(12); const f32 = new Float32Array(u32.buffer); | |
| u32[0]=this.B; u32[1]=this.convFlat; u32[2]=this.O; u32[3]=0; | |
| f32[4]=fields.dt||0; f32[5]=fields.beta||0; f32[6]=0; f32[7]=0; | |
| u32[8]=fields.has_target||0; u32[9]=0; u32[10]=0; u32[11]=0; | |
| this.dev.queue.writeBuffer(buf, 0, u32.buffer); | |
| } | |
| _writeGradConvParams(two_beta){ | |
| const u32 = new Uint32Array(16); const f32 = new Float32Array(u32.buffer); | |
| u32[0]=this.B; u32[1]=this.cfg.Cin; u32[2]=this.cfg.Cout; u32[3]=this.cfg.H; | |
| u32[4]=this.cfg.W; u32[5]=this.Hout; u32[6]=this.Wout; u32[7]=this.cfg.KH; | |
| u32[8]=this.cfg.KW; u32[9]=this.cfg.stride; u32[10]=this.cfg.pad; u32[11]=0; | |
| f32[12]=two_beta; f32[13]=0; f32[14]=0; f32[15]=0; | |
| this.dev.queue.writeBuffer(this.bufP_grad_conv, 0, u32.buffer); | |
| } | |
| _writeGradDenseParams(two_beta){ | |
| const u32 = new Uint32Array(8); const f32 = new Float32Array(u32.buffer); | |
| u32[0]=this.B; u32[1]=this.convFlat; u32[2]=this.O; u32[3]=0; | |
| f32[4]=two_beta; f32[5]=0; f32[6]=0; f32[7]=0; | |
| this.dev.queue.writeBuffer(this.bufP_grad_dense, 0, u32.buffer); | |
| } | |
| _writeAuxParams(buf, fields){ | |
| const u32 = new Uint32Array(8); const f32 = new Float32Array(u32.buffer); | |
| u32[0]=this.B; u32[1]=this.O; u32[2]=this.convFlat; u32[3]=0; | |
| f32[4]=fields.c_adp||0; f32[5]=fields.mode||0; f32[6]=0; f32[7]=0; | |
| this.dev.queue.writeBuffer(buf, 0, u32.buffer); | |
| } | |
| uploadWeights(Wconv, bconv, Wdense, bdense){ | |
| const q = this.dev.queue; | |
| q.writeBuffer(this.bufWconv, 0, Wconv.buffer, Wconv.byteOffset, Wconv.byteLength); | |
| q.writeBuffer(this.bufBconv, 0, bconv.buffer, bconv.byteOffset, bconv.byteLength); | |
| q.writeBuffer(this.bufWdense, 0, Wdense.buffer, Wdense.byteOffset, Wdense.byteLength); | |
| q.writeBuffer(this.bufBdense, 0, bdense.buffer, bdense.byteOffset, bdense.byteLength); | |
| } | |
| uploadInputs(X, T){ | |
| const q = this.dev.queue; | |
| q.writeBuffer(this.bufXin, 0, X.buffer, X.byteOffset, X.byteLength); | |
| q.writeBuffer(this.bufTgt, 0, T.buffer, T.byteOffset, T.byteLength); | |
| } | |
| _writeAllPhaseUniforms(dt, beta, gamma){ | |
| // free phase: beta=0 | |
| this._writeConvParams (this.bufP_conv[0], {dt, beta:0, gamma, has_topdown:1, has_target:0}); | |
| this._writeConvParams (this.bufP_conv[1], {dt, beta:0, gamma, has_topdown:1, has_target:0}); | |
| this._writeConvParams (this.bufP_conv[2], {dt, beta:0, gamma, has_topdown:1, has_target:0}); | |
| this._writeDenseParams(this.bufP_dense[0], {dt, beta:0, has_target:0}); | |
| this._writeDenseParams(this.bufP_dense[1], {dt, beta:+beta, has_target:1}); | |
| this._writeDenseParams(this.bufP_dense[2], {dt, beta:-beta, has_target:1}); | |
| // init uniforms (beta=0, same as relax) | |
| for(let phase=0; phase<3; phase++){ | |
| this._writeConvParams(this.bufP_init_conv[phase], {dt, beta:0, gamma:0, has_topdown:0, has_target:0}); | |
| this._writeDenseParams(this.bufP_init_dense[phase],{dt, beta:0, has_target:0}); | |
| } | |
| } | |
| _initAllPhases(enc){ | |
| const MAX_WG_X = 65535; | |
| for(let phase=0; phase<3; phase++){ | |
| // conv init | |
| const nc = this.B * this.convFlat; | |
| const wgC = Math.ceil(nc/64); | |
| let p = enc.beginComputePass(); | |
| p.setPipeline(this.pipeInitConv); p.setBindGroup(0, this.bgInitConv[phase]); | |
| p.dispatchWorkgroups(Math.min(wgC, MAX_WG_X), Math.ceil(wgC/MAX_WG_X)); | |
| p.end(); | |
| // dense init | |
| const no = this.B * this.O; | |
| p = enc.beginComputePass(); | |
| p.setPipeline(this.pipeInitDense); p.setBindGroup(0, this.bgInitDense[phase]); | |
| p.dispatchWorkgroups(Math.ceil(no/64)); | |
| p.end(); | |
| } | |
| } | |
| _runPhaseRelax(enc, phase, iters){ | |
| const B = this.B; | |
| for(let t=0; t<iters; t++){ | |
| // Conv pass | |
| let p = enc.beginComputePass(); | |
| p.setPipeline(this.pipeConv); p.setBindGroup(0, this.bgConv[phase]); | |
| p.dispatchWorkgroups(Math.ceil(this.Wout/8), Math.ceil(this.Hout/8), B*this.cfg.Cout); | |
| p.end(); | |
| // Dense pass | |
| p = enc.beginComputePass(); | |
| p.setPipeline(this.pipeDense); p.setBindGroup(0, this.bgDense[phase]); | |
| p.dispatchWorkgroups(Math.ceil(this.O/64), B); | |
| p.end(); | |
| } | |
| } | |
| _runReward(enc){ | |
| this._writeAuxParams(this.bufP_rew, {c_adp:0, mode:0}); | |
| const p = enc.beginComputePass(); | |
| p.setPipeline(this.pipeReward); p.setBindGroup(0, this.bgRew); | |
| p.dispatchWorkgroups(Math.ceil(this.B/64)); | |
| p.end(); | |
| } | |
| _runAdapt(enc, adpC, adpSteps){ | |
| if(adpSteps<=0) return; | |
| this._writeAuxParams(this.bufP_adapt_conv, {c_adp:adpC, mode:0}); | |
| this._writeAuxParams(this.bufP_adapt_out, {c_adp:adpC, mode:0}); | |
| const MAX_WG_X = 65535; | |
| const nc = this.B * this.convFlat, no = this.B * this.O; | |
| const wgC = Math.ceil(nc/64), wgO = Math.ceil(no/64); | |
| for(let a=0; a<adpSteps; a++){ | |
| let p = enc.beginComputePass(); | |
| p.setPipeline(this.pipeAdapt); p.setBindGroup(0, this.bgAdaptConv); | |
| p.dispatchWorkgroups(Math.min(wgC, MAX_WG_X), Math.ceil(wgC/MAX_WG_X)); | |
| p.end(); | |
| p = enc.beginComputePass(); | |
| p.setPipeline(this.pipeAdapt); p.setBindGroup(0, this.bgAdaptOut); | |
| p.dispatchWorkgroups(Math.min(wgO, MAX_WG_X), Math.ceil(wgO/MAX_WG_X)); | |
| p.end(); | |
| } | |
| } | |
| _runGrad(enc, beta){ | |
| this._writeGradConvParams(2*beta); | |
| this._writeGradDenseParams(2*beta); | |
| // grad conv | |
| let p = enc.beginComputePass(); | |
| p.setPipeline(this.pipeGWconv); p.setBindGroup(0, this.bgGC); | |
| p.dispatchWorkgroups(Math.ceil(this.cfg.KW/8), Math.ceil(this.cfg.KH/8), this.cfg.Cout*this.cfg.Cin); | |
| p.setPipeline(this.pipeGBconv); p.setBindGroup(0, this.bgGC); | |
| p.dispatchWorkgroups(Math.ceil(this.cfg.Cout/64)); | |
| p.end(); | |
| // grad dense | |
| p = enc.beginComputePass(); | |
| p.setPipeline(this.pipeGWdense); p.setBindGroup(0, this.bgGD); | |
| p.dispatchWorkgroups(Math.ceil(this.convFlat/8), Math.ceil(this.O/8)); | |
| p.setPipeline(this.pipeGBdense); p.setBindGroup(0, this.bgGD); | |
| p.dispatchWorkgroups(Math.ceil(this.O/64)); | |
| p.end(); | |
| } | |
| async runFreeAndReadOutputs(iters, dt, gamma=0.6){ | |
| this._writeAllPhaseUniforms(dt, 0, gamma); | |
| const enc = this.dev.createCommandEncoder(); | |
| this._initAllPhases(enc); | |
| this._runPhaseRelax(enc, PHASE_F, iters); | |
| enc.copyBufferToBuffer(this.bufUout[PHASE_F], 0, this.rbUoF, 0, this.B*this.O*4); | |
| this.dev.queue.submit([enc.finish()]); | |
| await this.rbUoF.mapAsync(GPUMapMode.READ); | |
| const r = new Float32Array(this.rbUoF.getMappedRange().slice(0)); | |
| this.rbUoF.unmap(); | |
| return r; | |
| } | |
| // Greedy layer-wise EqProp helper: | |
| // After conv1 is trained, run the free phase and return the post-σ conv hidden state | |
| // so it can be re-fed into a second conv stage as input. | |
| // Returns Float32Array of shape [B * Cout * Hout * Wout], values in [0,1] after σ. | |
| async runFreeAndReadConvHidden(iters, dt, gamma=0.6){ | |
| this._writeAllPhaseUniforms(dt, 0, gamma); | |
| const enc = this.dev.createCommandEncoder(); | |
| this._initAllPhases(enc); | |
| this._runPhaseRelax(enc, PHASE_F, iters); | |
| enc.copyBufferToBuffer(this.bufUconv[PHASE_F], 0, this.rbUconvF, 0, this.B*this.convFlat*4); | |
| this.dev.queue.submit([enc.finish()]); | |
| await this.rbUconvF.mapAsync(GPUMapMode.READ); | |
| const raw = new Float32Array(this.rbUconvF.getMappedRange().slice(0)); | |
| this.rbUconvF.unmap(); | |
| // Apply σ(u) = 1/(1+exp(-4(u-0.5))) — same adaptive-mode firing rate as in WGSL. | |
| const out = new Float32Array(raw.length); | |
| for(let i=0;i<raw.length;i++){ | |
| out[i] = 1 / (1 + Math.exp(-4*(raw[i]-0.5))); | |
| } | |
| return out; | |
| } | |
| async runOnePassGetGradients({itF=8, itN=5, dt=0.7, beta=0.5, gamma=0.6, adpC=0.15, adpSteps=3}={}){ | |
| this._writeAllPhaseUniforms(dt, beta, gamma); | |
| const enc = this.dev.createCommandEncoder(); | |
| this._initAllPhases(enc); | |
| this._runPhaseRelax(enc, PHASE_F, itF); | |
| this._runPhaseRelax(enc, PHASE_P, itN); | |
| this._runPhaseRelax(enc, PHASE_M, itN); | |
| this._runReward(enc); | |
| this._runAdapt(enc, adpC, adpSteps); | |
| this._runGrad(enc, beta); | |
| enc.copyBufferToBuffer(this.bufGWconv, 0, this.rbGWconv, 0, this.cfg.Cout*this.cfg.Cin*this.cfg.KH*this.cfg.KW*4); | |
| enc.copyBufferToBuffer(this.bufGBconv, 0, this.rbGBconv, 0, this.cfg.Cout*4); | |
| enc.copyBufferToBuffer(this.bufGWdense, 0, this.rbGWdense, 0, this.O*this.convFlat*4); | |
| enc.copyBufferToBuffer(this.bufGBdense, 0, this.rbGBdense, 0, this.O*4); | |
| enc.copyBufferToBuffer(this.bufUout[PHASE_F], 0, this.rbUoF, 0, this.B*this.O*4); | |
| this.dev.queue.submit([enc.finish()]); | |
| await Promise.all([ | |
| this.rbGWconv.mapAsync(GPUMapMode.READ), | |
| this.rbGBconv.mapAsync(GPUMapMode.READ), | |
| this.rbGWdense.mapAsync(GPUMapMode.READ), | |
| this.rbGBdense.mapAsync(GPUMapMode.READ), | |
| this.rbUoF.mapAsync(GPUMapMode.READ), | |
| ]); | |
| const gWconv = new Float32Array(this.rbGWconv.getMappedRange().slice(0)); | |
| const gBconv = new Float32Array(this.rbGBconv.getMappedRange().slice(0)); | |
| const gWdense = new Float32Array(this.rbGWdense.getMappedRange().slice(0)); | |
| const gBdense = new Float32Array(this.rbGBdense.getMappedRange().slice(0)); | |
| const uoF = new Float32Array(this.rbUoF.getMappedRange().slice(0)); | |
| this.rbGWconv.unmap(); this.rbGBconv.unmap(); this.rbGWdense.unmap(); this.rbGBdense.unmap(); this.rbUoF.unmap(); | |
| return {gWconv, gBconv, gWdense, gBdense, uoF}; | |
| } | |
| destroy(){ | |
| const bufs = [this.bufXin, this.bufTgt, this.bufWconv, this.bufBconv, this.bufWdense, this.bufBdense, | |
| this.bufR, this.bufDummyR, this.bufDummyRW1, this.bufDummyRW2, this.bufDummyRW3, | |
| this.bufGWconv, this.bufGBconv, this.bufGWdense, this.bufGBdense, | |
| this.rbUoF, this.rbGWconv, this.rbGBconv, this.rbGWdense, this.rbGBdense, | |
| this.bufP_grad_conv, this.bufP_grad_dense, this.bufP_rew, this.bufP_adapt_conv, this.bufP_adapt_out]; | |
| for(const arr of [this.bufUconv, this.bufUout, this.bufP_conv, this.bufP_dense, this.bufP_init_conv, this.bufP_init_dense]) bufs.push(...arr); | |
| for(const v of bufs) if(v && v.destroy) try{ v.destroy(); }catch(e){} | |
| } | |
| } | |