|
|
#version 450 |
|
|
|
|
|
#extension GL_EXT_control_flow_attributes : enable |
|
|
#ifdef COOPMAT2 |
|
|
#extension GL_NV_cooperative_matrix2 : enable |
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require |
|
|
#extension GL_KHR_memory_scope_semantics : enable |
|
|
#endif |
|
|
|
|
|
#ifdef USE_COLLECTIVES |
|
|
# extension GL_KHR_shader_subgroup_shuffle : enable |
|
|
#endif |
|
|
|
|
|
#include "types.glsl" |
|
|
|
|
|
|
|
|
layout(binding = 0) readonly buffer A { |
|
|
A_TYPE knl_data[]; |
|
|
}; |
|
|
|
|
|
layout(binding = 1) readonly buffer B { |
|
|
B_TYPE src_data[]; |
|
|
}; |
|
|
|
|
|
layout(binding = 2) writeonly buffer D { |
|
|
D_TYPE dst_data[]; |
|
|
}; |
|
|
|
|
|
layout(push_constant) uniform parameter { |
|
|
|
|
|
uint32_t Cout; |
|
|
uint32_t Cin; |
|
|
uint32_t N; |
|
|
|
|
|
|
|
|
uint32_t KW; |
|
|
uint32_t KH; |
|
|
uint32_t W; |
|
|
uint32_t H; |
|
|
uint32_t OW; |
|
|
uint32_t OH; |
|
|
|
|
|
|
|
|
uint32_t s0; |
|
|
uint32_t s1; |
|
|
uint32_t p0; |
|
|
uint32_t p1; |
|
|
uint32_t d0; |
|
|
uint32_t d1; |
|
|
|
|
|
|
|
|
uint32_t nb01; |
|
|
uint32_t nb02; |
|
|
uint32_t nb03; |
|
|
|
|
|
uint32_t nb11; |
|
|
uint32_t nb12; |
|
|
uint32_t nb13; |
|
|
|
|
|
uint32_t nb1; |
|
|
uint32_t nb2; |
|
|
uint32_t nb3; |
|
|
|
|
|
|
|
|
uint32_t KWmp; uint32_t KWL; |
|
|
uint32_t KWKHmp; uint32_t KWKHL; |
|
|
uint32_t OWmp; uint32_t OWL; |
|
|
uint32_t OWOHmp; uint32_t OWOHL; |
|
|
#ifdef TRANSPOSE |
|
|
uint32_t s0mp; uint32_t s0L; |
|
|
uint32_t s1mp; uint32_t s1L; |
|
|
#endif |
|
|
} |
|
|
|
|
|
p; |
|
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; |
|
|
|
|
|
layout(constant_id = 1) const uint BS_K = 128; |
|
|
layout(constant_id = 2) const uint BS_CRS = 16; |
|
|
layout(constant_id = 3) const uint BS_NPQ = 128; |
|
|
|
|
|
layout(constant_id = 4) const uint TS_K = 8; |
|
|
layout(constant_id = 5) const uint use_collectives = 1; |
|
|
layout(constant_id = 6) const uint SHMEM_PAD = 4; |
|
|
|
|
|
uint32_t tid = gl_LocalInvocationID.x; |
|
|
const uint32_t WG_SIZE = gl_WorkGroupSize.x; |
|
|
|
|
|
uint splitWork(uint work_size, uint block_size) { |
|
|
return (block_size + work_size - 1) / block_size; |
|
|
} |
|
|
|
|
|
uint32_t K = p.Cout; |
|
|
uint32_t CRS = p.Cin * p.KH * p.KW; |
|
|
uint32_t NPQ = p.N * p.OH * p.OW; |
|
|
|
|
|
uint32_t n_elems_out = K * NPQ; |
|
|
|
|
|
|
|
|
uint32_t NB_CRS = splitWork(CRS, BS_CRS); |
|
|
|
|
|
#ifdef COOPMAT2 |
|
|
#define SHMEM_TYPE float16_t |
|
|
#else |
|
|
#define SHMEM_TYPE float |
|
|
#endif |
|
|
|
|
|
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD; |
|
|
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD; |
|
|
|
|
|
const uint32_t Ash_numel = BS_K * BS_CRS; |
|
|
const uint32_t Bsh_numel = BS_CRS * BS_NPQ; |
|
|
|
|
|
const uint32_t Ash_len = BS_K * Ash_stride; |
|
|
const uint32_t Bsh_len = BS_CRS * Bsh_stride; |
|
|
|
|
|
shared SHMEM_TYPE Ash[Ash_len]; |
|
|
shared SHMEM_TYPE Bsh[Bsh_len]; |
|
|
|
|
|
|
|
|
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; |
|
|
|
|
|
|
|
|
const uint32_t NT_K = BS_K / TS_K; |
|
|
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uint32_t B_idx_K = gl_WorkGroupID.x; |
|
|
uint32_t B_idx_NPQ = gl_WorkGroupID.y; |
|
|
|
|
|
uint32_t T_y = tid / NT_NPQ; |
|
|
uint32_t T_x = tid % NT_NPQ; |
|
|
|
|
|
uint32_t Ar = tid / BS_CRS; |
|
|
uint32_t Ac = tid % BS_CRS; |
|
|
const uint32_t ArpWg = WG_SIZE / BS_CRS; |
|
|
|
|
|
uint32_t Br = tid / BS_NPQ; |
|
|
uint32_t Bc = tid % BS_NPQ; |
|
|
const uint32_t BrpWg = WG_SIZE / BS_NPQ; |
|
|
|
|
|
|
|
|
uint fastdiv(uint n, uint mp, uint L) { |
|
|
uint msbs, lsbs; |
|
|
|
|
|
umulExtended(n, mp, msbs, lsbs); |
|
|
return (msbs + n) >> L; |
|
|
} |
|
|
|
|
|
#ifdef COOPMAT2 |
|
|
#define ACC_TYPE float16_t |
|
|
|
|
|
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem) |
|
|
{ |
|
|
uint32_t K_idx = B_idx_K * BS_K + r; |
|
|
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c; |
|
|
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); |
|
|
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); |
|
|
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; |
|
|
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; |
|
|
if (K_idx < K && NPQ_idx < NPQ) { |
|
|
dst_data[dst_idx] = D_TYPE(elem); |
|
|
} |
|
|
return elem; |
|
|
} |
|
|
#endif |
|
|
|
|
|
void main() { |
|
|
#ifdef COOPMAT2 |
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC; |
|
|
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0); |
|
|
#else |
|
|
float regC[TS_K][TS_NPQ]; |
|
|
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { |
|
|
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { |
|
|
regC[T_ly][T_lx] = 0.0; |
|
|
} |
|
|
} |
|
|
#endif |
|
|
|
|
|
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { |
|
|
uint32_t CRS_idx_a; |
|
|
uint32_t Cin_idx_a; |
|
|
uint32_t KH_idx_a; |
|
|
uint32_t KW_idx_a; |
|
|
|
|
|
#ifdef USE_COLLECTIVES |
|
|
uint32_t cached_CRS_idx; |
|
|
uint32_t cached_Cin_idx; |
|
|
uint32_t cached_KH_idx; |
|
|
uint32_t cached_KW_idx; |
|
|
if (use_collectives == 1) { |
|
|
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID; |
|
|
cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); |
|
|
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH); |
|
|
cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); |
|
|
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW; |
|
|
|
|
|
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac); |
|
|
Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac); |
|
|
KH_idx_a = subgroupShuffle(cached_KH_idx, Ac); |
|
|
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac); |
|
|
} else { |
|
|
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; |
|
|
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); |
|
|
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; |
|
|
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); |
|
|
KW_idx_a = CRS_remainder - KH_idx_a * p.KW; |
|
|
} |
|
|
#else |
|
|
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; |
|
|
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); |
|
|
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; |
|
|
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); |
|
|
KW_idx_a = CRS_remainder - KH_idx_a * p.KW; |
|
|
#endif |
|
|
|
|
|
|
|
|
for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { |
|
|
uint32_t B_ly = r_offset + Ar; |
|
|
uint32_t B_lx = Ac; |
|
|
uint32_t K_idx = B_idx_K * BS_K + B_ly; |
|
|
#ifdef TRANSPOSE |
|
|
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1); |
|
|
#else |
|
|
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); |
|
|
#endif |
|
|
float val = knl_data[knl_idx]; |
|
|
if (K_idx >= K || CRS_idx_a >= CRS) { |
|
|
val = 0.0; |
|
|
} |
|
|
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); |
|
|
} |
|
|
|
|
|
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { |
|
|
uint32_t B_ly = r_offset + Br; |
|
|
uint32_t B_lx = Bc; |
|
|
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; |
|
|
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); |
|
|
uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW; |
|
|
uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); |
|
|
uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW; |
|
|
|
|
|
uint32_t CRS_idx_b; |
|
|
uint32_t Cin_idx_b; |
|
|
uint32_t KH_idx_b; |
|
|
uint32_t KW_idx_b; |
|
|
#ifdef USE_COLLECTIVES |
|
|
if (use_collectives == 1) { |
|
|
CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br); |
|
|
Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br); |
|
|
KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br); |
|
|
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br); |
|
|
} else { |
|
|
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; |
|
|
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); |
|
|
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; |
|
|
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); |
|
|
KW_idx_b = CRS_remainder - KH_idx_b * p.KW; |
|
|
} |
|
|
#else |
|
|
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; |
|
|
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); |
|
|
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; |
|
|
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); |
|
|
KW_idx_b = CRS_remainder - KH_idx_b * p.KW; |
|
|
#endif |
|
|
|
|
|
#ifdef TRANSPOSE |
|
|
uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1; |
|
|
uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0; |
|
|
uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L); |
|
|
uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L); |
|
|
#else |
|
|
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1; |
|
|
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0; |
|
|
#endif |
|
|
uint32_t src_idx = |
|
|
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); |
|
|
float val = src_data[src_idx]; |
|
|
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ |
|
|
|| H_idx >= p.H || W_idx >= p.W |
|
|
#ifdef TRANSPOSE |
|
|
|| (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0) |
|
|
#endif |
|
|
) { |
|
|
val = 0.0; |
|
|
} |
|
|
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); |
|
|
} |
|
|
barrier(); |
|
|
#ifdef COOPMAT2 |
|
|
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA; |
|
|
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB; |
|
|
|
|
|
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); |
|
|
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); |
|
|
matC = coopMatMulAdd(matA, matB, matC); |
|
|
#else |
|
|
if (T_y * TS_K < K) { |
|
|
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { |
|
|
float regA[TS_K]; |
|
|
float regB[TS_NPQ]; |
|
|
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { |
|
|
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; |
|
|
} |
|
|
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { |
|
|
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx]; |
|
|
} |
|
|
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { |
|
|
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { |
|
|
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
#endif |
|
|
barrier(); |
|
|
} |
|
|
|
|
|
#ifdef COOPMAT2 |
|
|
coopMatPerElementNV(matC, matC, perElemOpStore); |
|
|
#else |
|
|
if (T_y * TS_K < K) { |
|
|
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { |
|
|
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { |
|
|
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; |
|
|
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; |
|
|
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); |
|
|
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); |
|
|
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; |
|
|
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; |
|
|
if (K_idx < K && NPQ_idx < NPQ) { |
|
|
dst_data[dst_idx] = regC[T_ly][T_lx]; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
#endif |
|
|
} |
|
|
|