File size: 2,638 Bytes
faa6fde | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | // sample.wgsl — temperature-based multinomial token sampler.
// Single workgroup, full-vocab softmax + sequential CDF walk.
//
// Phase 1: parallel max-reduce of (logits / temperature)
// Phase 2: parallel sum of exp(scaled - max)
// Phase 3: thread 0 walks CDF with xorshift32 RNG
const WG_SIZE: u32 = 256u;
struct Params {
VOCAB: u32,
inv_temperature: f32,
rng_seed: u32,
}
@group(0) @binding(0) var<storage, read> logits: array<f32>;
@group(0) @binding(1) var<storage, read_write> out_token: array<u32>;
@group(1) @binding(0) var<uniform> params: Params;
var<workgroup> partial: array<f32, 256>;
var<workgroup> global_max: f32;
var<workgroup> global_sum: f32;
fn xorshift32(x_in: u32) -> u32 {
var x = x_in;
x = x ^ (x << 13u);
x = x ^ (x >> 17u);
x = x ^ (x << 5u);
return x;
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(local_invocation_id) lid: vec3<u32>) {
let tid = lid.x;
// Phase 1: parallel max
var my_max: f32 = -3.402823466e+38;
var i: u32 = tid;
loop {
if (i >= params.VOCAB) { break; }
my_max = max(my_max, logits[i] * params.inv_temperature);
i = i + WG_SIZE;
}
partial[tid] = my_max;
workgroupBarrier();
var off: u32 = WG_SIZE / 2u;
loop {
if (off == 0u) { break; }
if (tid < off) { partial[tid] = max(partial[tid], partial[tid + off]); }
workgroupBarrier();
off = off >> 1u;
}
if (tid == 0u) { global_max = partial[0u]; }
workgroupBarrier();
// Phase 2: parallel sum of exp
var my_sum: f32 = 0.0;
i = tid;
loop {
if (i >= params.VOCAB) { break; }
my_sum = my_sum + exp(logits[i] * params.inv_temperature - global_max);
i = i + WG_SIZE;
}
partial[tid] = my_sum;
workgroupBarrier();
off = WG_SIZE / 2u;
loop {
if (off == 0u) { break; }
if (tid < off) { partial[tid] = partial[tid] + partial[tid + off]; }
workgroupBarrier();
off = off >> 1u;
}
if (tid == 0u) { global_sum = partial[0u]; }
workgroupBarrier();
// Phase 3: thread 0 walks CDF
if (tid == 0u) {
let rng = xorshift32(params.rng_seed);
let threshold = f32(rng) * (1.0 / 4294967296.0) * global_sum;
var cum: f32 = 0.0;
var chosen: u32 = params.VOCAB - 1u;
var j: u32 = 0u;
loop {
if (j >= params.VOCAB) { break; }
cum = cum + exp(logits[j] * params.inv_temperature - global_max);
if (cum >= threshold) { chosen = j; break; }
j = j + 1u;
}
out_token[0u] = chosen;
}
}
|