File size: 1,999 Bytes
b3e0b44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// conv1d_step.wgsl — Mamba autoregressive depthwise conv1d, single-step.
// Ported from gfx1151_inference/shaders/conv1d_step.comp (Vulkan GLSL → WebGPU WGSL).
//
// Falcon-Mamba 7B: H=INTERMEDIATE=8192, K=CONV_KERNEL=4.
//
// Per channel h, maintains a conv state cache of K-1=3 past hidden values.
// state layout (row-major): [H, K-1], oldest at index 0, newest at index K-2.
//
// Step:
//   window = [state[h,0], state[h,1], state[h,2], x[h]]   (length K)
//   y[h]   = sum_k(window[k] * W[h, k]) + bias[h]
//   state[h, 0..K-3] = state[h, 1..K-2]    (shift left)
//   state[h, K-2]    = x[h]                (append current)
//
// One thread per h. WG_SIZE = 64.

const WG_SIZE: u32 = 64u;
const K: u32 = 4u;

struct Params {
    H: u32,
}

@group(0) @binding(0) var<storage, read_write> state_buf: array<f32>;  // [H, K-1]
@group(0) @binding(1) var<storage, read>       x_buf: array<f32>;      // [H]
@group(0) @binding(2) var<storage, read>       w_buf: array<f32>;      // [H, K]
@group(0) @binding(3) var<storage, read>       b_buf: array<f32>;      // [H]
@group(0) @binding(4) var<storage, read_write> out_buf: array<f32>;    // [H]

@group(1) @binding(0) var<uniform> params: Params;

@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let h = gid.x;
    if (h >= params.H) {
        return;
    }

    let sb = h * (K - 1u);   // state base for this h
    let wb = h * K;           // weight base for this h

    let s0 = state_buf[sb + 0u];   // oldest of the K-1 past values
    let s1 = state_buf[sb + 1u];
    let s2 = state_buf[sb + 2u];   // most recent of the K-1 past values
    let xh = x_buf[h];

    let w0 = w_buf[wb + 0u];
    let w1 = w_buf[wb + 1u];
    let w2 = w_buf[wb + 2u];
    let w3 = w_buf[wb + 3u];
    let bh = b_buf[h];

    out_buf[h] = s0 * w0 + s1 * w1 + s2 * w2 + xh * w3 + bh;

    // Shift left and append current x
    state_buf[sb + 0u] = s1;
    state_buf[sb + 1u] = s2;
    state_buf[sb + 2u] = xh;
}