// embedding.wgsl — token embedding lookup. // For each token id, copy table[ids[i]] row into output. // One thread per output element (seq_len * D threads total). struct Params { seq_len: u32, hidden_dim: u32, vocab_size: u32, } @group(0) @binding(0) var ids: array; @group(0) @binding(1) var table: array; @group(0) @binding(2) var out_buf: array; @group(1) @binding(0) var params: Params; @compute @workgroup_size(64) fn main(@builtin(global_invocation_id) gid: vec3) { let i = gid.x; let total = params.seq_len * params.hidden_dim; if (i >= total) { return; } let t = i / params.hidden_dim; let d = i % params.hidden_dim; let token_id = ids[t]; if (token_id >= params.vocab_size) { out_buf[i] = 0.0; return; } out_buf[i] = table[token_id * params.hidden_dim + d]; }