mamba-webgpu / shaders /conv1d_step.wgsl
LJTSG's picture
Upload shaders/conv1d_step.wgsl with huggingface_hub
b3e0b44 verified
// conv1d_step.wgsl — Mamba autoregressive depthwise conv1d, single-step.
// Ported from gfx1151_inference/shaders/conv1d_step.comp (Vulkan GLSL → WebGPU WGSL).
//
// Falcon-Mamba 7B: H=INTERMEDIATE=8192, K=CONV_KERNEL=4.
//
// Per channel h, maintains a conv state cache of K-1=3 past hidden values.
// state layout (row-major): [H, K-1], oldest at index 0, newest at index K-2.
//
// Step:
// window = [state[h,0], state[h,1], state[h,2], x[h]] (length K)
// y[h] = sum_k(window[k] * W[h, k]) + bias[h]
// state[h, 0..K-3] = state[h, 1..K-2] (shift left)
// state[h, K-2] = x[h] (append current)
//
// One thread per h. WG_SIZE = 64.
const WG_SIZE: u32 = 64u;
const K: u32 = 4u;
struct Params {
H: u32,
}
@group(0) @binding(0) var<storage, read_write> state_buf: array<f32>; // [H, K-1]
@group(0) @binding(1) var<storage, read> x_buf: array<f32>; // [H]
@group(0) @binding(2) var<storage, read> w_buf: array<f32>; // [H, K]
@group(0) @binding(3) var<storage, read> b_buf: array<f32>; // [H]
@group(0) @binding(4) var<storage, read_write> out_buf: array<f32>; // [H]
@group(1) @binding(0) var<uniform> params: Params;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let h = gid.x;
if (h >= params.H) {
return;
}
let sb = h * (K - 1u); // state base for this h
let wb = h * K; // weight base for this h
let s0 = state_buf[sb + 0u]; // oldest of the K-1 past values
let s1 = state_buf[sb + 1u];
let s2 = state_buf[sb + 2u]; // most recent of the K-1 past values
let xh = x_buf[h];
let w0 = w_buf[wb + 0u];
let w1 = w_buf[wb + 1u];
let w2 = w_buf[wb + 2u];
let w3 = w_buf[wb + 3u];
let bh = b_buf[h];
out_buf[h] = s0 * w0 + s1 * w1 + s2 * w2 + xh * w3 + bh;
// Shift left and append current x
state_buf[sb + 0u] = s1;
state_buf[sb + 1u] = s2;
state_buf[sb + 2u] = xh;
}