| #version 450
|
|
|
| #include "types.glsl"
|
| #include "generic_binary_head.glsl"
|
|
|
| // false for SET, true for ACC
|
| layout(constant_id = 1) const bool ACC = true |
|
|
| layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in |
|
|
| void main() {
|
| const uint idx = gl_GlobalInvocationID.x |
| if (idx >= p.ne) {
|
| return |
| }
|
|
|
| const uint offset = p.param3 |
| const uint src1_i = idx - offset |
| const uint i3 = src1_i / p.nb03 |
| const uint rem2 = src1_i - i3 * p.nb03 |
| const uint i2 = rem2 / p.nb02 |
| const uint rem1 = rem2 - i2 * p.nb02 |
| const uint i1 = rem1 / p.nb01 |
| const uint i0 = rem1 % p.nb01 |
|
|
| uint i00, i01, i02, i03 |
|
|
| if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
|
| if (ACC) {
|
| data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)])) |
| } else {
|
| data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)])) |
| }
|
| } else {
|
| data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx])) |
| }
|
| }
|
|
|