File size: 2,638 Bytes
faa6fde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
// sample.wgsl — temperature-based multinomial token sampler.
// Single workgroup, full-vocab softmax + sequential CDF walk.
//
// Phase 1: parallel max-reduce of (logits / temperature)
// Phase 2: parallel sum of exp(scaled - max)
// Phase 3: thread 0 walks CDF with xorshift32 RNG

const WG_SIZE: u32 = 256u;

struct Params {
    VOCAB: u32,
    inv_temperature: f32,
    rng_seed: u32,
}

@group(0) @binding(0) var<storage, read>       logits: array<f32>;
@group(0) @binding(1) var<storage, read_write> out_token: array<u32>;

@group(1) @binding(0) var<uniform> params: Params;

var<workgroup> partial: array<f32, 256>;
var<workgroup> global_max: f32;
var<workgroup> global_sum: f32;

fn xorshift32(x_in: u32) -> u32 {
    var x = x_in;
    x = x ^ (x << 13u);
    x = x ^ (x >> 17u);
    x = x ^ (x << 5u);
    return x;
}

@compute @workgroup_size(WG_SIZE)
fn main(@builtin(local_invocation_id) lid: vec3<u32>) {
    let tid = lid.x;

    // Phase 1: parallel max
    var my_max: f32 = -3.402823466e+38;
    var i: u32 = tid;
    loop {
        if (i >= params.VOCAB) { break; }
        my_max = max(my_max, logits[i] * params.inv_temperature);
        i = i + WG_SIZE;
    }
    partial[tid] = my_max;
    workgroupBarrier();
    var off: u32 = WG_SIZE / 2u;
    loop {
        if (off == 0u) { break; }
        if (tid < off) { partial[tid] = max(partial[tid], partial[tid + off]); }
        workgroupBarrier();
        off = off >> 1u;
    }
    if (tid == 0u) { global_max = partial[0u]; }
    workgroupBarrier();

    // Phase 2: parallel sum of exp
    var my_sum: f32 = 0.0;
    i = tid;
    loop {
        if (i >= params.VOCAB) { break; }
        my_sum = my_sum + exp(logits[i] * params.inv_temperature - global_max);
        i = i + WG_SIZE;
    }
    partial[tid] = my_sum;
    workgroupBarrier();
    off = WG_SIZE / 2u;
    loop {
        if (off == 0u) { break; }
        if (tid < off) { partial[tid] = partial[tid] + partial[tid + off]; }
        workgroupBarrier();
        off = off >> 1u;
    }
    if (tid == 0u) { global_sum = partial[0u]; }
    workgroupBarrier();

    // Phase 3: thread 0 walks CDF
    if (tid == 0u) {
        let rng = xorshift32(params.rng_seed);
        let threshold = f32(rng) * (1.0 / 4294967296.0) * global_sum;

        var cum: f32 = 0.0;
        var chosen: u32 = params.VOCAB - 1u;
        var j: u32 = 0u;
        loop {
            if (j >= params.VOCAB) { break; }
            cum = cum + exp(logits[j] * params.inv_temperature - global_max);
            if (cum >= threshold) { chosen = j; break; }
            j = j + 1u;
        }
        out_token[0u] = chosen;
    }
}