Upload shaders/sample.wgsl with huggingface_hub
Browse files- shaders/sample.wgsl +93 -0
shaders/sample.wgsl
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// sample.wgsl — temperature-based multinomial token sampler.
|
| 2 |
+
// Single workgroup, full-vocab softmax + sequential CDF walk.
|
| 3 |
+
//
|
| 4 |
+
// Phase 1: parallel max-reduce of (logits / temperature)
|
| 5 |
+
// Phase 2: parallel sum of exp(scaled - max)
|
| 6 |
+
// Phase 3: thread 0 walks CDF with xorshift32 RNG
|
| 7 |
+
|
| 8 |
+
const WG_SIZE: u32 = 256u;
|
| 9 |
+
|
| 10 |
+
struct Params {
|
| 11 |
+
VOCAB: u32,
|
| 12 |
+
inv_temperature: f32,
|
| 13 |
+
rng_seed: u32,
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
@group(0) @binding(0) var<storage, read> logits: array<f32>;
|
| 17 |
+
@group(0) @binding(1) var<storage, read_write> out_token: array<u32>;
|
| 18 |
+
|
| 19 |
+
@group(1) @binding(0) var<uniform> params: Params;
|
| 20 |
+
|
| 21 |
+
var<workgroup> partial: array<f32, 256>;
|
| 22 |
+
var<workgroup> global_max: f32;
|
| 23 |
+
var<workgroup> global_sum: f32;
|
| 24 |
+
|
| 25 |
+
fn xorshift32(x_in: u32) -> u32 {
|
| 26 |
+
var x = x_in;
|
| 27 |
+
x = x ^ (x << 13u);
|
| 28 |
+
x = x ^ (x >> 17u);
|
| 29 |
+
x = x ^ (x << 5u);
|
| 30 |
+
return x;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@compute @workgroup_size(WG_SIZE)
|
| 34 |
+
fn main(@builtin(local_invocation_id) lid: vec3<u32>) {
|
| 35 |
+
let tid = lid.x;
|
| 36 |
+
|
| 37 |
+
// Phase 1: parallel max
|
| 38 |
+
var my_max: f32 = -3.402823466e+38;
|
| 39 |
+
var i: u32 = tid;
|
| 40 |
+
loop {
|
| 41 |
+
if (i >= params.VOCAB) { break; }
|
| 42 |
+
my_max = max(my_max, logits[i] * params.inv_temperature);
|
| 43 |
+
i = i + WG_SIZE;
|
| 44 |
+
}
|
| 45 |
+
partial[tid] = my_max;
|
| 46 |
+
workgroupBarrier();
|
| 47 |
+
var off: u32 = WG_SIZE / 2u;
|
| 48 |
+
loop {
|
| 49 |
+
if (off == 0u) { break; }
|
| 50 |
+
if (tid < off) { partial[tid] = max(partial[tid], partial[tid + off]); }
|
| 51 |
+
workgroupBarrier();
|
| 52 |
+
off = off >> 1u;
|
| 53 |
+
}
|
| 54 |
+
if (tid == 0u) { global_max = partial[0u]; }
|
| 55 |
+
workgroupBarrier();
|
| 56 |
+
|
| 57 |
+
// Phase 2: parallel sum of exp
|
| 58 |
+
var my_sum: f32 = 0.0;
|
| 59 |
+
i = tid;
|
| 60 |
+
loop {
|
| 61 |
+
if (i >= params.VOCAB) { break; }
|
| 62 |
+
my_sum = my_sum + exp(logits[i] * params.inv_temperature - global_max);
|
| 63 |
+
i = i + WG_SIZE;
|
| 64 |
+
}
|
| 65 |
+
partial[tid] = my_sum;
|
| 66 |
+
workgroupBarrier();
|
| 67 |
+
off = WG_SIZE / 2u;
|
| 68 |
+
loop {
|
| 69 |
+
if (off == 0u) { break; }
|
| 70 |
+
if (tid < off) { partial[tid] = partial[tid] + partial[tid + off]; }
|
| 71 |
+
workgroupBarrier();
|
| 72 |
+
off = off >> 1u;
|
| 73 |
+
}
|
| 74 |
+
if (tid == 0u) { global_sum = partial[0u]; }
|
| 75 |
+
workgroupBarrier();
|
| 76 |
+
|
| 77 |
+
// Phase 3: thread 0 walks CDF
|
| 78 |
+
if (tid == 0u) {
|
| 79 |
+
let rng = xorshift32(params.rng_seed);
|
| 80 |
+
let threshold = f32(rng) * (1.0 / 4294967296.0) * global_sum;
|
| 81 |
+
|
| 82 |
+
var cum: f32 = 0.0;
|
| 83 |
+
var chosen: u32 = params.VOCAB - 1u;
|
| 84 |
+
var j: u32 = 0u;
|
| 85 |
+
loop {
|
| 86 |
+
if (j >= params.VOCAB) { break; }
|
| 87 |
+
cum = cum + exp(logits[j] * params.inv_temperature - global_max);
|
| 88 |
+
if (cum >= threshold) { chosen = j; break; }
|
| 89 |
+
j = j + 1u;
|
| 90 |
+
}
|
| 91 |
+
out_token[0u] = chosen;
|
| 92 |
+
}
|
| 93 |
+
}
|