LJTSG commited on
Commit
c3a1bdb
·
verified ·
1 Parent(s): 4cee88d

Upload shaders/embedding.wgsl with huggingface_hub

Browse files
Files changed (1) hide show
  1. shaders/embedding.wgsl +31 -0
shaders/embedding.wgsl ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // embedding.wgsl — token embedding lookup.
2
+ // For each token id, copy table[ids[i]] row into output.
3
+ // One thread per output element (seq_len * D threads total).
4
+
5
+ struct Params {
6
+ seq_len: u32,
7
+ hidden_dim: u32,
8
+ vocab_size: u32,
9
+ }
10
+
11
+ @group(0) @binding(0) var<storage, read> ids: array<u32>;
12
+ @group(0) @binding(1) var<storage, read> table: array<f32>;
13
+ @group(0) @binding(2) var<storage, read_write> out_buf: array<f32>;
14
+
15
+ @group(1) @binding(0) var<uniform> params: Params;
16
+
17
+ @compute @workgroup_size(64)
18
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
19
+ let i = gid.x;
20
+ let total = params.seq_len * params.hidden_dim;
21
+ if (i >= total) { return; }
22
+
23
+ let t = i / params.hidden_dim;
24
+ let d = i % params.hidden_dim;
25
+ let token_id = ids[t];
26
+ if (token_id >= params.vocab_size) {
27
+ out_buf[i] = 0.0;
28
+ return;
29
+ }
30
+ out_buf[i] = table[token_id * params.hidden_dim + d];
31
+ }