| |
| |
| |
| |
| |
| |
|
|
| const WG_SIZE: u32 = 256u; |
|
|
| struct Params { |
| VOCAB: u32, |
| inv_temperature: f32, |
| rng_seed: u32, |
| } |
|
|
| @group(0) @binding(0) var<storage, read> logits: array<f32>; |
| @group(0) @binding(1) var<storage, read_write> out_token: array<u32>; |
|
|
| @group(1) @binding(0) var<uniform> params: Params; |
|
|
| var<workgroup> partial: array<f32, 256>; |
| var<workgroup> global_max: f32; |
| var<workgroup> global_sum: f32; |
|
|
| fn xorshift32(x_in: u32) -> u32 { |
| var x = x_in; |
| x = x ^ (x << 13u); |
| x = x ^ (x >> 17u); |
| x = x ^ (x << 5u); |
| return x; |
| } |
|
|
| @compute @workgroup_size(WG_SIZE) |
| fn main(@builtin(local_invocation_id) lid: vec3<u32>) { |
| let tid = lid.x; |
|
|
| |
| var my_max: f32 = -3.402823466e+38; |
| var i: u32 = tid; |
| loop { |
| if (i >= params.VOCAB) { break; } |
| my_max = max(my_max, logits[i] * params.inv_temperature); |
| i = i + WG_SIZE; |
| } |
| partial[tid] = my_max; |
| workgroupBarrier(); |
| var off: u32 = WG_SIZE / 2u; |
| loop { |
| if (off == 0u) { break; } |
| if (tid < off) { partial[tid] = max(partial[tid], partial[tid + off]); } |
| workgroupBarrier(); |
| off = off >> 1u; |
| } |
| if (tid == 0u) { global_max = partial[0u]; } |
| workgroupBarrier(); |
|
|
| |
| var my_sum: f32 = 0.0; |
| i = tid; |
| loop { |
| if (i >= params.VOCAB) { break; } |
| my_sum = my_sum + exp(logits[i] * params.inv_temperature - global_max); |
| i = i + WG_SIZE; |
| } |
| partial[tid] = my_sum; |
| workgroupBarrier(); |
| off = WG_SIZE / 2u; |
| loop { |
| if (off == 0u) { break; } |
| if (tid < off) { partial[tid] = partial[tid] + partial[tid + off]; } |
| workgroupBarrier(); |
| off = off >> 1u; |
| } |
| if (tid == 0u) { global_sum = partial[0u]; } |
| workgroupBarrier(); |
|
|
| |
| if (tid == 0u) { |
| let rng = xorshift32(params.rng_seed); |
| let threshold = f32(rng) * (1.0 / 4294967296.0) * global_sum; |
|
|
| var cum: f32 = 0.0; |
| var chosen: u32 = params.VOCAB - 1u; |
| var j: u32 = 0u; |
| loop { |
| if (j >= params.VOCAB) { break; } |
| cum = cum + exp(logits[j] * params.inv_temperature - global_max); |
| if (cum >= threshold) { chosen = j; break; } |
| j = j + 1u; |
| } |
| out_token[0u] = chosen; |
| } |
| } |
|
|