mamba-webgpu / shaders /embedding.wgsl
LJTSG's picture
Upload shaders/embedding.wgsl with huggingface_hub
c3a1bdb verified
raw
history blame contribute delete
933 Bytes
// embedding.wgsl — token embedding lookup.
// For each token id, copy table[ids[i]] row into output.
// One thread per output element (seq_len * D threads total).
struct Params {
seq_len: u32,
hidden_dim: u32,
vocab_size: u32,
}
@group(0) @binding(0) var<storage, read> ids: array<u32>;
@group(0) @binding(1) var<storage, read> table: array<f32>;
@group(0) @binding(2) var<storage, read_write> out_buf: array<f32>;
@group(1) @binding(0) var<uniform> params: Params;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
let total = params.seq_len * params.hidden_dim;
if (i >= total) { return; }
let t = i / params.hidden_dim;
let d = i % params.hidden_dim;
let token_id = ids[t];
if (token_id >= params.vocab_size) {
out_buf[i] = 0.0;
return;
}
out_buf[i] = table[token_id * params.hidden_dim + d];
}