Upload shaders/conv1d_step.wgsl with huggingface_hub
Browse files- shaders/conv1d_step.wgsl +59 -0
shaders/conv1d_step.wgsl
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// conv1d_step.wgsl — Mamba autoregressive depthwise conv1d, single-step.
|
| 2 |
+
// Ported from gfx1151_inference/shaders/conv1d_step.comp (Vulkan GLSL → WebGPU WGSL).
|
| 3 |
+
//
|
| 4 |
+
// Falcon-Mamba 7B: H=INTERMEDIATE=8192, K=CONV_KERNEL=4.
|
| 5 |
+
//
|
| 6 |
+
// Per channel h, maintains a conv state cache of K-1=3 past hidden values.
|
| 7 |
+
// state layout (row-major): [H, K-1], oldest at index 0, newest at index K-2.
|
| 8 |
+
//
|
| 9 |
+
// Step:
|
| 10 |
+
// window = [state[h,0], state[h,1], state[h,2], x[h]] (length K)
|
| 11 |
+
// y[h] = sum_k(window[k] * W[h, k]) + bias[h]
|
| 12 |
+
// state[h, 0..K-3] = state[h, 1..K-2] (shift left)
|
| 13 |
+
// state[h, K-2] = x[h] (append current)
|
| 14 |
+
//
|
| 15 |
+
// One thread per h. WG_SIZE = 64.
|
| 16 |
+
|
| 17 |
+
const WG_SIZE: u32 = 64u;
|
| 18 |
+
const K: u32 = 4u;
|
| 19 |
+
|
| 20 |
+
struct Params {
|
| 21 |
+
H: u32,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
@group(0) @binding(0) var<storage, read_write> state_buf: array<f32>; // [H, K-1]
|
| 25 |
+
@group(0) @binding(1) var<storage, read> x_buf: array<f32>; // [H]
|
| 26 |
+
@group(0) @binding(2) var<storage, read> w_buf: array<f32>; // [H, K]
|
| 27 |
+
@group(0) @binding(3) var<storage, read> b_buf: array<f32>; // [H]
|
| 28 |
+
@group(0) @binding(4) var<storage, read_write> out_buf: array<f32>; // [H]
|
| 29 |
+
|
| 30 |
+
@group(1) @binding(0) var<uniform> params: Params;
|
| 31 |
+
|
| 32 |
+
@compute @workgroup_size(WG_SIZE)
|
| 33 |
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
| 34 |
+
let h = gid.x;
|
| 35 |
+
if (h >= params.H) {
|
| 36 |
+
return;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
let sb = h * (K - 1u); // state base for this h
|
| 40 |
+
let wb = h * K; // weight base for this h
|
| 41 |
+
|
| 42 |
+
let s0 = state_buf[sb + 0u]; // oldest of the K-1 past values
|
| 43 |
+
let s1 = state_buf[sb + 1u];
|
| 44 |
+
let s2 = state_buf[sb + 2u]; // most recent of the K-1 past values
|
| 45 |
+
let xh = x_buf[h];
|
| 46 |
+
|
| 47 |
+
let w0 = w_buf[wb + 0u];
|
| 48 |
+
let w1 = w_buf[wb + 1u];
|
| 49 |
+
let w2 = w_buf[wb + 2u];
|
| 50 |
+
let w3 = w_buf[wb + 3u];
|
| 51 |
+
let bh = b_buf[h];
|
| 52 |
+
|
| 53 |
+
out_buf[h] = s0 * w0 + s1 * w1 + s2 * w2 + xh * w3 + bh;
|
| 54 |
+
|
| 55 |
+
// Shift left and append current x
|
| 56 |
+
state_buf[sb + 0u] = s1;
|
| 57 |
+
state_buf[sb + 1u] = s2;
|
| 58 |
+
state_buf[sb + 2u] = xh;
|
| 59 |
+
}
|