// sample.wgsl — temperature-based multinomial token sampler. // Single workgroup, full-vocab softmax + sequential CDF walk. // // Phase 1: parallel max-reduce of (logits / temperature) // Phase 2: parallel sum of exp(scaled - max) // Phase 3: thread 0 walks CDF with xorshift32 RNG const WG_SIZE: u32 = 256u; struct Params { VOCAB: u32, inv_temperature: f32, rng_seed: u32, } @group(0) @binding(0) var logits: array; @group(0) @binding(1) var out_token: array; @group(1) @binding(0) var params: Params; var partial: array; var global_max: f32; var global_sum: f32; fn xorshift32(x_in: u32) -> u32 { var x = x_in; x = x ^ (x << 13u); x = x ^ (x >> 17u); x = x ^ (x << 5u); return x; } @compute @workgroup_size(WG_SIZE) fn main(@builtin(local_invocation_id) lid: vec3) { let tid = lid.x; // Phase 1: parallel max var my_max: f32 = -3.402823466e+38; var i: u32 = tid; loop { if (i >= params.VOCAB) { break; } my_max = max(my_max, logits[i] * params.inv_temperature); i = i + WG_SIZE; } partial[tid] = my_max; workgroupBarrier(); var off: u32 = WG_SIZE / 2u; loop { if (off == 0u) { break; } if (tid < off) { partial[tid] = max(partial[tid], partial[tid + off]); } workgroupBarrier(); off = off >> 1u; } if (tid == 0u) { global_max = partial[0u]; } workgroupBarrier(); // Phase 2: parallel sum of exp var my_sum: f32 = 0.0; i = tid; loop { if (i >= params.VOCAB) { break; } my_sum = my_sum + exp(logits[i] * params.inv_temperature - global_max); i = i + WG_SIZE; } partial[tid] = my_sum; workgroupBarrier(); off = WG_SIZE / 2u; loop { if (off == 0u) { break; } if (tid < off) { partial[tid] = partial[tid] + partial[tid + off]; } workgroupBarrier(); off = off >> 1u; } if (tid == 0u) { global_sum = partial[0u]; } workgroupBarrier(); // Phase 3: thread 0 walks CDF if (tid == 0u) { let rng = xorshift32(params.rng_seed); let threshold = f32(rng) * (1.0 / 4294967296.0) * global_sum; var cum: f32 = 0.0; var chosen: u32 = params.VOCAB - 1u; var j: u32 = 0u; loop { if (j >= params.VOCAB) { break; } cum = cum + exp(logits[j] * params.inv_temperature - global_max); if (cum >= threshold) { chosen = j; break; } j = j + 1u; } out_token[0u] = chosen; } }