LJTSG commited on
Commit
faa6fde
·
verified ·
1 Parent(s): 1399c22

Upload shaders/sample.wgsl with huggingface_hub

Browse files
Files changed (1) hide show
  1. shaders/sample.wgsl +93 -0
shaders/sample.wgsl ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // sample.wgsl — temperature-based multinomial token sampler.
2
+ // Single workgroup, full-vocab softmax + sequential CDF walk.
3
+ //
4
+ // Phase 1: parallel max-reduce of (logits / temperature)
5
+ // Phase 2: parallel sum of exp(scaled - max)
6
+ // Phase 3: thread 0 walks CDF with xorshift32 RNG
7
+
8
+ const WG_SIZE: u32 = 256u;
9
+
10
+ struct Params {
11
+ VOCAB: u32,
12
+ inv_temperature: f32,
13
+ rng_seed: u32,
14
+ }
15
+
16
+ @group(0) @binding(0) var<storage, read> logits: array<f32>;
17
+ @group(0) @binding(1) var<storage, read_write> out_token: array<u32>;
18
+
19
+ @group(1) @binding(0) var<uniform> params: Params;
20
+
21
+ var<workgroup> partial: array<f32, 256>;
22
+ var<workgroup> global_max: f32;
23
+ var<workgroup> global_sum: f32;
24
+
25
+ fn xorshift32(x_in: u32) -> u32 {
26
+ var x = x_in;
27
+ x = x ^ (x << 13u);
28
+ x = x ^ (x >> 17u);
29
+ x = x ^ (x << 5u);
30
+ return x;
31
+ }
32
+
33
+ @compute @workgroup_size(WG_SIZE)
34
+ fn main(@builtin(local_invocation_id) lid: vec3<u32>) {
35
+ let tid = lid.x;
36
+
37
+ // Phase 1: parallel max
38
+ var my_max: f32 = -3.402823466e+38;
39
+ var i: u32 = tid;
40
+ loop {
41
+ if (i >= params.VOCAB) { break; }
42
+ my_max = max(my_max, logits[i] * params.inv_temperature);
43
+ i = i + WG_SIZE;
44
+ }
45
+ partial[tid] = my_max;
46
+ workgroupBarrier();
47
+ var off: u32 = WG_SIZE / 2u;
48
+ loop {
49
+ if (off == 0u) { break; }
50
+ if (tid < off) { partial[tid] = max(partial[tid], partial[tid + off]); }
51
+ workgroupBarrier();
52
+ off = off >> 1u;
53
+ }
54
+ if (tid == 0u) { global_max = partial[0u]; }
55
+ workgroupBarrier();
56
+
57
+ // Phase 2: parallel sum of exp
58
+ var my_sum: f32 = 0.0;
59
+ i = tid;
60
+ loop {
61
+ if (i >= params.VOCAB) { break; }
62
+ my_sum = my_sum + exp(logits[i] * params.inv_temperature - global_max);
63
+ i = i + WG_SIZE;
64
+ }
65
+ partial[tid] = my_sum;
66
+ workgroupBarrier();
67
+ off = WG_SIZE / 2u;
68
+ loop {
69
+ if (off == 0u) { break; }
70
+ if (tid < off) { partial[tid] = partial[tid] + partial[tid + off]; }
71
+ workgroupBarrier();
72
+ off = off >> 1u;
73
+ }
74
+ if (tid == 0u) { global_sum = partial[0u]; }
75
+ workgroupBarrier();
76
+
77
+ // Phase 3: thread 0 walks CDF
78
+ if (tid == 0u) {
79
+ let rng = xorshift32(params.rng_seed);
80
+ let threshold = f32(rng) * (1.0 / 4294967296.0) * global_sum;
81
+
82
+ var cum: f32 = 0.0;
83
+ var chosen: u32 = params.VOCAB - 1u;
84
+ var j: u32 = 0u;
85
+ loop {
86
+ if (j >= params.VOCAB) { break; }
87
+ cum = cum + exp(logits[j] * params.inv_temperature - global_max);
88
+ if (cum >= threshold) { chosen = j; break; }
89
+ j = j + 1u;
90
+ }
91
+ out_token[0u] = chosen;
92
+ }
93
+ }