| // embedding.wgsl — token embedding lookup. | |
| // For each token id, copy table[ids[i]] row into output. | |
| // One thread per output element (seq_len * D threads total). | |
| struct Params { | |
| seq_len: u32, | |
| hidden_dim: u32, | |
| vocab_size: u32, | |
| } | |
| var<storage, read> ids: array<u32>; | |
| var<storage, read> table: array<f32>; | |
| var<storage, read_write> out_buf: array<f32>; | |
| var<uniform> params: Params; | |
| fn main( gid: vec3<u32>) { | |
| let i = gid.x; | |
| let total = params.seq_len * params.hidden_dim; | |
| if (i >= total) { return; } | |
| let t = i / params.hidden_dim; | |
| let d = i % params.hidden_dim; | |
| let token_id = ids[t]; | |
| if (token_id >= params.vocab_size) { | |
| out_buf[i] = 0.0; | |
| return; | |
| } | |
| out_buf[i] = table[token_id * params.hidden_dim + d]; | |
| } | |