File size: 933 Bytes
c3a1bdb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | // embedding.wgsl — token embedding lookup.
// For each token id, copy table[ids[i]] row into output.
// One thread per output element (seq_len * D threads total).
struct Params {
seq_len: u32,
hidden_dim: u32,
vocab_size: u32,
}
@group(0) @binding(0) var<storage, read> ids: array<u32>;
@group(0) @binding(1) var<storage, read> table: array<f32>;
@group(0) @binding(2) var<storage, read_write> out_buf: array<f32>;
@group(1) @binding(0) var<uniform> params: Params;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
let total = params.seq_len * params.hidden_dim;
if (i >= total) { return; }
let t = i / params.hidden_dim;
let d = i % params.hidden_dim;
let token_id = ids[t];
if (token_id >= params.vocab_size) {
out_buf[i] = 0.0;
return;
}
out_buf[i] = table[token_id * params.hidden_dim + d];
}
|