Spaces:
Running
Running
| // Tier D — multi-layer convolutional EqProp trainer (end-to-end, not greedy). | |
| // Architecture: N conv layers followed by one dense readout. All algorithmic — N is determined | |
| // by the length of the convCfgs array passed to the constructor. | |
| // | |
| // Critical design: | |
| // * Each conv layer has its own W, b, U-state-per-phase buffer set. | |
| // * Each conv layer receives top-down feedback from the next layer: | |
| // - For NON-last conv (layer l<N-1): top-down comes from conv layer l+1 via TRANSPOSED-CONV | |
| // (deconv) with that layer's kernel. WITHOUT this top-down, the +β and -β perturbations of | |
| // the output dense layer would never propagate back to lower conv layers → those layers | |
| // would receive identical states in both phases → gradient = (σ(u+)σ(x+) - σ(u-)σ(x-))/2β | |
| // would be exactly zero → no learning → the architecture would BE greedy by accident. | |
| // - For the LAST conv (layer l=N-1): top-down comes from the dense readout (same as single-conv v03). | |
| // * Dense layer gets the ±β target nudge in +/- phases (canonical EqProp output nudging). | |
| // * Gradient per layer is computed exactly like single-conv: outer product accumulator over batch | |
| // of σ(u_pre_+)σ(u_post_+) − σ(u_pre_-)σ(u_post_-), divided by 2β, with reward-weighting. | |
| // | |
| // Conv-transpose constraint (v1): | |
| // The top-down WGSL kernel currently supports the conv-transpose for ANY (kH, kW, stride, pad) | |
| // by iterating kernel offsets and recovering the inverse spatial mapping algorithmically. | |
| // When stride > 1 the conv-transpose covers strictly fewer positions per kernel offset (some | |
| // (iy, ix) have no preimage in (yo, xo)) — this is handled by `if(integer && in_range)` checks. | |
| import { orth as orthCPU } from './eqprop_lib.js'; | |
| const PHASE_F = 0, PHASE_P = 1, PHASE_M = 2; | |
| // WGSL: bottom-up conv pass. | |
| // Output: writes u_state[b, k, y, x] = u_old + dt * (-u_old + σ(c)) | |
| // where c = bias + Σ kernel·input + (gamma * top-down if has_topdown). | |
| // has_topdown_type: 0=none, 1=dense-next (Wnxt is [O × this_flat]), 2=conv-next (Wnxt is conv kernel [Cnxt × Cthis × KHnxt × KWnxt]) | |
| const WGSL_CONV_RELAX_MULTI = ` | |
| struct CP { | |
| B: u32, Cin: u32, Cout: u32, H: u32, | |
| W: u32, Hout: u32, Wout: u32, KH: u32, | |
| KW: u32, stride: u32, pad: u32, _p0: u32, | |
| dt: f32, beta_unused: f32, gamma: f32, mode: f32, | |
| has_topdown_type: u32, nxt_O: u32, nxt_KH: u32, nxt_KW: u32, | |
| nxt_stride: u32, nxt_pad: u32, nxt_Cnxt: u32, _p2: u32, | |
| clamp_lo: f32, clamp_hi: f32, triangle_offset: f32, triangle_power: f32, | |
| // MSMEN-MVT: stochastic spike-sampling mode (subset of Tempered Markov Energy Network) | |
| // spike_mode > 0: at each relax iter, sample s ~ Bernoulli(σ(c)) using iter_seed-derived PCG hash; | |
| // update u via running mean of spikes so the dense readout sees a fair estimate. | |
| // For inference-time M-sample ensembling, caller sets a different iter_seed_base per sample. | |
| spike_mode: u32, iter_index: u32, iter_seed_base: u32, _p3: u32, | |
| // SI-5: dense → conv skip connection. When has_skip=1, conv layer reads an | |
| // ADDITIONAL top-down from the LAST DENSE LAYER via a learnable W_skip[skip_O × this_flat]. | |
| // Bypasses γ^L attenuation in deep stacks. skip_gamma controls its strength independently. | |
| has_skip: u32, skip_O: u32, skip_gamma: f32, _p4: u32, | |
| }; | |
| @group(0) @binding(0) var<uniform> p : CP; | |
| @group(0) @binding(1) var<storage, read> Xin : array<f32>; // [B*Cin*H*W] input map | |
| @group(0) @binding(2) var<storage, read> Wt : array<f32>; // [Cout*Cin*KH*KW] | |
| @group(0) @binding(3) var<storage, read> Bs : array<f32>; // [Cout] | |
| @group(0) @binding(4) var<storage, read> Wnxt: array<f32>; // top-down weights (dense or conv kernel) | |
| @group(0) @binding(5) var<storage, read_write> Uh : array<f32>; // [B*Cout*Hout*Wout] | |
| @group(0) @binding(6) var<storage, read> Unxt: array<f32>; // [B*nxt_O] dense or [B*Cnxt*Hnxt*Wnxt] conv | |
| @group(0) @binding(7) var<storage, read> Tau : array<f32>; // [Cout] per-channel τ (HPSN); broadcast across spatial | |
| @group(0) @binding(8) var<storage, read> Wskip: array<f32>; // SI-5 [skip_O × this_flat] dense→conv skip W | |
| @group(0) @binding(9) var<storage, read> Uskip: array<f32>; // SI-5 [B × skip_O] last dense's state | |
| // Activations supported (mode flag): | |
| // 0 = adaptive σ (default, baseline) | |
| // 1 = fhn clip ρ(u) = clamp(u, 0, 1) | |
| // 2 = prism softplus smooth approximation of clip with bilateral gradient | |
| // 3 = triangle Krotov ρ(u) = max(0, u - triangle_offset)^triangle_power | |
| // — offset is set externally (algorithmic; commonly the per-layer mean) | |
| // — power=1 gives RePU; power=2 gives RePU² | |
| // | |
| // Tau is per-output-channel time constant; replaces global p.dt. Constant Tau[k]=p.dt → identical | |
| // to scalar-dt behavior (used for backward-compat default). | |
| const PRISM_K : f32 = 10.0; | |
| fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } | |
| fn softplus_safe(x: f32) -> f32 { return select(x + log(1.0 + exp(-x)), log(1.0 + exp(x)), x <= 0.0); } | |
| fn prism_rho_c(u: f32) -> f32 { return (softplus_safe(PRISM_K * u) - softplus_safe(PRISM_K * (u - 1.0))) / PRISM_K; } | |
| fn triangle_rho_c(u: f32, off: f32, pwr: f32) -> f32 { | |
| let z = u - off; | |
| if (z <= 0.0) { return 0.0; } | |
| if (pwr == 1.0) { return z; } | |
| if (pwr == 2.0) { return z * z; } | |
| return pow(z, pwr); | |
| } | |
| fn rho(u: f32) -> f32 { | |
| // mode dispatch: 0 sigma, 1 clip, 2 prism, 3 triangle. | |
| // p.mode is uniform; all branches compile, one path runs per thread. | |
| if (p.mode > 2.5) { return triangle_rho_c(u, p.triangle_offset, p.triangle_power); } | |
| if (p.mode > 1.5) { return prism_rho_c(u); } | |
| if (p.mode > 0.5) { return clamp(u, 0.0, 1.0); } | |
| return sg(u); | |
| } | |
| // MSMEN-MVT: PCG-hash uniform sample in [0, 1). Deterministic for given seed. | |
| fn pcg_u32(seed_in: u32) -> u32 { | |
| var state : u32 = seed_in * 747796405u + 2891336453u; | |
| let word : u32 = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; | |
| return (word >> 22u) ^ word; | |
| } | |
| fn pcg_unit(b: u32, i: u32, t: u32, base: u32) -> f32 { | |
| // Compose per-(batch, neuron, iter, seed_base) — independent samples across all axes. | |
| let s = b * 1000003u + i * 2654435761u + t * 374761393u + base * 2246822519u; | |
| return f32(pcg_u32(s)) / 4294967296.0; | |
| } | |
| @compute @workgroup_size(8, 8, 1) fn conv_pass(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let xo = gid.x; let yo = gid.y; let bk = gid.z; | |
| if (xo >= p.Wout || yo >= p.Hout) { return; } | |
| let b = bk / p.Cout; let k = bk % p.Cout; | |
| if (b >= p.B) { return; } | |
| let img_size = p.Cin * p.H * p.W; | |
| let map_size = p.Cout * p.Hout * p.Wout; | |
| // Bottom-up | |
| var c : f32 = Bs[k]; | |
| for (var kin: u32 = 0u; kin < p.Cin; kin = kin + 1u) { | |
| for (var dy: u32 = 0u; dy < p.KH; dy = dy + 1u) { | |
| let iy_s = i32(yo * p.stride + dy) - i32(p.pad); | |
| if (iy_s < 0 || iy_s >= i32(p.H)) { continue; } | |
| let iy = u32(iy_s); | |
| for (var dx: u32 = 0u; dx < p.KW; dx = dx + 1u) { | |
| let ix_s = i32(xo * p.stride + dx) - i32(p.pad); | |
| if (ix_s < 0 || ix_s >= i32(p.W)) { continue; } | |
| let ix = u32(ix_s); | |
| let xidx = b * img_size + kin * p.H * p.W + iy * p.W + ix; | |
| let widx = ((k * p.Cin + kin) * p.KH + dy) * p.KW + dx; | |
| c = c + Wt[widx] * rho(Xin[xidx]); | |
| } | |
| } | |
| } | |
| // Top-down | |
| if (p.has_topdown_type == 1u) { | |
| // Dense next: Wnxt is [O × map_size], Unxt is [B × O] | |
| var td : f32 = 0.0; | |
| let this_flat = k * p.Hout * p.Wout + yo * p.Wout + xo; | |
| for (var n: u32 = 0u; n < p.nxt_O; n = n + 1u) { | |
| td = td + Wnxt[n * map_size + this_flat] * rho(Unxt[b * p.nxt_O + n]); | |
| } | |
| c = c + p.gamma * td; | |
| } else if (p.has_topdown_type == 2u) { | |
| // Conv next: TRANSPOSED CONV (deconv). | |
| // Next layer output u[b, k_nxt, yo_nxt, xo_nxt] receives contribution from | |
| // THIS layer position (k, yo, xo) via kernel offset (dy_nxt, dx_nxt) when | |
| // yo_nxt * nxt_stride + dy_nxt - nxt_pad == yo | |
| // xo_nxt * nxt_stride + dx_nxt - nxt_pad == xo | |
| // So for this position, the top-down sum reads back ALL next-layer outputs that read FROM here. | |
| // Iterate kernel offsets; for each, compute the next-layer position that would have used this one. | |
| var td : f32 = 0.0; | |
| // Compute spatial shape of next conv layer based on this layer's output shape: | |
| // Hnxt = floor((Hout + 2*nxt_pad - nxt_KH)/nxt_stride) + 1 | |
| let Hnxt : u32 = (p.Hout + 2u*p.nxt_pad - p.nxt_KH) / p.nxt_stride + 1u; | |
| let Wnxt_s : u32 = (p.Wout + 2u*p.nxt_pad - p.nxt_KW) / p.nxt_stride + 1u; | |
| let nxt_map_size = p.nxt_Cnxt * Hnxt * Wnxt_s; | |
| for (var k_nxt: u32 = 0u; k_nxt < p.nxt_Cnxt; k_nxt = k_nxt + 1u) { | |
| for (var dy_nxt: u32 = 0u; dy_nxt < p.nxt_KH; dy_nxt = dy_nxt + 1u) { | |
| // yo_nxt_s = (yo + nxt_pad - dy_nxt). Must be divisible by nxt_stride and in [0, Hnxt). | |
| let yo_nxt_s = i32(yo) + i32(p.nxt_pad) - i32(dy_nxt); | |
| if (yo_nxt_s < 0) { continue; } | |
| let yo_nxt_u = u32(yo_nxt_s); | |
| if (yo_nxt_u % p.nxt_stride != 0u) { continue; } | |
| let yo_nxt = yo_nxt_u / p.nxt_stride; | |
| if (yo_nxt >= Hnxt) { continue; } | |
| for (var dx_nxt: u32 = 0u; dx_nxt < p.nxt_KW; dx_nxt = dx_nxt + 1u) { | |
| let xo_nxt_s = i32(xo) + i32(p.nxt_pad) - i32(dx_nxt); | |
| if (xo_nxt_s < 0) { continue; } | |
| let xo_nxt_u = u32(xo_nxt_s); | |
| if (xo_nxt_u % p.nxt_stride != 0u) { continue; } | |
| let xo_nxt = xo_nxt_u / p.nxt_stride; | |
| if (xo_nxt >= Wnxt_s) { continue; } | |
| // Kernel weight: W[k_nxt, this_kin=k, dy_nxt, dx_nxt] | |
| let widx_nxt = ((k_nxt * p.Cout + k) * p.nxt_KH + dy_nxt) * p.nxt_KW + dx_nxt; | |
| let uidx_nxt = b * nxt_map_size + k_nxt * Hnxt * Wnxt_s + yo_nxt * Wnxt_s + xo_nxt; | |
| td = td + Wnxt[widx_nxt] * rho(Unxt[uidx_nxt]); | |
| } | |
| } | |
| } | |
| c = c + p.gamma * td; | |
| } | |
| // SI-5: dense → conv skip top-down (in ADDITION to existing chain top-down). | |
| if (p.has_skip != 0u) { | |
| var td_skip : f32 = 0.0; | |
| let this_flat = k * p.Hout * p.Wout + yo * p.Wout + xo; | |
| for (var n: u32 = 0u; n < p.skip_O; n = n + 1u) { | |
| td_skip = td_skip + Wskip[n * map_size + this_flat] * rho(Uskip[b * p.skip_O + n]); | |
| } | |
| c = c + p.skip_gamma * td_skip; | |
| } | |
| // Tier A — pre-σ drive clamp (active iff clamp_hi > clamp_lo) | |
| if (p.clamp_hi > p.clamp_lo) { c = clamp(c, p.clamp_lo, p.clamp_hi); } | |
| let idx = b * map_size + k * p.Hout * p.Wout + yo * p.Wout + xo; | |
| let u_old = Uh[idx]; | |
| let p_spike = rho(c); | |
| if (p.spike_mode != 0u) { | |
| // MSMEN-MVT: stochastic spike sampling. Running mean of binary spikes is the | |
| // unbiased estimator of σ(c) — matches deterministic in expectation, adds variance | |
| // per-iter that decorrelates samples (M-sample ensemble at inference). | |
| // n = iter_index + 1 (avoid /0 on first iter) | |
| let s_t = select(0.0, 1.0, pcg_unit(b, k * p.Hout * p.Wout + yo * p.Wout + xo, p.iter_index, p.iter_seed_base) < p_spike); | |
| let n = f32(p.iter_index + 1u); | |
| Uh[idx] = (1.0 - 1.0/n) * u_old + (1.0/n) * s_t; | |
| } else { | |
| // Deterministic adaptive σ update — v07 default behavior. | |
| let drive = -u_old + p_spike; | |
| Uh[idx] = u_old + Tau[k] * drive; | |
| } | |
| } | |
| @compute @workgroup_size(64) fn init_state(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let stride = 65535u * 64u; | |
| let g = gid.y * stride + gid.x; | |
| let n = p.B * p.Cout * p.Hout * p.Wout; | |
| if (g < n) { Uh[g] = 0.1; } | |
| } | |
| `; | |
| // Dense layer relax — supports BOTH: | |
| // * Output dense (last in dense chain): has_target=1, gets +/-β nudge in plus/minus phase | |
| // * Hidden dense (Tier E — heterogeneous trainer): has_topdown=1, reads next-dense via Wnxt | |
| // Wnxt layout: [Nnxt x No]; if has_topdown=0 the binding can be a dummy buffer. | |
| const WGSL_DENSE_OUT_MULTI = ` | |
| struct DP { | |
| B: u32, Ni: u32, No: u32, Nnxt: u32, | |
| dt: f32, beta: f32, gamma: f32, _p2: f32, | |
| has_target: u32, has_topdown: u32, _p4: u32, _p5: u32, | |
| }; | |
| @group(0) @binding(0) var<uniform> p : DP; | |
| @group(0) @binding(1) var<storage, read> Xin : array<f32>; // [B*Ni] | |
| @group(0) @binding(2) var<storage, read> Wt : array<f32>; // [No*Ni] | |
| @group(0) @binding(3) var<storage, read> Bs : array<f32>; // [No] | |
| @group(0) @binding(4) var<storage, read> Wnxt: array<f32>; // [Nnxt*No] top-down weights (dummy if has_topdown=0) | |
| @group(0) @binding(5) var<storage, read> Unxt: array<f32>; // [B*Nnxt] next-layer state (dummy if has_topdown=0) | |
| @group(0) @binding(6) var<storage, read> Tgt : array<f32>; // [B*No] | |
| @group(0) @binding(7) var<storage, read_write> Uo : array<f32>; // [B*No] | |
| fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } | |
| fn rho(u: f32) -> f32 { return sg(u); } | |
| @compute @workgroup_size(64, 1) fn dense_pass(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let b = gid.y; let i = gid.x; | |
| if (b >= p.B || i >= p.No) { return; } | |
| var c : f32 = Bs[i]; | |
| for (var j: u32 = 0u; j < p.Ni; j = j + 1u) { | |
| c = c + Wt[i * p.Ni + j] * rho(Xin[b * p.Ni + j]); | |
| } | |
| if (p.has_topdown != 0u) { | |
| var td : f32 = 0.0; | |
| for (var k: u32 = 0u; k < p.Nnxt; k = k + 1u) { | |
| td = td + Wnxt[k * p.No + i] * rho(Unxt[b * p.Nnxt + k]); | |
| } | |
| c = c + p.gamma * td; | |
| } | |
| let idx = b * p.No + i; | |
| let u_old = Uo[idx]; | |
| var drive : f32 = -u_old + sg(c); | |
| if (p.has_target != 0u && p.beta != 0.0) { | |
| drive = drive + p.beta * (Tgt[idx] - u_old); | |
| } | |
| Uo[idx] = u_old + p.dt * drive; | |
| } | |
| @compute @workgroup_size(64) fn init_state_out(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let g = gid.x; let n = p.B * p.No; | |
| if (g < n) { Uo[g] = 0.1; } | |
| } | |
| `; | |
| // Gradient kernels per layer (conv & dense) — identical to single-conv lib. | |
| const WGSL_GRAD_CONV_MULTI = ` | |
| struct CGP { | |
| B: u32, Cin: u32, Cout: u32, H: u32, | |
| W: u32, Hout: u32, Wout: u32, KH: u32, | |
| KW: u32, stride: u32, pad: u32, _p0: u32, | |
| two_beta: f32, _p1: f32, _p2: f32, _p3: f32, | |
| }; | |
| @group(0) @binding(0) var<uniform> p : CGP; | |
| @group(0) @binding(1) var<storage, read> Xp : array<f32>; | |
| @group(0) @binding(2) var<storage, read> Xm : array<f32>; | |
| @group(0) @binding(3) var<storage, read> Up : array<f32>; | |
| @group(0) @binding(4) var<storage, read> Um : array<f32>; | |
| @group(0) @binding(5) var<storage, read> R : array<f32>; | |
| @group(0) @binding(6) var<storage, read_write> gW : array<f32>; | |
| @group(0) @binding(7) var<storage, read_write> gB : array<f32>; | |
| fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } | |
| fn rho(u: f32) -> f32 { return sg(u); } | |
| @compute @workgroup_size(8, 8, 1) fn grad_W_conv(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let dx = gid.x; let dy = gid.y; let kk = gid.z; | |
| if (dx >= p.KW || dy >= p.KH) { return; } | |
| let kout = kk / p.Cin; let kin = kk % p.Cin; | |
| if (kout >= p.Cout) { return; } | |
| let img_size = p.Cin * p.H * p.W; | |
| let map_size = p.Cout * p.Hout * p.Wout; | |
| var acc : f32 = 0.0; | |
| for (var b: u32 = 0u; b < p.B; b = b + 1u) { | |
| let rb = R[b]; | |
| for (var yo: u32 = 0u; yo < p.Hout; yo = yo + 1u) { | |
| let iy_s = i32(yo * p.stride + dy) - i32(p.pad); | |
| if (iy_s < 0 || iy_s >= i32(p.H)) { continue; } | |
| let iy = u32(iy_s); | |
| for (var xo: u32 = 0u; xo < p.Wout; xo = xo + 1u) { | |
| let ix_s = i32(xo * p.stride + dx) - i32(p.pad); | |
| if (ix_s < 0 || ix_s >= i32(p.W)) { continue; } | |
| let ix = u32(ix_s); | |
| let u_flat = b * map_size + kout * p.Hout * p.Wout + yo * p.Wout + xo; | |
| let x_flat = b * img_size + kin * p.H * p.W + iy * p.W + ix; | |
| acc = acc + rb * (rho(Up[u_flat]) * rho(Xp[x_flat]) - rho(Um[u_flat]) * rho(Xm[x_flat])); | |
| } | |
| } | |
| } | |
| let widx = ((kout * p.Cin + kin) * p.KH + dy) * p.KW + dx; | |
| gW[widx] = acc / p.two_beta; | |
| } | |
| @compute @workgroup_size(64) fn grad_B_conv(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let kout = gid.x; | |
| if (kout >= p.Cout) { return; } | |
| let map_size = p.Cout * p.Hout * p.Wout; | |
| var acc : f32 = 0.0; | |
| for (var b: u32 = 0u; b < p.B; b = b + 1u) { | |
| let rb = R[b]; | |
| for (var yo: u32 = 0u; yo < p.Hout; yo = yo + 1u) { | |
| for (var xo: u32 = 0u; xo < p.Wout; xo = xo + 1u) { | |
| let u_flat = b * map_size + kout * p.Hout * p.Wout + yo * p.Wout + xo; | |
| acc = acc + rb * (rho(Up[u_flat]) - rho(Um[u_flat])); | |
| } | |
| } | |
| } | |
| gB[kout] = acc / p.two_beta; | |
| } | |
| `; | |
| const WGSL_GRAD_DENSE_MULTI = ` | |
| struct DGP { | |
| B: u32, Ni: u32, No: u32, _p0: u32, | |
| two_beta: f32, _p1: f32, _p2: f32, _p3: f32, | |
| }; | |
| @group(0) @binding(0) var<uniform> p : DGP; | |
| @group(0) @binding(1) var<storage, read> Xp : array<f32>; | |
| @group(0) @binding(2) var<storage, read> Xm : array<f32>; | |
| @group(0) @binding(3) var<storage, read> Up : array<f32>; | |
| @group(0) @binding(4) var<storage, read> Um : array<f32>; | |
| @group(0) @binding(5) var<storage, read> R : array<f32>; | |
| @group(0) @binding(6) var<storage, read_write> gW : array<f32>; | |
| @group(0) @binding(7) var<storage, read_write> gB : array<f32>; | |
| fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } | |
| fn rho(u: f32) -> f32 { return sg(u); } | |
| @compute @workgroup_size(8, 8) fn grad_W_dense(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let i = gid.y; let j = gid.x; | |
| if (i >= p.No || j >= p.Ni) { return; } | |
| var acc : f32 = 0.0; | |
| for (var b: u32 = 0u; b < p.B; b = b + 1u) { | |
| let rb = R[b]; | |
| let ip = rho(Up[b * p.No + i]); let im = rho(Um[b * p.No + i]); | |
| let jp = rho(Xp[b * p.Ni + j]); let jm = rho(Xm[b * p.Ni + j]); | |
| acc = acc + rb * (ip * jp - im * jm); | |
| } | |
| gW[i * p.Ni + j] = acc / p.two_beta; | |
| } | |
| @compute @workgroup_size(64) fn grad_B_dense(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let i = gid.x; | |
| if (i >= p.No) { return; } | |
| var acc : f32 = 0.0; | |
| for (var b: u32 = 0u; b < p.B; b = b + 1u) { | |
| let rb = R[b]; | |
| acc = acc + rb * (rho(Up[b * p.No + i]) - rho(Um[b * p.No + i])); | |
| } | |
| gB[i] = acc / p.two_beta; | |
| } | |
| `; | |
| // Reward computation — identical to v03. | |
| const WGSL_AUX_MULTI = ` | |
| struct AP { | |
| B: u32, O: u32, _p0: u32, _p1: u32, | |
| _p2: f32, _p3: f32, _p4: f32, _p5: f32, | |
| }; | |
| @group(0) @binding(0) var<uniform> p : AP; | |
| @group(0) @binding(1) var<storage, read> UoF : array<f32>; | |
| @group(0) @binding(2) var<storage, read> Tgt : array<f32>; | |
| @group(0) @binding(3) var<storage, read_write> R : array<f32>; | |
| fn sg(u: f32) -> f32 { return 1.0 / (1.0 + exp(-4.0 * (u - 0.5))); } | |
| @compute @workgroup_size(64) fn compute_reward(@builtin(global_invocation_id) gid: vec3<u32>) { | |
| let b = gid.x; | |
| if (b >= p.B) { return; } | |
| var loss : f32 = 0.0; | |
| let off = b * p.O; | |
| for (var i: u32 = 0u; i < p.O; i = i + 1u) { | |
| let d = sg(UoF[off + i]) - Tgt[off + i]; | |
| loss = loss + d * d; | |
| } | |
| var r : f32 = loss / 0.4; | |
| if (r > 1.0) { r = 1.0; } | |
| R[b] = 0.1 + 0.9 * r; | |
| } | |
| `; | |
| export class GPUTrainerConvMulti { | |
| // convCfgs: array of {Cin, Cout, KH, KW, stride, pad, H, W} — first entry's Cin/H/W is input image, | |
| // subsequent entries' Cin/H/W must equal previous entry's Cout/Hout/Wout (validated in constructor). | |
| // denseSize OR denseSizes: | |
| // - denseSize (scalar): number of OUTPUT classes O (backward-compat with v03 single-dense) | |
| // - denseSizes (array): [hiddenSize1, hiddenSize2, ..., O] — Tier E heterogeneous trainer | |
| // conv stack feeds first dense; each dense feeds the next via top-down. | |
| // Last dense receives the ±β target nudge. | |
| // B: batch size | |
| constructor({dev, convCfgs, denseSize, denseSizes, B}){ | |
| this.dev = dev; | |
| if(!Array.isArray(convCfgs) || convCfgs.length < 1) throw new Error('convCfgs must be non-empty array'); | |
| // Resolve denseSizes (Tier E): if scalar denseSize passed, wrap as single-element array. | |
| if(denseSizes !== undefined){ | |
| if(!Array.isArray(denseSizes) || denseSizes.length < 1) throw new Error('denseSizes must be non-empty array'); | |
| this.denseSizes = denseSizes.slice(); | |
| } else if(denseSize !== undefined){ | |
| this.denseSizes = [denseSize]; | |
| } else { | |
| throw new Error('must pass denseSize (scalar) or denseSizes (array)'); | |
| } | |
| this.D = this.denseSizes.length; // number of dense layers | |
| this.O = this.denseSizes[this.D-1]; // output classes = last dense size | |
| this.cfgs = convCfgs.map(c => ({...c})); // shallow-copy entries | |
| this.N = this.cfgs.length; | |
| this.B = B; | |
| // Compute per-layer Hout/Wout and verify chain consistency | |
| let prevC = null, prevH = null, prevW = null; | |
| for(let l=0; l<this.N; l++){ | |
| const c = this.cfgs[l]; | |
| if(l === 0){ | |
| if(!c.Cin || !c.H || !c.W) throw new Error(`convCfg[0] must specify Cin,H,W`); | |
| } else { | |
| if(c.Cin !== prevC) throw new Error(`convCfg[${l}].Cin (${c.Cin}) must equal convCfg[${l-1}].Cout (${prevC})`); | |
| c.H = prevH; c.W = prevW; | |
| } | |
| c.Hout = Math.floor((c.H + 2*c.pad - c.KH) / c.stride) + 1; | |
| c.Wout = Math.floor((c.W + 2*c.pad - c.KW) / c.stride) + 1; | |
| if(c.Hout < 1 || c.Wout < 1) throw new Error(`convCfg[${l}] produces invalid Hout=${c.Hout} Wout=${c.Wout}`); | |
| c.convFlat = c.Cout * c.Hout * c.Wout; | |
| prevC = c.Cout; prevH = c.Hout; prevW = c.Wout; | |
| } | |
| this.lastFlat = this.cfgs[this.N-1].convFlat; | |
| this._build(); | |
| } | |
| _F32(n, usage){ return this.dev.createBuffer({size: Math.max(4, n*4), usage}); } | |
| _build(){ | |
| const dev = this.dev, B = this.B, N = this.N, O = this.O; | |
| const RW = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST; | |
| const R = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC; | |
| const UNI= GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST; | |
| const RDS= GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ; | |
| // ---- buffers ---- | |
| const c0 = this.cfgs[0]; | |
| this.bufXin = this._F32(B * c0.Cin * c0.H * c0.W, R); | |
| this.bufTgt = this._F32(B * O, R); | |
| // Per-conv-layer: W, b, state per phase | |
| this.bufWconv = []; this.bufBconv = []; | |
| this.bufUconv = [[],[],[]]; | |
| this.bufGWconv = []; this.bufGBconv = []; | |
| this.rbGWconv = []; this.rbGBconv = []; | |
| // SI-5: per-conv-layer skip W (always allocated, dummy-tiny when unused). | |
| // Sized [O × this_layer's convFlat]; bound to a dummy 4-byte buffer when has_skip is off. | |
| this.bufWskip = []; | |
| this.bufGWskip = []; | |
| this.rbGWskip = []; | |
| this.skipEnabled = new Array(N).fill(false); | |
| this.bufDummySkip = this._F32(4, R); // 4-byte dummy for layers without skip | |
| // Tier H — per-layer per-channel τ buffer. Default Tau[k] = dt (set later via _writeAllPhaseUniforms). | |
| this.bufTau = []; | |
| for(let l=0; l<N; l++){ | |
| const c = this.cfgs[l]; | |
| const nw = c.Cout * c.Cin * c.KH * c.KW; | |
| this.bufWconv.push(this._F32(nw, R)); | |
| this.bufBconv.push(this._F32(c.Cout, R)); | |
| for(let p=0; p<3; p++) this.bufUconv[p].push(this._F32(B * c.convFlat, RW)); | |
| this.bufGWconv.push(this._F32(nw, RW)); | |
| this.bufGBconv.push(this._F32(c.Cout, RW)); | |
| this.rbGWconv.push(dev.createBuffer({size: nw*4, usage: RDS})); | |
| this.rbGBconv.push(dev.createBuffer({size: c.Cout*4, usage: RDS})); | |
| this.bufTau.push(this._F32(c.Cout, R)); // [Cout] tau-per-channel | |
| // SI-5: allocate skip W of size O × convFlat for each conv layer (in case user enables it later). | |
| // Last conv reads dense top-down already through standard chain — skip on last is redundant but allowed. | |
| const skipSize = this.denseSizes[this.D-1] * c.convFlat; | |
| this.bufWskip.push(this._F32(skipSize, R)); | |
| this.bufGWskip.push(this._F32(skipSize, RW)); | |
| this.rbGWskip.push(dev.createBuffer({size: skipSize * 4, usage: RDS})); | |
| } | |
| // Initialize all Tau to default 0.7 (will be overwritten by setAllTau or _writeAllPhaseUniforms) | |
| for(let l=0; l<N; l++){ | |
| const arr = new Float32Array(this.cfgs[l].Cout); arr.fill(0.7); | |
| this.dev.queue.writeBuffer(this.bufTau[l], 0, arr.buffer, arr.byteOffset, arr.byteLength); | |
| } | |
| // Tier F/G — per-layer activation mode + triangle params. Default mode=0 (σ-adaptive). | |
| this.modeFlag = 0; // 0=σ, 1=clip, 2=prism, 3=triangle | |
| this.triangleOffset = 0; | |
| this.trianglePower = 1; | |
| this.driveClampLo = 0; | |
| this.driveClampHi = 0; | |
| // Autoresearch v06 finding: top-down fan-in normalization. | |
| // When gammaTdNorm=true, per-layer γ is auto-scaled by sqrt(N_top_ref / N_top_layer) | |
| // so that effective top-down magnitude matches v03's working regime regardless of depth. | |
| // N_top_ref = 10 (v03's dense output dim, the regime where γ=0.6 was validated). | |
| // Reproduces v06's hand-tuned γ=0.1 for 2-conv conv0 automatically (0.6·√(10/288) ≈ 0.111). | |
| this.gammaTdNorm = false; | |
| this.gammaTdRef = 10; | |
| // HPSN state (per layer): {scalarTau, tauMin, tauMax, seed}; null = default scalar mode. | |
| this.tauSpec = new Array(N).fill(null); | |
| // Dense layers (Tier E: array). Per-layer W, b, U-state-per-phase, grad buffers. | |
| // Dense layer d has input Ni and output No: | |
| // Ni for d=0: lastFlat (conv stack output) | |
| // Ni for d>0: denseSizes[d-1] | |
| // No: denseSizes[d] | |
| this.bufWdense = []; this.bufBdense = []; | |
| this.bufGWdense = []; this.bufGBdense = []; | |
| this.rbGWdense = []; this.rbGBdense = []; | |
| this.bufUout = [[],[],[]]; // bufUout[phase][d] is dense layer d's state | |
| for(let d=0; d<this.D; d++){ | |
| const Ni = (d===0) ? this.lastFlat : this.denseSizes[d-1]; | |
| const No = this.denseSizes[d]; | |
| this.bufWdense.push(this._F32(No * Ni, R)); | |
| this.bufBdense.push(this._F32(No, R)); | |
| this.bufGWdense.push(this._F32(No * Ni, RW)); | |
| this.bufGBdense.push(this._F32(No, RW)); | |
| this.rbGWdense.push(dev.createBuffer({size: No*Ni*4, usage: RDS})); | |
| this.rbGBdense.push(dev.createBuffer({size: No*4, usage: RDS})); | |
| for(let p=0; p<3; p++) this.bufUout[p].push(this._F32(B * No, RW)); | |
| } | |
| // Readback for the output dense (last layer)'s free-phase state | |
| this.rbUoF = dev.createBuffer({size: B*this.O*4, usage: RDS}); | |
| // Reward | |
| this.bufR = this._F32(B, RW); | |
| this.bufDummyR = this._F32(4, R); | |
| // Uniforms — sizes: | |
| // Conv relax CP = 96 bytes (24 u32 slots) | |
| // Dense relax DP = 48 bytes (12 slots) | |
| // Conv grad CGP = 64 bytes | |
| // Dense grad DGP = 32 bytes | |
| // Aux AP = 32 bytes | |
| this.bufP_conv = []; // [layer][phase] — relax uniforms | |
| this.bufP_init_conv = []; // [layer][phase] — init uniforms (no top-down) | |
| for(let l=0; l<N; l++){ | |
| this.bufP_conv.push([]); | |
| this.bufP_init_conv.push([]); | |
| for(let p=0; p<3; p++){ | |
| this.bufP_conv[l].push(dev.createBuffer({size:144, usage:UNI})); | |
| this.bufP_init_conv[l].push(dev.createBuffer({size:144, usage:UNI})); | |
| } | |
| } | |
| // Per-dense-layer per-phase uniforms (new layout: 64 bytes = 16 slots with topdown fields) | |
| this.bufP_dense = []; // [phase][d] | |
| this.bufP_init_dense = []; // [phase][d] | |
| for(let p=0; p<3; p++){ | |
| this.bufP_dense.push([]); | |
| this.bufP_init_dense.push([]); | |
| for(let d=0; d<this.D; d++){ | |
| this.bufP_dense[p].push(dev.createBuffer({size:64, usage:UNI})); | |
| this.bufP_init_dense[p].push(dev.createBuffer({size:64, usage:UNI})); | |
| } | |
| } | |
| this.bufP_grad_conv = []; for(let l=0; l<N; l++) this.bufP_grad_conv.push(dev.createBuffer({size:64, usage:UNI})); | |
| this.bufP_grad_dense = []; for(let d=0; d<this.D; d++) this.bufP_grad_dense.push(dev.createBuffer({size:32, usage:UNI})); | |
| // SI-5 skip W gradient uniforms — one per conv layer (sized like dense grad uniform). | |
| this.bufP_grad_skip = []; for(let l=0; l<N; l++) this.bufP_grad_skip.push(dev.createBuffer({size:32, usage:UNI})); | |
| this.bufP_rew = dev.createBuffer({size:32, usage:UNI}); | |
| // ---- pipelines ---- | |
| const sR = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'read-only-storage'}}); | |
| const sRW = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'storage'}}); | |
| const uN = (i)=>({binding:i, visibility:GPUShaderStage.COMPUTE, buffer:{type:'uniform'}}); | |
| const modConv = dev.createShaderModule({code: WGSL_CONV_RELAX_MULTI}); | |
| this.bglConv = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sRW(5), sR(6), sR(7), sR(8), sR(9)]}); | |
| this.plConv = dev.createPipelineLayout({bindGroupLayouts:[this.bglConv]}); | |
| this.pipeConv = dev.createComputePipeline({layout:this.plConv, compute:{module:modConv, entryPoint:'conv_pass'}}); | |
| this.pipeInitConv = dev.createComputePipeline({layout:this.plConv, compute:{module:modConv, entryPoint:'init_state'}}); | |
| const modDense = dev.createShaderModule({code: WGSL_DENSE_OUT_MULTI}); | |
| this.bglDense = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sR(6), sRW(7)]}); | |
| this.plDense = dev.createPipelineLayout({bindGroupLayouts:[this.bglDense]}); | |
| this.pipeDense = dev.createComputePipeline({layout:this.plDense, compute:{module:modDense, entryPoint:'dense_pass'}}); | |
| this.pipeInitDense = dev.createComputePipeline({layout:this.plDense, compute:{module:modDense, entryPoint:'init_state_out'}}); | |
| const modGC = dev.createShaderModule({code: WGSL_GRAD_CONV_MULTI}); | |
| this.bglGC = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]}); | |
| this.plGC = dev.createPipelineLayout({bindGroupLayouts:[this.bglGC]}); | |
| this.pipeGWconv = dev.createComputePipeline({layout:this.plGC, compute:{module:modGC, entryPoint:'grad_W_conv'}}); | |
| this.pipeGBconv = dev.createComputePipeline({layout:this.plGC, compute:{module:modGC, entryPoint:'grad_B_conv'}}); | |
| const modGD = dev.createShaderModule({code: WGSL_GRAD_DENSE_MULTI}); | |
| this.bglGD = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sR(3), sR(4), sR(5), sRW(6), sRW(7)]}); | |
| this.plGD = dev.createPipelineLayout({bindGroupLayouts:[this.bglGD]}); | |
| this.pipeGWdense = dev.createComputePipeline({layout:this.plGD, compute:{module:modGD, entryPoint:'grad_W_dense'}}); | |
| this.pipeGBdense = dev.createComputePipeline({layout:this.plGD, compute:{module:modGD, entryPoint:'grad_B_dense'}}); | |
| // SI-5: skip W gradient uses the SAME grad_W_dense kernel — outer product of (last dense state) × (conv hidden). | |
| // Bind: Xp = conv_l hidden plus phase, Xm = conv_l hidden minus phase, Up = dense_last plus, Um = dense_last minus. | |
| // Output: gW with shape [denseSizes[D-1] × convFlat_l] — matches Wskip[l] layout. | |
| const modAux = dev.createShaderModule({code: WGSL_AUX_MULTI}); | |
| this.bglAux = dev.createBindGroupLayout({entries:[uN(0), sR(1), sR(2), sRW(3)]}); | |
| this.plAux = dev.createPipelineLayout({bindGroupLayouts:[this.bglAux]}); | |
| this.pipeReward = dev.createComputePipeline({layout:this.plAux, compute:{module:modAux, entryPoint:'compute_reward'}}); | |
| // ---- bind groups ---- | |
| // Conv per (layer, phase). Inputs depend on layer index. | |
| // layer 0: Xin = bufXin | |
| // layer l>0: Xin = bufUconv[phase][l-1] (previous layer's U-state, post-σ via rho()) | |
| // Wnxt: bufWdense if last conv (top-down type=1), else bufWconv[l+1] (type=2) | |
| // Unxt: bufUout[phase] if last conv, else bufUconv[phase][l+1] | |
| this.bgConv = []; this.bgInitConv = []; | |
| for(let l=0; l<N; l++){ | |
| this.bgConv.push([]); this.bgInitConv.push([]); | |
| for(let p=0; p<3; p++){ | |
| const isLast = (l === N-1); | |
| const Xin = (l === 0) ? this.bufXin : this.bufUconv[p][l-1]; | |
| // last conv reads top-down from FIRST dense layer (dense[0]), not last | |
| const Wnxt = isLast ? this.bufWdense[0] : this.bufWconv[l+1]; | |
| const Unxt = isLast ? this.bufUout[p][0] : this.bufUconv[p][l+1]; | |
| // SI-5 skip: conv reads from LAST DENSE state via its own Wskip[l]. | |
| // Uskip = bufUout[phase][D-1] for all conv layers (the last dense state). | |
| const Uskip = this.bufUout[p][this.D-1]; | |
| this.bgConv[l].push(dev.createBindGroup({layout:this.bglConv, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_conv[l][p]}}, | |
| {binding:1, resource:{buffer:Xin}}, | |
| {binding:2, resource:{buffer:this.bufWconv[l]}}, | |
| {binding:3, resource:{buffer:this.bufBconv[l]}}, | |
| {binding:4, resource:{buffer:Wnxt}}, | |
| {binding:5, resource:{buffer:this.bufUconv[p][l]}}, | |
| {binding:6, resource:{buffer:Unxt}}, | |
| {binding:7, resource:{buffer:this.bufTau[l]}}, | |
| {binding:8, resource:{buffer:this.bufWskip[l]}}, | |
| {binding:9, resource:{buffer:Uskip}}, | |
| ]})); | |
| this.bgInitConv[l].push(dev.createBindGroup({layout:this.bglConv, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_init_conv[l][p]}}, | |
| {binding:1, resource:{buffer:Xin}}, | |
| {binding:2, resource:{buffer:this.bufWconv[l]}}, | |
| {binding:3, resource:{buffer:this.bufBconv[l]}}, | |
| {binding:4, resource:{buffer:Wnxt}}, | |
| {binding:5, resource:{buffer:this.bufUconv[p][l]}}, | |
| {binding:6, resource:{buffer:Unxt}}, | |
| {binding:7, resource:{buffer:this.bufTau[l]}}, | |
| {binding:8, resource:{buffer:this.bufWskip[l]}}, | |
| {binding:9, resource:{buffer:Uskip}}, | |
| ]})); | |
| } | |
| } | |
| // Dense bind groups per (phase, dense-layer). New 8-binding layout. | |
| this.bgDense = [[],[],[]]; // bgDense[phase][d] | |
| this.bgInitDense = [[],[],[]]; | |
| for(let p=0; p<3; p++){ | |
| for(let d=0; d<this.D; d++){ | |
| const isLastD = (d === this.D-1); | |
| // input: layer d=0 reads last conv hidden, d>0 reads previous dense state | |
| const Xin = (d === 0) ? this.bufUconv[p][N-1] : this.bufUout[p][d-1]; | |
| // top-down: hidden dense (not last) reads next dense's W and U; last has none (uses target nudge) | |
| const Wnxt = isLastD ? this.bufDummyR : this.bufWdense[d+1]; | |
| const Unxt = isLastD ? this.bufDummyR : this.bufUout[p][d+1]; | |
| this.bgDense[p].push(dev.createBindGroup({layout:this.bglDense, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_dense[p][d]}}, | |
| {binding:1, resource:{buffer:Xin}}, | |
| {binding:2, resource:{buffer:this.bufWdense[d]}}, | |
| {binding:3, resource:{buffer:this.bufBdense[d]}}, | |
| {binding:4, resource:{buffer:Wnxt}}, | |
| {binding:5, resource:{buffer:Unxt}}, | |
| {binding:6, resource:{buffer:this.bufTgt}}, | |
| {binding:7, resource:{buffer:this.bufUout[p][d]}}, | |
| ]})); | |
| this.bgInitDense[p].push(dev.createBindGroup({layout:this.bglDense, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_init_dense[p][d]}}, | |
| {binding:1, resource:{buffer:Xin}}, | |
| {binding:2, resource:{buffer:this.bufWdense[d]}}, | |
| {binding:3, resource:{buffer:this.bufBdense[d]}}, | |
| {binding:4, resource:{buffer:Wnxt}}, | |
| {binding:5, resource:{buffer:Unxt}}, | |
| {binding:6, resource:{buffer:this.bufTgt}}, | |
| {binding:7, resource:{buffer:this.bufUout[p][d]}}, | |
| ]})); | |
| } | |
| } | |
| // Grad bind groups per conv layer: | |
| // pre_p/pre_m = layer's INPUT (Xin if l=0, else previous layer's U-plus/minus phase) | |
| // post_p/post_m = THIS layer's U-plus/minus phase | |
| this.bgGC = []; | |
| for(let l=0; l<N; l++){ | |
| const preP = (l === 0) ? this.bufXin : this.bufUconv[PHASE_P][l-1]; | |
| const preM = (l === 0) ? this.bufXin : this.bufUconv[PHASE_M][l-1]; | |
| this.bgGC.push(dev.createBindGroup({layout:this.bglGC, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_grad_conv[l]}}, | |
| {binding:1, resource:{buffer:preP}}, | |
| {binding:2, resource:{buffer:preM}}, | |
| {binding:3, resource:{buffer:this.bufUconv[PHASE_P][l]}}, | |
| {binding:4, resource:{buffer:this.bufUconv[PHASE_M][l]}}, | |
| {binding:5, resource:{buffer:this.bufR}}, | |
| {binding:6, resource:{buffer:this.bufGWconv[l]}}, | |
| {binding:7, resource:{buffer:this.bufGBconv[l]}}, | |
| ]})); | |
| } | |
| // Dense grad bind groups: one per dense layer. | |
| // layer 0: pre = last conv hidden (P/M phases), post = dense[0] (P/M) | |
| // layer d>0: pre = dense[d-1] (P/M), post = dense[d] (P/M) | |
| this.bgGD = []; | |
| for(let d=0; d<this.D; d++){ | |
| const preP = (d === 0) ? this.bufUconv[PHASE_P][N-1] : this.bufUout[PHASE_P][d-1]; | |
| const preM = (d === 0) ? this.bufUconv[PHASE_M][N-1] : this.bufUout[PHASE_M][d-1]; | |
| this.bgGD.push(dev.createBindGroup({layout:this.bglGD, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_grad_dense[d]}}, | |
| {binding:1, resource:{buffer:preP}}, | |
| {binding:2, resource:{buffer:preM}}, | |
| {binding:3, resource:{buffer:this.bufUout[PHASE_P][d]}}, | |
| {binding:4, resource:{buffer:this.bufUout[PHASE_M][d]}}, | |
| {binding:5, resource:{buffer:this.bufR}}, | |
| {binding:6, resource:{buffer:this.bufGWdense[d]}}, | |
| {binding:7, resource:{buffer:this.bufGBdense[d]}}, | |
| ]})); | |
| } | |
| // SI-5: skip-W gradient bind groups. Pre = conv_l hidden, Post = last dense state. | |
| // We need a tiny dummy bias-grad buffer since the dense grad kernel expects gB binding. | |
| this.bufDummySkipB = this._F32(this.denseSizes[this.D-1], RW); | |
| this.bgGSkip = []; | |
| for(let l=0; l<N; l++){ | |
| this.bgGSkip.push(dev.createBindGroup({layout:this.bglGD, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_grad_skip[l]}}, | |
| {binding:1, resource:{buffer:this.bufUconv[PHASE_P][l]}}, // Xp = conv_l hidden plus | |
| {binding:2, resource:{buffer:this.bufUconv[PHASE_M][l]}}, // Xm = conv_l hidden minus | |
| {binding:3, resource:{buffer:this.bufUout[PHASE_P][this.D-1]}}, // Up = last dense plus | |
| {binding:4, resource:{buffer:this.bufUout[PHASE_M][this.D-1]}}, // Um = last dense minus | |
| {binding:5, resource:{buffer:this.bufR}}, | |
| {binding:6, resource:{buffer:this.bufGWskip[l]}}, | |
| {binding:7, resource:{buffer:this.bufDummySkipB}}, // unused — skip has no bias for now | |
| ]})); | |
| } | |
| // Aux (reward only) — reads LAST dense state for loss | |
| this.bgRew = dev.createBindGroup({layout:this.bglAux, entries:[ | |
| {binding:0, resource:{buffer:this.bufP_rew}}, | |
| {binding:1, resource:{buffer:this.bufUout[PHASE_F][this.D-1]}}, | |
| {binding:2, resource:{buffer:this.bufTgt}}, | |
| {binding:3, resource:{buffer:this.bufR}}, | |
| ]}); | |
| } | |
| // Compute the effective γ for layer l given its top-down type. | |
| // Three modes (controlled by this.gammaSchedule): | |
| // 'flat' (default): γ_eff = gamma_base everywhere | |
| // 'fanin' : γ_eff = gamma_base * sqrt(gammaTdRef / N_top_l) | |
| // (matches v03 regime; falsified for 2-conv conv1 in v06 sweep) | |
| // 'invDepth' (SI-1) : γ_eff = gamma_base * sqrt(L / (l+1)) | |
| // Deeper convs get LARGER γ to compensate β-nudge attenuation. | |
| // For 3-conv L=3: γ_0=γ_base·√3=1.73·g_base, γ_1=γ_base·√1.5=1.22·g, γ_2=γ_base | |
| _gammaFor(l, gamma_base, has_topdown_type){ | |
| if(has_topdown_type === 0) return gamma_base; | |
| const sched = this.gammaSchedule || (this.gammaTdNorm ? 'fanin' : 'flat'); | |
| if(sched === 'flat') return gamma_base; | |
| if(sched === 'fanin'){ | |
| let N_top; | |
| if(has_topdown_type === 1){ | |
| N_top = this.denseSizes[0]; | |
| } else { | |
| const nc = this.cfgs[l+1]; | |
| N_top = nc.Cout * nc.KH * nc.KW; | |
| } | |
| return gamma_base * Math.sqrt(this.gammaTdRef / Math.max(1, N_top)); | |
| } | |
| if(sched === 'invDepth'){ | |
| // l ∈ [0, N-1]. Deeper layers (small l) get larger γ. Last conv (l=N-1) gets γ_base. | |
| return gamma_base * Math.sqrt(this.N / Math.max(1, l + 1)); | |
| } | |
| return gamma_base; | |
| } | |
| setGammaSchedule(sched){ this.gammaSchedule = sched; } // 'flat' | 'fanin' | 'invDepth' | |
| _writeConvParams(buf, l, {dt, gamma, has_topdown_type, iter_index=0}){ | |
| const c = this.cfgs[l]; | |
| const u32 = new Uint32Array(36); const f32 = new Float32Array(u32.buffer); | |
| u32[0]=this.B; u32[1]=c.Cin; u32[2]=c.Cout; u32[3]=c.H; | |
| u32[4]=c.W; u32[5]=c.Hout; u32[6]=c.Wout; u32[7]=c.KH; | |
| u32[8]=c.KW; u32[9]=c.stride; u32[10]=c.pad; u32[11]=0; | |
| // v06: optional per-layer γ scaling for top-down fan-in normalization | |
| const gamma_eff = this._gammaFor(l, gamma, has_topdown_type); | |
| f32[12]=dt; f32[13]=0; f32[14]=gamma_eff; f32[15]=this.modeFlag; | |
| let nxt_O=0, nxt_KH=0, nxt_KW=0, nxt_stride=0, nxt_pad=0, nxt_Cnxt=0; | |
| if(has_topdown_type === 1){ | |
| nxt_O = this.O; | |
| } else if(has_topdown_type === 2){ | |
| const nc = this.cfgs[l+1]; | |
| nxt_O = 0; nxt_KH = nc.KH; nxt_KW = nc.KW; nxt_stride = nc.stride; nxt_pad = nc.pad; | |
| nxt_Cnxt = nc.Cout; | |
| } | |
| u32[16]=has_topdown_type; u32[17]=nxt_O; u32[18]=nxt_KH; u32[19]=nxt_KW; | |
| u32[20]=nxt_stride; u32[21]=nxt_pad; u32[22]=nxt_Cnxt; u32[23]=0; | |
| f32[24]=this.driveClampLo; f32[25]=this.driveClampHi; | |
| f32[26]=this.triangleOffset; f32[27]=this.trianglePower; | |
| // MSMEN-MVT spike-sampling state | |
| u32[28] = this.spikeMode || 0; | |
| u32[29] = iter_index >>> 0; | |
| u32[30] = this.iterSeedBase >>> 0; | |
| f32[31] = 0; | |
| // SI-5 skip connection state (active iff this layer has skip enabled AND not the last layer) | |
| const has_skip = (this.skipEnabled && this.skipEnabled[l]) ? 1 : 0; | |
| u32[32] = has_skip; | |
| u32[33] = has_skip ? this.denseSizes[this.D-1] : 0; | |
| f32[34] = (this.skipGamma !== undefined) ? this.skipGamma : 0.1; | |
| f32[35] = 0; | |
| this.dev.queue.writeBuffer(buf, 0, u32.buffer); | |
| } | |
| // SI-5: enable skip connection for layer l (typically the FIRST conv, l=0). | |
| // Wskip[l] of shape [denseSizes[D-1] × this_layer's convFlat] is auto-allocated; | |
| // initialize via uploadSkipWeights(l, Float32Array). | |
| setSkipEnabled(l, enabled){ this.skipEnabled[l] = !!enabled; } | |
| setSkipGamma(g){ this.skipGamma = g; } | |
| uploadSkipWeights(l, Wskip){ | |
| this.dev.queue.writeBuffer(this.bufWskip[l], 0, Wskip.buffer, Wskip.byteOffset, Wskip.byteLength); | |
| } | |
| // MSMEN-MVT: enable spike-sampling mode and set the base seed for the next forward. | |
| // setSpikeMode(true, seedBase) makes the conv hidden update stochastic; deterministic if false. | |
| setSpikeMode(enabled, seedBase=0){ | |
| this.spikeMode = enabled ? 1 : 0; | |
| this.iterSeedBase = seedBase >>> 0; | |
| } | |
| // Public setter for autoresearch sweeps | |
| setGammaTdNorm(enabled, ref=10){ this.gammaTdNorm = !!enabled; this.gammaTdRef = ref; } | |
| // New 16-slot layout: includes Nnxt + gamma + has_topdown for dense layers (Tier E) | |
| _writeDenseParams(buf, {Ni, No, Nnxt, dt, beta, gamma, has_target, has_topdown}){ | |
| const u32 = new Uint32Array(16); const f32 = new Float32Array(u32.buffer); | |
| u32[0]=this.B; u32[1]=Ni; u32[2]=No; u32[3]=Nnxt; | |
| f32[4]=dt; f32[5]=beta; f32[6]=gamma; f32[7]=0; | |
| u32[8]=has_target; u32[9]=has_topdown; u32[10]=0; u32[11]=0; | |
| f32[12]=0; f32[13]=0; f32[14]=0; f32[15]=0; | |
| this.dev.queue.writeBuffer(buf, 0, u32.buffer); | |
| } | |
| _writeGradConvParams(l, two_beta){ | |
| const c = this.cfgs[l]; | |
| const u32 = new Uint32Array(16); const f32 = new Float32Array(u32.buffer); | |
| u32[0]=this.B; u32[1]=c.Cin; u32[2]=c.Cout; u32[3]=c.H; | |
| u32[4]=c.W; u32[5]=c.Hout; u32[6]=c.Wout; u32[7]=c.KH; | |
| u32[8]=c.KW; u32[9]=c.stride; u32[10]=c.pad; u32[11]=0; | |
| f32[12]=two_beta; f32[13]=0; f32[14]=0; f32[15]=0; | |
| this.dev.queue.writeBuffer(this.bufP_grad_conv[l], 0, u32.buffer); | |
| } | |
| _writeGradDenseParams(d, two_beta){ | |
| const Ni = (d===0) ? this.lastFlat : this.denseSizes[d-1]; | |
| const No = this.denseSizes[d]; | |
| const u32 = new Uint32Array(8); const f32 = new Float32Array(u32.buffer); | |
| u32[0]=this.B; u32[1]=Ni; u32[2]=No; u32[3]=0; | |
| f32[4]=two_beta; f32[5]=0; f32[6]=0; f32[7]=0; | |
| this.dev.queue.writeBuffer(this.bufP_grad_dense[d], 0, u32.buffer); | |
| } | |
| _writeAuxParams(){ | |
| const u32 = new Uint32Array(8); const f32 = new Float32Array(u32.buffer); | |
| u32[0]=this.B; u32[1]=this.O; u32[2]=0; u32[3]=0; | |
| f32[4]=0; f32[5]=0; f32[6]=0; f32[7]=0; | |
| this.dev.queue.writeBuffer(this.bufP_rew, 0, u32.buffer); | |
| } | |
| // Tier F/G — set the activation mode. mode ∈ {'adaptive','clip','prism','triangle'}. | |
| // For 'triangle' also pass {offset, power}. All algorithmic; no hardcoded magic constants. | |
| setMode(mode, opts={}){ | |
| const map = {'adaptive':0, 'clip':1, 'prism':2, 'triangle':3}; | |
| if(!(mode in map)) throw new Error(`setMode: unknown mode "${mode}". Use one of ${Object.keys(map)}`); | |
| this.modeFlag = map[mode]; | |
| if(mode === 'triangle'){ | |
| this.triangleOffset = (opts.offset !== undefined) ? opts.offset : 0; | |
| this.trianglePower = (opts.power !== undefined) ? opts.power : 1; | |
| } | |
| } | |
| setDriveClamp(lo, hi){ this.driveClampLo = lo; this.driveClampHi = hi; } | |
| // Tier H — per-layer τ. scalarTau=null + tauMin>0 + tauMax>tauMin → Uniform sample. | |
| setTau(layerIdx, scalarTau, tauMin=0, tauMax=0, seed=42){ | |
| const c = this.cfgs[layerIdx]; | |
| const arr = new Float32Array(c.Cout); | |
| if(tauMax > tauMin && tauMin > 0){ | |
| let s = (seed>>>0) || 1; | |
| const rng = ()=>{ s = (Math.imul(s, 1664525) + 1013904223) >>> 0; return s/4294967296; }; | |
| for(let i=0; i<c.Cout; i++) arr[i] = tauMin + rng() * (tauMax - tauMin); | |
| this.tauSpec[layerIdx] = {mode:'hpsn', tauMin, tauMax, seed}; | |
| } else { | |
| arr.fill(scalarTau); | |
| this.tauSpec[layerIdx] = {mode:'scalar', scalar:scalarTau}; | |
| } | |
| this.dev.queue.writeBuffer(this.bufTau[layerIdx], 0, arr.buffer, arr.byteOffset, arr.byteLength); | |
| return arr; | |
| } | |
| setAllTau(scalarTau, tauMin=0, tauMax=0, seed=42){ | |
| for(let l=0; l<this.N; l++) this.setTau(l, scalarTau, tauMin, tauMax, seed + l*1000); | |
| } | |
| // Wconv: array of Float32Array (one per conv layer). | |
| // Wdense/bdense: accepts EITHER a single Float32Array (D=1 backward-compat) or array of length D (Tier E). | |
| uploadWeights(Wconv, bconv, Wdense, bdense){ | |
| if(!Array.isArray(Wconv) || Wconv.length !== this.N) throw new Error(`uploadWeights: Wconv must be array of length ${this.N}`); | |
| const q = this.dev.queue; | |
| for(let l=0; l<this.N; l++){ | |
| q.writeBuffer(this.bufWconv[l], 0, Wconv[l].buffer, Wconv[l].byteOffset, Wconv[l].byteLength); | |
| q.writeBuffer(this.bufBconv[l], 0, bconv[l].buffer, bconv[l].byteOffset, bconv[l].byteLength); | |
| } | |
| // Normalize dense args: scalar → [scalar] for D=1 (backward-compat). | |
| const WdArr = Array.isArray(Wdense) ? Wdense : [Wdense]; | |
| const BdArr = Array.isArray(bdense) ? bdense : [bdense]; | |
| if(WdArr.length !== this.D) throw new Error(`uploadWeights: Wdense array length ${WdArr.length} != D=${this.D}`); | |
| for(let d=0; d<this.D; d++){ | |
| q.writeBuffer(this.bufWdense[d], 0, WdArr[d].buffer, WdArr[d].byteOffset, WdArr[d].byteLength); | |
| q.writeBuffer(this.bufBdense[d], 0, BdArr[d].buffer, BdArr[d].byteOffset, BdArr[d].byteLength); | |
| } | |
| } | |
| uploadInputs(X, T){ | |
| const q = this.dev.queue; | |
| q.writeBuffer(this.bufXin, 0, X.buffer, X.byteOffset, X.byteLength); | |
| q.writeBuffer(this.bufTgt, 0, T.buffer, T.byteOffset, T.byteLength); | |
| } | |
| _writeAllPhaseUniforms(dt, beta, gamma){ | |
| // Save for per-iter rewrites in spike-sampling mode. | |
| this._lastDt = dt; this._lastGamma = gamma; | |
| // Tier H — when layer is in default scalar-τ mode, ensure Tau[k] == dt (so τ tracks dt every call). | |
| // HPSN mode (user set explicit min/max) is preserved (Tau buffer untouched). | |
| for(let l=0; l<this.N; l++){ | |
| const spec = this.tauSpec[l]; | |
| if(!spec || spec.mode === 'scalar'){ | |
| if(this._lastTauDt !== dt || !spec){ | |
| const arr = new Float32Array(this.cfgs[l].Cout); arr.fill(dt); | |
| this.dev.queue.writeBuffer(this.bufTau[l], 0, arr.buffer, arr.byteOffset, arr.byteLength); | |
| this.tauSpec[l] = {mode:'scalar', scalar:dt}; | |
| } | |
| } | |
| } | |
| this._lastTauDt = dt; | |
| for(let l=0; l<this.N; l++){ | |
| // top-down type: 2 (conv-next) if not last, 1 (dense-next) if last | |
| const tdType = (l < this.N-1) ? 2 : 1; | |
| for(let p=0; p<3; p++){ | |
| this._writeConvParams(this.bufP_conv[l][p], l, {dt, gamma, has_topdown_type: tdType}); | |
| // init: no top-down (we re-initialize state to 0.1 before relax) | |
| this._writeConvParams(this.bufP_init_conv[l][p], l, {dt, gamma, has_topdown_type: 0}); | |
| } | |
| } | |
| // For each (phase, dense layer): | |
| // has_target: 1 only for the LAST dense layer | |
| // has_topdown: 1 if NOT the last (reads from next dense) | |
| // beta: ±beta for last in plus/minus, else 0 | |
| const phaseBetas = [0, +beta, -beta]; | |
| for(let p=0; p<3; p++){ | |
| for(let d=0; d<this.D; d++){ | |
| const isLastD = (d === this.D-1); | |
| const Ni = (d===0) ? this.lastFlat : this.denseSizes[d-1]; | |
| const No = this.denseSizes[d]; | |
| const Nnxt = isLastD ? 0 : this.denseSizes[d+1]; | |
| const phaseBeta = isLastD ? phaseBetas[p] : 0; | |
| this._writeDenseParams(this.bufP_dense[p][d], { | |
| Ni, No, Nnxt, dt, beta: phaseBeta, gamma, | |
| has_target: isLastD ? 1 : 0, has_topdown: isLastD ? 0 : 1, | |
| }); | |
| this._writeDenseParams(this.bufP_init_dense[p][d], { | |
| Ni, No, Nnxt, dt, beta: 0, gamma: 0, | |
| has_target: 0, has_topdown: 0, | |
| }); | |
| } | |
| } | |
| } | |
| _initAllPhases(enc){ | |
| const MAX_WG_X = 65535; | |
| for(let p=0; p<3; p++){ | |
| for(let l=0; l<this.N; l++){ | |
| const c = this.cfgs[l]; | |
| const n = this.B * c.convFlat; | |
| const wg = Math.ceil(n/64); | |
| const pass = enc.beginComputePass(); | |
| pass.setPipeline(this.pipeInitConv); pass.setBindGroup(0, this.bgInitConv[l][p]); | |
| pass.dispatchWorkgroups(Math.min(wg, MAX_WG_X), Math.ceil(wg/MAX_WG_X)); | |
| pass.end(); | |
| } | |
| for(let d=0; d<this.D; d++){ | |
| const no = this.B * this.denseSizes[d]; | |
| const pass = enc.beginComputePass(); | |
| pass.setPipeline(this.pipeInitDense); pass.setBindGroup(0, this.bgInitDense[p][d]); | |
| pass.dispatchWorkgroups(Math.ceil(no/64)); | |
| pass.end(); | |
| } | |
| } | |
| } | |
| _runPhaseRelax(enc, phase, iters){ | |
| for(let t=0; t<iters; t++){ | |
| // Rewrite conv uniforms each iter to bump iter_index (used by MSMEN-MVT PCG seed). | |
| // No-op cost when spike_mode == 0 — but the write still happens; small. | |
| if(this.spikeMode){ | |
| for(let l=0; l<this.N; l++){ | |
| const tdType = (l < this.N-1) ? 2 : 1; | |
| this._writeConvParams(this.bufP_conv[l][phase], l, {dt: this._lastDt, gamma: this._lastGamma, has_topdown_type: tdType, iter_index: t}); | |
| } | |
| } | |
| for(let l=0; l<this.N; l++){ | |
| const c = this.cfgs[l]; | |
| const pass = enc.beginComputePass(); | |
| pass.setPipeline(this.pipeConv); pass.setBindGroup(0, this.bgConv[l][phase]); | |
| pass.dispatchWorkgroups(Math.ceil(c.Wout/8), Math.ceil(c.Hout/8), this.B * c.Cout); | |
| pass.end(); | |
| } | |
| for(let d=0; d<this.D; d++){ | |
| const No = this.denseSizes[d]; | |
| const pass = enc.beginComputePass(); | |
| pass.setPipeline(this.pipeDense); pass.setBindGroup(0, this.bgDense[phase][d]); | |
| pass.dispatchWorkgroups(Math.ceil(No/64), this.B); | |
| pass.end(); | |
| } | |
| } | |
| } | |
| _runReward(enc){ | |
| this._writeAuxParams(); | |
| const pass = enc.beginComputePass(); | |
| pass.setPipeline(this.pipeReward); pass.setBindGroup(0, this.bgRew); | |
| pass.dispatchWorkgroups(Math.ceil(this.B/64)); | |
| pass.end(); | |
| } | |
| _runGrad(enc, beta){ | |
| for(let l=0; l<this.N; l++){ | |
| const c = this.cfgs[l]; | |
| this._writeGradConvParams(l, 2*beta); | |
| const pass = enc.beginComputePass(); | |
| pass.setPipeline(this.pipeGWconv); pass.setBindGroup(0, this.bgGC[l]); | |
| pass.dispatchWorkgroups(Math.ceil(c.KW/8), Math.ceil(c.KH/8), c.Cout * c.Cin); | |
| pass.setPipeline(this.pipeGBconv); pass.setBindGroup(0, this.bgGC[l]); | |
| pass.dispatchWorkgroups(Math.ceil(c.Cout/64)); | |
| pass.end(); | |
| } | |
| // SI-5: skip-W gradient per conv layer (only meaningful when skipEnabled[l]). | |
| // Always run (small cost vs whole forward); the optimizer side gates by skipEnabled. | |
| for(let l=0; l<this.N; l++){ | |
| const c = this.cfgs[l]; | |
| const Ni = c.convFlat; // conv hidden flat | |
| const No = this.denseSizes[this.D-1]; | |
| // Write grad-skip uniform (same 8-slot dense-grad layout): B, Ni, No, _, two_beta | |
| const u32 = new Uint32Array(8); const f32 = new Float32Array(u32.buffer); | |
| u32[0]=this.B; u32[1]=Ni; u32[2]=No; u32[3]=0; | |
| f32[4]=2*beta; f32[5]=0; f32[6]=0; f32[7]=0; | |
| this.dev.queue.writeBuffer(this.bufP_grad_skip[l], 0, u32.buffer); | |
| const pass = enc.beginComputePass(); | |
| pass.setPipeline(this.pipeGWdense); pass.setBindGroup(0, this.bgGSkip[l]); | |
| pass.dispatchWorkgroups(Math.ceil(Ni/8), Math.ceil(No/8)); | |
| pass.end(); | |
| } | |
| // Per-dense-layer gradient | |
| for(let d=0; d<this.D; d++){ | |
| this._writeGradDenseParams(d, 2*beta); | |
| const Ni = (d===0) ? this.lastFlat : this.denseSizes[d-1]; | |
| const No = this.denseSizes[d]; | |
| const pass = enc.beginComputePass(); | |
| pass.setPipeline(this.pipeGWdense); pass.setBindGroup(0, this.bgGD[d]); | |
| pass.dispatchWorkgroups(Math.ceil(Ni/8), Math.ceil(No/8)); | |
| pass.setPipeline(this.pipeGBdense); pass.setBindGroup(0, this.bgGD[d]); | |
| pass.dispatchWorkgroups(Math.ceil(No/64)); | |
| pass.end(); | |
| } | |
| } | |
| async runFreeAndReadOutputs(iters, dt, gamma=0.6){ | |
| this._writeAllPhaseUniforms(dt, 0, gamma); | |
| const enc = this.dev.createCommandEncoder(); | |
| this._initAllPhases(enc); | |
| this._runPhaseRelax(enc, PHASE_F, iters); | |
| enc.copyBufferToBuffer(this.bufUout[PHASE_F][this.D-1], 0, this.rbUoF, 0, this.B*this.O*4); | |
| this.dev.queue.submit([enc.finish()]); | |
| await this.rbUoF.mapAsync(GPUMapMode.READ); | |
| const r = new Float32Array(this.rbUoF.getMappedRange().slice(0)); | |
| this.rbUoF.unmap(); | |
| return r; | |
| } | |
| async runOnePassGetGradients({itF=8, itN=5, dt=0.7, beta=0.5, gamma=0.6}={}){ | |
| this._writeAllPhaseUniforms(dt, beta, gamma); | |
| const enc = this.dev.createCommandEncoder(); | |
| this._initAllPhases(enc); | |
| this._runPhaseRelax(enc, PHASE_F, itF); | |
| this._runPhaseRelax(enc, PHASE_P, itN); | |
| this._runPhaseRelax(enc, PHASE_M, itN); | |
| this._runReward(enc); | |
| this._runGrad(enc, beta); | |
| // Readback gradients per conv layer + dense | |
| const reads = []; | |
| for(let l=0; l<this.N; l++){ | |
| const c = this.cfgs[l]; | |
| enc.copyBufferToBuffer(this.bufGWconv[l], 0, this.rbGWconv[l], 0, c.Cout*c.Cin*c.KH*c.KW*4); | |
| enc.copyBufferToBuffer(this.bufGBconv[l], 0, this.rbGBconv[l], 0, c.Cout*4); | |
| } | |
| for(let d=0; d<this.D; d++){ | |
| const Ni = (d===0) ? this.lastFlat : this.denseSizes[d-1]; | |
| const No = this.denseSizes[d]; | |
| enc.copyBufferToBuffer(this.bufGWdense[d], 0, this.rbGWdense[d], 0, No*Ni*4); | |
| enc.copyBufferToBuffer(this.bufGBdense[d], 0, this.rbGBdense[d], 0, No*4); | |
| } | |
| // SI-5: readback skip W gradients per conv layer | |
| const O_last = this.denseSizes[this.D-1]; | |
| for(let l=0; l<this.N; l++){ | |
| const sz = O_last * this.cfgs[l].convFlat * 4; | |
| enc.copyBufferToBuffer(this.bufGWskip[l], 0, this.rbGWskip[l], 0, sz); | |
| } | |
| enc.copyBufferToBuffer(this.bufUout[PHASE_F][this.D-1], 0, this.rbUoF, 0, this.B*this.O*4); | |
| this.dev.queue.submit([enc.finish()]); | |
| const maps = [this.rbUoF.mapAsync(GPUMapMode.READ)]; | |
| for(let l=0; l<this.N; l++) maps.push(this.rbGWskip[l].mapAsync(GPUMapMode.READ)); | |
| for(let l=0; l<this.N; l++){ | |
| maps.push(this.rbGWconv[l].mapAsync(GPUMapMode.READ)); | |
| maps.push(this.rbGBconv[l].mapAsync(GPUMapMode.READ)); | |
| } | |
| for(let d=0; d<this.D; d++){ | |
| maps.push(this.rbGWdense[d].mapAsync(GPUMapMode.READ)); | |
| maps.push(this.rbGBdense[d].mapAsync(GPUMapMode.READ)); | |
| } | |
| await Promise.all(maps); | |
| const gWconv = [], gBconv = []; | |
| for(let l=0; l<this.N; l++){ | |
| gWconv.push(new Float32Array(this.rbGWconv[l].getMappedRange().slice(0))); | |
| gBconv.push(new Float32Array(this.rbGBconv[l].getMappedRange().slice(0))); | |
| this.rbGWconv[l].unmap(); this.rbGBconv[l].unmap(); | |
| } | |
| const gWdenseArr = [], gBdenseArr = []; | |
| for(let d=0; d<this.D; d++){ | |
| gWdenseArr.push(new Float32Array(this.rbGWdense[d].getMappedRange().slice(0))); | |
| gBdenseArr.push(new Float32Array(this.rbGBdense[d].getMappedRange().slice(0))); | |
| this.rbGWdense[d].unmap(); this.rbGBdense[d].unmap(); | |
| } | |
| const uoF = new Float32Array(this.rbUoF.getMappedRange().slice(0)); | |
| this.rbUoF.unmap(); | |
| // SI-5: skip W gradients per conv layer | |
| const gWskip = []; | |
| for(let l=0; l<this.N; l++){ | |
| gWskip.push(new Float32Array(this.rbGWskip[l].getMappedRange().slice(0))); | |
| this.rbGWskip[l].unmap(); | |
| } | |
| // Backward-compat: when D=1, expose gWdense/gBdense as scalars; always also expose arrays. | |
| const gWdense = (this.D === 1) ? gWdenseArr[0] : gWdenseArr; | |
| const gBdense = (this.D === 1) ? gBdenseArr[0] : gBdenseArr; | |
| return {gWconv, gBconv, gWdense, gBdense, gWdenseArr, gBdenseArr, gWskip, uoF}; | |
| } | |
| destroy(){ | |
| const bufs = [this.bufXin, this.bufTgt, this.rbUoF, this.bufR, this.bufDummyR, this.bufP_rew]; | |
| for(const a of [this.bufWconv, this.bufBconv, this.bufGWconv, this.bufGBconv, this.rbGWconv, this.rbGBconv, | |
| this.bufP_grad_conv, this.bufTau, | |
| this.bufWdense, this.bufBdense, this.bufGWdense, this.bufGBdense, | |
| this.rbGWdense, this.rbGBdense, this.bufP_grad_dense]) bufs.push(...a); | |
| for(const ph of this.bufUconv) bufs.push(...ph); | |
| for(const ph of this.bufUout) bufs.push(...ph); | |
| for(const l of this.bufP_conv) bufs.push(...l); | |
| for(const l of this.bufP_init_conv) bufs.push(...l); | |
| for(const ph of this.bufP_dense) bufs.push(...ph); | |
| for(const ph of this.bufP_init_dense) bufs.push(...ph); | |
| for(const v of bufs) if(v && v.destroy) try{ v.destroy(); }catch(e){} | |
| } | |
| } | |