File size: 933 Bytes
c3a1bdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// embedding.wgsl — token embedding lookup.
// For each token id, copy table[ids[i]] row into output.
// One thread per output element (seq_len * D threads total).

struct Params {
    seq_len: u32,
    hidden_dim: u32,
    vocab_size: u32,
}

@group(0) @binding(0) var<storage, read>       ids: array<u32>;
@group(0) @binding(1) var<storage, read>       table: array<f32>;
@group(0) @binding(2) var<storage, read_write> out_buf: array<f32>;

@group(1) @binding(0) var<uniform> params: Params;

@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    let i = gid.x;
    let total = params.seq_len * params.hidden_dim;
    if (i >= total) { return; }

    let t = i / params.hidden_dim;
    let d = i % params.hidden_dim;
    let token_id = ids[t];
    if (token_id >= params.vocab_size) {
        out_buf[i] = 0.0;
        return;
    }
    out_buf[i] = table[token_id * params.hidden_dim + d];
}