// conv1d_step.wgsl — Mamba autoregressive depthwise conv1d, single-step. // Ported from gfx1151_inference/shaders/conv1d_step.comp (Vulkan GLSL → WebGPU WGSL). // // Falcon-Mamba 7B: H=INTERMEDIATE=8192, K=CONV_KERNEL=4. // // Per channel h, maintains a conv state cache of K-1=3 past hidden values. // state layout (row-major): [H, K-1], oldest at index 0, newest at index K-2. // // Step: // window = [state[h,0], state[h,1], state[h,2], x[h]] (length K) // y[h] = sum_k(window[k] * W[h, k]) + bias[h] // state[h, 0..K-3] = state[h, 1..K-2] (shift left) // state[h, K-2] = x[h] (append current) // // One thread per h. WG_SIZE = 64. const WG_SIZE: u32 = 64u; const K: u32 = 4u; struct Params { H: u32, } @group(0) @binding(0) var state_buf: array; // [H, K-1] @group(0) @binding(1) var x_buf: array; // [H] @group(0) @binding(2) var w_buf: array; // [H, K] @group(0) @binding(3) var b_buf: array; // [H] @group(0) @binding(4) var out_buf: array; // [H] @group(1) @binding(0) var params: Params; @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { let h = gid.x; if (h >= params.H) { return; } let sb = h * (K - 1u); // state base for this h let wb = h * K; // weight base for this h let s0 = state_buf[sb + 0u]; // oldest of the K-1 past values let s1 = state_buf[sb + 1u]; let s2 = state_buf[sb + 2u]; // most recent of the K-1 past values let xh = x_buf[h]; let w0 = w_buf[wb + 0u]; let w1 = w_buf[wb + 1u]; let w2 = w_buf[wb + 2u]; let w3 = w_buf[wb + 3u]; let bh = b_buf[h]; out_buf[h] = s0 * w0 + s1 * w1 + s2 * w2 + xh * w3 + bh; // Shift left and append current x state_buf[sb + 0u] = s1; state_buf[sb + 1u] = s2; state_buf[sb + 2u] = xh; }