File size: 2,062 Bytes
4d35814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#version 450

#extension GL_EXT_shader_16bit_storage : require
#if ADD_RMS
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
#endif

#include "types.glsl"
#include "generic_binary_head.glsl"

const uint num_threads = 256;

layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};

layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;

#if ADD_RMS
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
shared FLOAT_TYPE sumsh[num_threads];
#endif

void main() {
    uint idx = get_idx();
    uint orig_idx = idx;

    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
    const uint num_iter = 2;

    FLOAT_TYPE sum_sq = 0;

    [[unroll]] for (uint i = 0; i < num_iter; ++i) {
        if (idx >= p.ne) {
            continue;
        }
        uint i00, i01, i02, i03;
        get_indices(idx, i00, i01, i02, i03);

        FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
        sum_sq += sum*sum;

        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);

        idx += num_threads;
    }

#if ADD_RMS
    if (p.param3 != 0) {
        // reduce the sum within each subgroup, then across subgroups
        const uint NumSubgroups = num_threads / gl_SubgroupSize;
        sum_sq = subgroupAdd(sum_sq);
        if (gl_SubgroupInvocationID == 0) {
            sumsh[gl_SubgroupID] = sum_sq;
        }
        barrier();
        [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
            if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
                sum_sq += sumsh[gl_SubgroupID + s];
                sumsh[gl_SubgroupID] = sum_sq;
            }
            barrier();
        }

        if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
            partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
        }
    }
#endif
}