// Full convolutional EqProp trainer (WebGPU). // Architecture: one conv layer (weight-shared kernel) + one dense readout to O classes. // All EqProp, no backprop. Validated vs ConvEqPropNet CPU reference. import { orth as orthCPU } from './eqprop_lib.js'; const WGSL_CONV_RELAX = ` struct CP { B: u32, Cin: u32, Cout: u32, H: u32, W: u32, Hout: u32, Wout: u32, KH: u32, KW: u32, stride: u32, pad: u32, _p0: u32, dt: f32, beta: f32, gamma: f32, _p1: f32, has_topdown: u32, nxt_size: u32, has_target: u32, _p2: u32, }; @group(0) @binding(0) var p : CP; @group(0) @binding(1) var Xin : array; @group(0) @binding(2) var Wt : array; @group(0) @binding(3) var Bs : array; @group(0) @binding(4) var Wnxt: array; @group(0) @binding(5) var Uh : array; @group(0) @binding(6) var Uo : array; fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } fn rho(u: f32) -> f32 { return sg(u); } @compute @workgroup_size(8, 8, 1) fn conv_pass(@builtin(global_invocation_id) gid: vec3) { let xo = gid.x; let yo = gid.y; let bk = gid.z; if (xo >= p.Wout || yo >= p.Hout) { return; } let b = bk / p.Cout; let k = bk % p.Cout; if (b >= p.B) { return; } let img_size = p.Cin * p.H * p.W; let map_size = p.Cout * p.Hout * p.Wout; var c : f32 = Bs[k]; for (var kin: u32 = 0u; kin < p.Cin; kin = kin + 1u) { for (var dy: u32 = 0u; dy < p.KH; dy = dy + 1u) { let iy_s = i32(yo * p.stride + dy) - i32(p.pad); if (iy_s < 0 || iy_s >= i32(p.H)) { continue; } let iy = u32(iy_s); for (var dx: u32 = 0u; dx < p.KW; dx = dx + 1u) { let ix_s = i32(xo * p.stride + dx) - i32(p.pad); if (ix_s < 0 || ix_s >= i32(p.W)) { continue; } let ix = u32(ix_s); let xidx = b * img_size + kin * p.H * p.W + iy * p.W + ix; let widx = ((k * p.Cin + kin) * p.KH + dy) * p.KW + dx; c = c + Wt[widx] * rho(Xin[xidx]); } } } if (p.has_topdown != 0u) { var td : f32 = 0.0; let this_flat = k * p.Hout * p.Wout + yo * p.Wout + xo; for (var n: u32 = 0u; n < p.nxt_size; n = n + 1u) { td = td + Wnxt[n * map_size + this_flat] * rho(Uo[b * p.nxt_size + n]); } c = c + p.gamma * td; } // STABILIZER #4: clip pre-activation drive to a bounded range BEFORE σ. c = clamp(c, -2.5, 3.5); // SYNTHESIS #3: LOCAL RESPONSE NORMALIZATION (LRN) — V1 cross-channel lateral inhibition. // For each (b, k, y, x), divisively normalize σ(c) by the L2 norm of neighbor-channel activities at // the same spatial position. This is biologically faithful (V1 surround inhibition) and breaks the // saturation positive-feedback loop by ensuring no single channel can dominate. // normalized = σ(c) / (k₀ + α · Σ_{k' ∈ [k-n, k+n]} Uh[b,k',y,x]²)^β // We read Uh from PREVIOUS iter (current-iter writes haven't happened yet within this pass). // Applied identically in all 3 phases (free/+β/−β) so the EqProp gradient remains consistent // for the LRN-modified energy. let chan_size = p.Hout * p.Wout; let spatial_off = b * map_size + yo * p.Wout + xo; var sum_sq : f32 = 0.0; let n_half : i32 = 2; let kk_min : i32 = max(0, i32(k) - n_half); let kk_max : i32 = min(i32(p.Cout) - 1, i32(k) + n_half); for (var kk: i32 = kk_min; kk <= kk_max; kk = kk + 1) { let v = Uh[spatial_off + u32(kk) * chan_size]; sum_sq = sum_sq + v * v; } // LRN DISABLED (v13 found it degrades single-conv accuracy; keep code path for re-enable). let sg_c_norm = sg(c); let idx = b * map_size + k * p.Hout * p.Wout + yo * p.Wout + xo; let u_old = Uh[idx]; let drive = -u_old + sg_c_norm; Uh[idx] = u_old + p.dt * drive; } @compute @workgroup_size(64) fn init_state(@builtin(global_invocation_id) gid: vec3) { let stride = 65535u * 64u; let g = gid.y * stride + gid.x; let n = p.B * p.Cout * p.Hout * p.Wout; if (g < n) { Uh[g] = 0.1; } } `; // Dense output pass: u_out[b, i] = -u + σ(b_i + Σ_j W[i,j] σ(conv_flat[b,j])) + β·(tgt-u) const WGSL_DENSE_OUT = ` struct DP { B: u32, Ni: u32, No: u32, _p0: u32, dt: f32, beta: f32, _p1: f32, _p2: f32, has_target: u32, _p3: u32, _p4: u32, _p5: u32, }; @group(0) @binding(0) var p : DP; @group(0) @binding(1) var Xin : array; // conv hidden, [B*Ni] @group(0) @binding(2) var Wt : array; // [No*Ni] @group(0) @binding(3) var Bs : array; // [No] @group(0) @binding(4) var Tgt : array; // [B*No] @group(0) @binding(5) var Uo : array; // [B*No] fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } fn rho(u: f32) -> f32 { return sg(u); } @compute @workgroup_size(64, 1) fn dense_pass(@builtin(global_invocation_id) gid: vec3) { let b = gid.y; let i = gid.x; if (b >= p.B || i >= p.No) { return; } var c : f32 = Bs[i]; for (var j: u32 = 0u; j < p.Ni; j = j + 1u) { c = c + Wt[i * p.Ni + j] * rho(Xin[b * p.Ni + j]); } // STABILIZER #4 (dense output): clamp pre-activation drive too — same reasoning. c = clamp(c, -2.5, 3.5); let idx = b * p.No + i; let u_old = Uo[idx]; var drive : f32 = -u_old + sg(c); if (p.has_target != 0u && p.beta != 0.0) { drive = drive + p.beta * (Tgt[idx] - u_old); } Uo[idx] = u_old + p.dt * drive; } @compute @workgroup_size(64) fn init_state_out(@builtin(global_invocation_id) gid: vec3) { let g = gid.x; let n = p.B * p.No; if (g < n) { Uo[g] = 0.1; } } `; const WGSL_AUX_CONV = ` struct AP { B: u32, O: u32, convFlat: u32, _p0: u32, c_adp: f32, mode: f32, _p1: f32, _p2: f32, }; @group(0) @binding(0) var p : AP; @group(0) @binding(1) var UoF : array; @group(0) @binding(2) var Tgt : array; @group(0) @binding(3) var R : array; @group(0) @binding(4) var Uf : array; @group(0) @binding(5) var Up : array; @group(0) @binding(6) var Um : array; fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } fn rho_out(u: f32) -> f32 { return sg(u); } @compute @workgroup_size(64) fn compute_reward(@builtin(global_invocation_id) gid: vec3) { let b = gid.x; if (b >= p.B) { return; } var loss : f32 = 0.0; let off = b * p.O; for (var i: u32 = 0u; i < p.O; i = i + 1u) { let d = rho_out(UoF[off + i]) - Tgt[off + i]; loss = loss + d * d; } var r : f32 = loss / 0.4; if (r > 1.0) { r = 1.0; } R[b] = 0.1 + 0.9 * r; } @compute @workgroup_size(64) fn adapt_layer(@builtin(global_invocation_id) gid: vec3) { let stride = 65535u * 64u; let g = gid.y * stride + gid.x; if (g >= arrayLength(&Uf)) { return; } let f = Uf[g]; Up[g] = (1.0 - p.c_adp) * Up[g] + p.c_adp * f; Um[g] = (1.0 - p.c_adp) * Um[g] + p.c_adp * f; } `; const WGSL_GRAD_CONV = ` struct CGP { B: u32, Cin: u32, Cout: u32, H: u32, W: u32, Hout: u32, Wout: u32, KH: u32, KW: u32, stride: u32, pad: u32, _p0: u32, two_beta: f32, _p1: f32, _p2: f32, _p3: f32, }; @group(0) @binding(0) var p : CGP; @group(0) @binding(1) var Xp : array; @group(0) @binding(2) var Xm : array; @group(0) @binding(3) var Up : array; @group(0) @binding(4) var Um : array; @group(0) @binding(5) var R : array; @group(0) @binding(6) var gW : array; @group(0) @binding(7) var gB : array; fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } fn rho(u: f32) -> f32 { return sg(u); } @compute @workgroup_size(8, 8, 1) fn grad_W_conv(@builtin(global_invocation_id) gid: vec3) { let dx = gid.x; let dy = gid.y; let kk = gid.z; if (dx >= p.KW || dy >= p.KH) { return; } let kout = kk / p.Cin; let kin = kk % p.Cin; if (kout >= p.Cout) { return; } let img_size = p.Cin * p.H * p.W; let map_size = p.Cout * p.Hout * p.Wout; var acc : f32 = 0.0; for (var b: u32 = 0u; b < p.B; b = b + 1u) { let rb = R[b]; for (var yo: u32 = 0u; yo < p.Hout; yo = yo + 1u) { let iy_s = i32(yo * p.stride + dy) - i32(p.pad); if (iy_s < 0 || iy_s >= i32(p.H)) { continue; } let iy = u32(iy_s); for (var xo: u32 = 0u; xo < p.Wout; xo = xo + 1u) { let ix_s = i32(xo * p.stride + dx) - i32(p.pad); if (ix_s < 0 || ix_s >= i32(p.W)) { continue; } let ix = u32(ix_s); let u_flat = b * map_size + kout * p.Hout * p.Wout + yo * p.Wout + xo; let x_flat = b * img_size + kin * p.H * p.W + iy * p.W + ix; acc = acc + rb * (rho(Up[u_flat]) * rho(Xp[x_flat]) - rho(Um[u_flat]) * rho(Xm[x_flat])); } } } let widx = ((kout * p.Cin + kin) * p.KH + dy) * p.KW + dx; gW[widx] = acc / p.two_beta; } @compute @workgroup_size(64) fn grad_B_conv(@builtin(global_invocation_id) gid: vec3) { let kout = gid.x; if (kout >= p.Cout) { return; } let map_size = p.Cout * p.Hout * p.Wout; var acc : f32 = 0.0; for (var b: u32 = 0u; b < p.B; b = b + 1u) { let rb = R[b]; for (var yo: u32 = 0u; yo < p.Hout; yo = yo + 1u) { for (var xo: u32 = 0u; xo < p.Wout; xo = xo + 1u) { let u_flat = b * map_size + kout * p.Hout * p.Wout + yo * p.Wout + xo; acc = acc + rb * (rho(Up[u_flat]) - rho(Um[u_flat])); } } } gB[kout] = acc / p.two_beta; } `; const WGSL_GRAD_DENSE = ` struct DGP { B: u32, Ni: u32, No: u32, _p0: u32, two_beta: f32, _p1: f32, _p2: f32, _p3: f32, }; @group(0) @binding(0) var p : DGP; @group(0) @binding(1) var Xp : array; // pre-layer plus phase [B*Ni] @group(0) @binding(2) var Xm : array; @group(0) @binding(3) var Up : array; // post-layer plus phase [B*No] @group(0) @binding(4) var Um : array; @group(0) @binding(5) var R : array; @group(0) @binding(6) var gW : array; @group(0) @binding(7) var gB : array; fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } fn rho(u: f32) -> f32 { return sg(u); } @compute @workgroup_size(8, 8) fn grad_W_dense(@builtin(global_invocation_id) gid: vec3) { let i = gid.y; let j = gid.x; if (i >= p.No || j >= p.Ni) { return; } var acc : f32 = 0.0; for (var b: u32 = 0u; b < p.B; b = b + 1u) { let rb = R[b]; let ip = rho(Up[b * p.No + i]); let im = rho(Um[b * p.No + i]); let jp = rho(Xp[b * p.Ni + j]); let jm = rho(Xm[b * p.Ni + j]); acc = acc + rb * (ip * jp - im * jm); } gW[i * p.Ni + j] = acc / p.two_beta; } @compute @workgroup_size(64) fn grad_B_dense(@builtin(global_invocation_id) gid: vec3) { let i = gid.x; if (i >= p.No) { return; } var acc : f32 = 0.0; for (var b: u32 = 0u; b < p.B; b = b + 1u) { let rb = R[b]; acc = acc + rb * (rho(Up[b * p.No + i]) - rho(Um[b * p.No + i])); } gB[i] = acc / p.two_beta; } `; export async function makeGPUConv({powerPreference='high-performance'}={}){ if(!navigator.gpu) throw new Error('no webgpu'); const adapter = await navigator.gpu.requestAdapter({powerPreference}); if(!adapter) throw new Error('no adapter'); const want = {}; for(const k of ['maxStorageBuffersPerShaderStage','maxBufferSize','maxStorageBufferBindingSize', 'maxComputeInvocationsPerWorkgroup','maxComputeWorkgroupSizeX','maxComputeWorkgroupStorageSize','maxBindGroups']){ const v=adapter.limits[k]; if(typeof v==='number') want[k]=v; } const dev = await adapter.requestDevice({requiredLimits: want}); return {adapter, dev, info: adapter.info||{}}; } const PHASE_F = 0, PHASE_P = 1, PHASE_M = 2; export class GPUTrainerConvFull { constructor({dev, convCfg, denseSize, B}){ this.dev = dev; this.cfg = convCfg; this.B = B; this.O = denseSize; this.Hout = Math.floor((convCfg.H + 2*convCfg.pad - convCfg.KH)/convCfg.stride) + 1; this.Wout = Math.floor((convCfg.W + 2*convCfg.pad - convCfg.KW)/convCfg.stride) + 1; this.convFlat = convCfg.Cout * this.Hout * this.Wout; this._build(); } _F32(n, usage){ return this.dev.createBuffer({size: Math.max(4, n*4), usage}); } _build(){ const dev = this.dev, B = this.B, c = this.cfg, O = this.O; const RW = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST; const R = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC; const UNI= GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST; const RDS= GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ; const imgSize = c.Cin * c.H * c.W; const cw_n = c.Cout * c.Cin * c.KH * c.KW; // Input + targets (shared across phases) this.bufXin = this._F32(B * imgSize, R); this.bufTgt = this._F32(B * O, R); // Weights this.bufWconv = this._F32(cw_n, R); this.bufBconv = this._F32(c.Cout, R); this.bufWdense = this._F32(O * this.convFlat, R); this.bufBdense = this._F32(O, R); // Per-phase states this.bufUconv = []; this.bufUout = []; for(let phase=0; phase<3; phase++){ this.bufUconv.push(this._F32(B * this.convFlat, RW)); this.bufUout.push(this._F32(B * O, RW)); } // Reward + dummies (need separate dummy buffers for distinct read_write slots to avoid aliasing) this.bufR = this._F32(B, RW); this.bufDummyR = this._F32(4, R); this.bufDummyRW1 = this._F32(4, RW); this.bufDummyRW2 = this._F32(4, RW); this.bufDummyRW3 = this._F32(4, RW); // Gradient buffers this.bufGWconv = this._F32(cw_n, RW); this.bufGBconv = this._F32(c.Cout, RW); this.bufGWdense= this._F32(O * this.convFlat, RW); this.bufGBdense= this._F32(O, RW); // Readback this.rbUoF = dev.createBuffer({size: B*O*4, usage: RDS}); this.rbUconvF = dev.createBuffer({size: B*this.convFlat*4, usage: RDS}); this.rbGWconv = dev.createBuffer({size: cw_n*4, usage: RDS}); this.rbGBconv = dev.createBuffer({size: c.Cout*4, usage: RDS}); this.rbGWdense = dev.createBuffer({size: O*this.convFlat*4, usage: RDS}); this.rbGBdense = dev.createBuffer({size: O*4, usage: RDS}); // Uniform buffers: one per phase for conv (CP=80 bytes), one per phase for dense (DP=48 bytes), one for grads (CGP=64, DGP=32), aux (AP=32) this.bufP_conv = []; this.bufP_dense = []; this.bufP_init_conv = []; this.bufP_init_dense = []; for(let phase=0; phase<3; phase++){ this.bufP_conv.push(dev.createBuffer({size:80, usage:UNI})); this.bufP_dense.push(dev.createBuffer({size:48, usage:UNI})); this.bufP_init_conv.push(dev.createBuffer({size:80, usage:UNI})); this.bufP_init_dense.push(dev.createBuffer({size:48, usage:UNI})); } this.bufP_grad_conv = dev.createBuffer({size:64, usage:UNI}); this.bufP_grad_dense = dev.createBuffer({size:32, usage:UNI}); this.bufP_rew = dev.createBuffer({size:32, usage:UNI}); this.bufP_adapt_conv = dev.createBuffer({size:32, usage:UNI}); this.bufP_adapt_out = dev.createBuffer({size:32, usage:UNI}); // Pipelines const sR = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'read-only-storage'}}); const sRW = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'storage'}}); const uN = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'uniform'}}); // ---- Conv relax pipeline ---- const modConv = dev.createShaderModule({code: WGSL_CONV_RELAX}); this.bglConv = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sRW(5), sR(6)]}); this.plConv = dev.createPipelineLayout({bindGroupLayouts:[this.bglConv]}); this.pipeConv = dev.createComputePipeline({layout:this.plConv, compute:{module:modConv, entryPoint:'conv_pass'}}); this.pipeInitConv = dev.createComputePipeline({layout:this.plConv, compute:{module:modConv, entryPoint:'init_state'}}); // ---- Dense relax pipeline ---- const modDense = dev.createShaderModule({code: WGSL_DENSE_OUT}); this.bglDense = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sRW(5)]}); this.plDense = dev.createPipelineLayout({bindGroupLayouts:[this.bglDense]}); this.pipeDense = dev.createComputePipeline({layout:this.plDense, compute:{module:modDense, entryPoint:'dense_pass'}}); this.pipeInitDense= dev.createComputePipeline({layout:this.plDense, compute:{module:modDense, entryPoint:'init_state_out'}}); // ---- Aux (reward + adapt) pipeline ---- const modAux = dev.createShaderModule({code: WGSL_AUX_CONV}); this.bglAux = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sRW(3), sR(4), sRW(5), sRW(6)]}); this.plAux = dev.createPipelineLayout({bindGroupLayouts:[this.bglAux]}); this.pipeReward = dev.createComputePipeline({layout:this.plAux, compute:{module:modAux, entryPoint:'compute_reward'}}); this.pipeAdapt = dev.createComputePipeline({layout:this.plAux, compute:{module:modAux, entryPoint:'adapt_layer'}}); // ---- Grad conv pipeline ---- const modGC = dev.createShaderModule({code: WGSL_GRAD_CONV}); this.bglGC = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]}); this.plGC = dev.createPipelineLayout({bindGroupLayouts:[this.bglGC]}); this.pipeGWconv = dev.createComputePipeline({layout:this.plGC, compute:{module:modGC, entryPoint:'grad_W_conv'}}); this.pipeGBconv = dev.createComputePipeline({layout:this.plGC, compute:{module:modGC, entryPoint:'grad_B_conv'}}); // ---- Grad dense pipeline ---- const modGD = dev.createShaderModule({code: WGSL_GRAD_DENSE}); this.bglGD = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]}); this.plGD = dev.createPipelineLayout({bindGroupLayouts:[this.bglGD]}); this.pipeGWdense = dev.createComputePipeline({layout:this.plGD, compute:{module:modGD, entryPoint:'grad_W_dense'}}); this.pipeGBdense = dev.createComputePipeline({layout:this.plGD, compute:{module:modGD, entryPoint:'grad_B_dense'}}); // ---- Bind groups ---- // Conv per phase (relax + init each use same layout, different uniform buffer) this.bgConv = [], this.bgInitConv = []; for(let phase=0; phase<3; phase++){ this.bgConv.push(dev.createBindGroup({layout:this.bglConv, entries:[ {binding:0, resource:{buffer:this.bufP_conv[phase]}}, {binding:1, resource:{buffer:this.bufXin}}, {binding:2, resource:{buffer:this.bufWconv}}, {binding:3, resource:{buffer:this.bufBconv}}, {binding:4, resource:{buffer:this.bufWdense}}, {binding:5, resource:{buffer:this.bufUconv[phase]}}, {binding:6, resource:{buffer:this.bufUout[phase]}}, ]})); this.bgInitConv.push(dev.createBindGroup({layout:this.bglConv, entries:[ {binding:0, resource:{buffer:this.bufP_init_conv[phase]}}, {binding:1, resource:{buffer:this.bufXin}}, {binding:2, resource:{buffer:this.bufWconv}}, {binding:3, resource:{buffer:this.bufBconv}}, {binding:4, resource:{buffer:this.bufWdense}}, {binding:5, resource:{buffer:this.bufUconv[phase]}}, {binding:6, resource:{buffer:this.bufUout[phase]}}, ]})); } // Dense per phase this.bgDense = [], this.bgInitDense = []; for(let phase=0; phase<3; phase++){ this.bgDense.push(dev.createBindGroup({layout:this.bglDense, entries:[ {binding:0, resource:{buffer:this.bufP_dense[phase]}}, {binding:1, resource:{buffer:this.bufUconv[phase]}}, // input = conv hidden {binding:2, resource:{buffer:this.bufWdense}}, {binding:3, resource:{buffer:this.bufBdense}}, {binding:4, resource:{buffer:this.bufTgt}}, {binding:5, resource:{buffer:this.bufUout[phase]}}, ]})); this.bgInitDense.push(dev.createBindGroup({layout:this.bglDense, entries:[ {binding:0, resource:{buffer:this.bufP_init_dense[phase]}}, {binding:1, resource:{buffer:this.bufUconv[phase]}}, {binding:2, resource:{buffer:this.bufWdense}}, {binding:3, resource:{buffer:this.bufBdense}}, {binding:4, resource:{buffer:this.bufTgt}}, {binding:5, resource:{buffer:this.bufUout[phase]}}, ]})); } // Aux: reward (uses Uo_free) + adapt (conv, out) this.bgRew = dev.createBindGroup({layout:this.bglAux, entries:[ {binding:0, resource:{buffer:this.bufP_rew}}, {binding:1, resource:{buffer:this.bufUout[PHASE_F]}}, {binding:2, resource:{buffer:this.bufTgt}}, {binding:3, resource:{buffer:this.bufR}}, {binding:4, resource:{buffer:this.bufDummyR}}, {binding:5, resource:{buffer:this.bufDummyRW1}}, {binding:6, resource:{buffer:this.bufDummyRW2}}, ]}); this.bgAdaptConv = dev.createBindGroup({layout:this.bglAux, entries:[ {binding:0, resource:{buffer:this.bufP_adapt_conv}}, {binding:1, resource:{buffer:this.bufDummyR}}, {binding:2, resource:{buffer:this.bufDummyR}}, {binding:3, resource:{buffer:this.bufDummyRW3}}, {binding:4, resource:{buffer:this.bufUconv[PHASE_F]}}, {binding:5, resource:{buffer:this.bufUconv[PHASE_P]}}, {binding:6, resource:{buffer:this.bufUconv[PHASE_M]}}, ]}); this.bgAdaptOut = dev.createBindGroup({layout:this.bglAux, entries:[ {binding:0, resource:{buffer:this.bufP_adapt_out}}, {binding:1, resource:{buffer:this.bufDummyR}}, {binding:2, resource:{buffer:this.bufDummyR}}, {binding:3, resource:{buffer:this.bufDummyRW3}}, {binding:4, resource:{buffer:this.bufUout[PHASE_F]}}, {binding:5, resource:{buffer:this.bufUout[PHASE_P]}}, {binding:6, resource:{buffer:this.bufUout[PHASE_M]}}, ]}); // Grad this.bgGC = dev.createBindGroup({layout:this.bglGC, entries:[ {binding:0, resource:{buffer:this.bufP_grad_conv}}, {binding:1, resource:{buffer:this.bufXin}}, // Xp (conv input) {binding:2, resource:{buffer:this.bufXin}}, // Xm (same input — no augmentation in two phases) {binding:3, resource:{buffer:this.bufUconv[PHASE_P]}}, {binding:4, resource:{buffer:this.bufUconv[PHASE_M]}}, {binding:5, resource:{buffer:this.bufR}}, {binding:6, resource:{buffer:this.bufGWconv}}, {binding:7, resource:{buffer:this.bufGBconv}}, ]}); this.bgGD = dev.createBindGroup({layout:this.bglGD, entries:[ {binding:0, resource:{buffer:this.bufP_grad_dense}}, {binding:1, resource:{buffer:this.bufUconv[PHASE_P]}}, {binding:2, resource:{buffer:this.bufUconv[PHASE_M]}}, {binding:3, resource:{buffer:this.bufUout[PHASE_P]}}, {binding:4, resource:{buffer:this.bufUout[PHASE_M]}}, {binding:5, resource:{buffer:this.bufR}}, {binding:6, resource:{buffer:this.bufGWdense}}, {binding:7, resource:{buffer:this.bufGBdense}}, ]}); } _writeConvParams(buf, fields){ const u32 = new Uint32Array(20); const f32 = new Float32Array(u32.buffer); u32[0]=this.B; u32[1]=this.cfg.Cin; u32[2]=this.cfg.Cout; u32[3]=this.cfg.H; u32[4]=this.cfg.W; u32[5]=this.Hout; u32[6]=this.Wout; u32[7]=this.cfg.KH; u32[8]=this.cfg.KW; u32[9]=this.cfg.stride; u32[10]=this.cfg.pad; u32[11]=0; f32[12]=fields.dt||0; f32[13]=fields.beta||0; f32[14]=fields.gamma||0; f32[15]=0; u32[16]=fields.has_topdown||0; u32[17]=this.O; u32[18]=fields.has_target||0; u32[19]=0; this.dev.queue.writeBuffer(buf, 0, u32.buffer); } _writeDenseParams(buf, fields){ const u32 = new Uint32Array(12); const f32 = new Float32Array(u32.buffer); u32[0]=this.B; u32[1]=this.convFlat; u32[2]=this.O; u32[3]=0; f32[4]=fields.dt||0; f32[5]=fields.beta||0; f32[6]=0; f32[7]=0; u32[8]=fields.has_target||0; u32[9]=0; u32[10]=0; u32[11]=0; this.dev.queue.writeBuffer(buf, 0, u32.buffer); } _writeGradConvParams(two_beta){ const u32 = new Uint32Array(16); const f32 = new Float32Array(u32.buffer); u32[0]=this.B; u32[1]=this.cfg.Cin; u32[2]=this.cfg.Cout; u32[3]=this.cfg.H; u32[4]=this.cfg.W; u32[5]=this.Hout; u32[6]=this.Wout; u32[7]=this.cfg.KH; u32[8]=this.cfg.KW; u32[9]=this.cfg.stride; u32[10]=this.cfg.pad; u32[11]=0; f32[12]=two_beta; f32[13]=0; f32[14]=0; f32[15]=0; this.dev.queue.writeBuffer(this.bufP_grad_conv, 0, u32.buffer); } _writeGradDenseParams(two_beta){ const u32 = new Uint32Array(8); const f32 = new Float32Array(u32.buffer); u32[0]=this.B; u32[1]=this.convFlat; u32[2]=this.O; u32[3]=0; f32[4]=two_beta; f32[5]=0; f32[6]=0; f32[7]=0; this.dev.queue.writeBuffer(this.bufP_grad_dense, 0, u32.buffer); } _writeAuxParams(buf, fields){ const u32 = new Uint32Array(8); const f32 = new Float32Array(u32.buffer); u32[0]=this.B; u32[1]=this.O; u32[2]=this.convFlat; u32[3]=0; f32[4]=fields.c_adp||0; f32[5]=fields.mode||0; f32[6]=0; f32[7]=0; this.dev.queue.writeBuffer(buf, 0, u32.buffer); } uploadWeights(Wconv, bconv, Wdense, bdense){ const q = this.dev.queue; q.writeBuffer(this.bufWconv, 0, Wconv.buffer, Wconv.byteOffset, Wconv.byteLength); q.writeBuffer(this.bufBconv, 0, bconv.buffer, bconv.byteOffset, bconv.byteLength); q.writeBuffer(this.bufWdense, 0, Wdense.buffer, Wdense.byteOffset, Wdense.byteLength); q.writeBuffer(this.bufBdense, 0, bdense.buffer, bdense.byteOffset, bdense.byteLength); } uploadInputs(X, T){ const q = this.dev.queue; q.writeBuffer(this.bufXin, 0, X.buffer, X.byteOffset, X.byteLength); q.writeBuffer(this.bufTgt, 0, T.buffer, T.byteOffset, T.byteLength); } _writeAllPhaseUniforms(dt, beta, gamma){ // free phase: beta=0 this._writeConvParams (this.bufP_conv[0], {dt, beta:0, gamma, has_topdown:1, has_target:0}); this._writeConvParams (this.bufP_conv[1], {dt, beta:0, gamma, has_topdown:1, has_target:0}); this._writeConvParams (this.bufP_conv[2], {dt, beta:0, gamma, has_topdown:1, has_target:0}); this._writeDenseParams(this.bufP_dense[0], {dt, beta:0, has_target:0}); this._writeDenseParams(this.bufP_dense[1], {dt, beta:+beta, has_target:1}); this._writeDenseParams(this.bufP_dense[2], {dt, beta:-beta, has_target:1}); // init uniforms (beta=0, same as relax) for(let phase=0; phase<3; phase++){ this._writeConvParams(this.bufP_init_conv[phase], {dt, beta:0, gamma:0, has_topdown:0, has_target:0}); this._writeDenseParams(this.bufP_init_dense[phase],{dt, beta:0, has_target:0}); } } _initAllPhases(enc){ const MAX_WG_X = 65535; for(let phase=0; phase<3; phase++){ // conv init const nc = this.B * this.convFlat; const wgC = Math.ceil(nc/64); let p = enc.beginComputePass(); p.setPipeline(this.pipeInitConv); p.setBindGroup(0, this.bgInitConv[phase]); p.dispatchWorkgroups(Math.min(wgC, MAX_WG_X), Math.ceil(wgC/MAX_WG_X)); p.end(); // dense init const no = this.B * this.O; p = enc.beginComputePass(); p.setPipeline(this.pipeInitDense); p.setBindGroup(0, this.bgInitDense[phase]); p.dispatchWorkgroups(Math.ceil(no/64)); p.end(); } } _runPhaseRelax(enc, phase, iters){ const B = this.B; for(let t=0; t