LJTSG commited on
Commit
b3e0b44
·
verified ·
1 Parent(s): a87c238

Upload shaders/conv1d_step.wgsl with huggingface_hub

Browse files
Files changed (1) hide show
  1. shaders/conv1d_step.wgsl +59 -0
shaders/conv1d_step.wgsl ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // conv1d_step.wgsl — Mamba autoregressive depthwise conv1d, single-step.
2
+ // Ported from gfx1151_inference/shaders/conv1d_step.comp (Vulkan GLSL → WebGPU WGSL).
3
+ //
4
+ // Falcon-Mamba 7B: H=INTERMEDIATE=8192, K=CONV_KERNEL=4.
5
+ //
6
+ // Per channel h, maintains a conv state cache of K-1=3 past hidden values.
7
+ // state layout (row-major): [H, K-1], oldest at index 0, newest at index K-2.
8
+ //
9
+ // Step:
10
+ // window = [state[h,0], state[h,1], state[h,2], x[h]] (length K)
11
+ // y[h] = sum_k(window[k] * W[h, k]) + bias[h]
12
+ // state[h, 0..K-3] = state[h, 1..K-2] (shift left)
13
+ // state[h, K-2] = x[h] (append current)
14
+ //
15
+ // One thread per h. WG_SIZE = 64.
16
+
17
+ const WG_SIZE: u32 = 64u;
18
+ const K: u32 = 4u;
19
+
20
+ struct Params {
21
+ H: u32,
22
+ }
23
+
24
+ @group(0) @binding(0) var<storage, read_write> state_buf: array<f32>; // [H, K-1]
25
+ @group(0) @binding(1) var<storage, read> x_buf: array<f32>; // [H]
26
+ @group(0) @binding(2) var<storage, read> w_buf: array<f32>; // [H, K]
27
+ @group(0) @binding(3) var<storage, read> b_buf: array<f32>; // [H]
28
+ @group(0) @binding(4) var<storage, read_write> out_buf: array<f32>; // [H]
29
+
30
+ @group(1) @binding(0) var<uniform> params: Params;
31
+
32
+ @compute @workgroup_size(WG_SIZE)
33
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
34
+ let h = gid.x;
35
+ if (h >= params.H) {
36
+ return;
37
+ }
38
+
39
+ let sb = h * (K - 1u); // state base for this h
40
+ let wb = h * K; // weight base for this h
41
+
42
+ let s0 = state_buf[sb + 0u]; // oldest of the K-1 past values
43
+ let s1 = state_buf[sb + 1u];
44
+ let s2 = state_buf[sb + 2u]; // most recent of the K-1 past values
45
+ let xh = x_buf[h];
46
+
47
+ let w0 = w_buf[wb + 0u];
48
+ let w1 = w_buf[wb + 1u];
49
+ let w2 = w_buf[wb + 2u];
50
+ let w3 = w_buf[wb + 3u];
51
+ let bh = b_buf[h];
52
+
53
+ out_buf[h] = s0 * w0 + s1 * w1 + s2 * w2 + xh * w3 + bh;
54
+
55
+ // Shift left and append current x
56
+ state_buf[sb + 0u] = s1;
57
+ state_buf[sb + 1u] = s2;
58
+ state_buf[sb + 2u] = xh;
59
+ }