| // conv1d_step.wgsl — Mamba autoregressive depthwise conv1d, single-step. | |
| // Ported from gfx1151_inference/shaders/conv1d_step.comp (Vulkan GLSL → WebGPU WGSL). | |
| // | |
| // Falcon-Mamba 7B: H=INTERMEDIATE=8192, K=CONV_KERNEL=4. | |
| // | |
| // Per channel h, maintains a conv state cache of K-1=3 past hidden values. | |
| // state layout (row-major): [H, K-1], oldest at index 0, newest at index K-2. | |
| // | |
| // Step: | |
| // window = [state[h,0], state[h,1], state[h,2], x[h]] (length K) | |
| // y[h] = sum_k(window[k] * W[h, k]) + bias[h] | |
| // state[h, 0..K-3] = state[h, 1..K-2] (shift left) | |
| // state[h, K-2] = x[h] (append current) | |
| // | |
| // One thread per h. WG_SIZE = 64. | |
| const WG_SIZE: u32 = 64u; | |
| const K: u32 = 4u; | |
| struct Params { | |
| H: u32, | |
| } | |
| var<storage, read_write> state_buf: array<f32>; // [H, K-1] | |
| var<storage, read> x_buf: array<f32>; // [H] | |
| var<storage, read> w_buf: array<f32>; // [H, K] | |
| var<storage, read> b_buf: array<f32>; // [H] | |
| var<storage, read_write> out_buf: array<f32>; // [H] | |
| var<uniform> params: Params; | |
| fn main( gid: vec3<u32>) { | |
| let h = gid.x; | |
| if (h >= params.H) { | |
| return; | |
| } | |
| let sb = h * (K - 1u); // state base for this h | |
| let wb = h * K; // weight base for this h | |
| let s0 = state_buf[sb + 0u]; // oldest of the K-1 past values | |
| let s1 = state_buf[sb + 1u]; | |
| let s2 = state_buf[sb + 2u]; // most recent of the K-1 past values | |
| let xh = x_buf[h]; | |
| let w0 = w_buf[wb + 0u]; | |
| let w1 = w_buf[wb + 1u]; | |
| let w2 = w_buf[wb + 2u]; | |
| let w3 = w_buf[wb + 3u]; | |
| let bh = b_buf[h]; | |
| out_buf[h] = s0 * w0 + s1 * w1 + s2 * w2 + xh * w3 + bh; | |
| // Shift left and append current x | |
| state_buf[sb + 0u] = s1; | |
| state_buf[sb + 1u] = s2; | |
| state_buf[sb + 2u] = xh; | |
| } | |