| | #version 450 |
| |
|
| | #extension GL_EXT_control_flow_attributes : enable |
| | #extension GL_EXT_shader_16bit_storage : require |
| |
|
| | #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require |
| | #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require |
| |
|
| | #extension GL_KHR_shader_subgroup_shuffle : enable |
| |
|
| | #include "types.glsl" |
| | #include "flash_attn_base.glsl" |
| |
|
| | const uint32_t HSK_per_thread = HSK / D_split; |
| | const uint32_t HSV_per_thread = HSV / D_split; |
| |
|
| | const uint32_t cols_per_iter = WorkGroupSize / D_split; |
| | const uint32_t cols_per_thread = Bc / cols_per_iter; |
| |
|
| |
|
| | layout (binding = 0) readonly buffer Q {float data_q[];}; |
| | layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; |
| | layout (binding = 1) readonly buffer K {float16_t data_k[];}; |
| | layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; |
| | layout (binding = 2) readonly buffer V {float16_t data_v[];}; |
| | layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; |
| | layout (binding = 3) readonly buffer M {float16_t data_m[];}; |
| |
|
| | |
| | |
| | D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) |
| | { |
| | uint32_t offset = (iq2 + r) * HSV + c; |
| | data_o[o_offset + offset] = D_TYPE(elem); |
| | return elem; |
| | } |
| |
|
| | shared FLOAT_TYPE tmpsh[WorkGroupSize]; |
| | shared vec4 tmpshv4[WorkGroupSize]; |
| |
|
| | shared float masksh[Bc][Br]; |
| | shared vec4 Qf[Br][HSK / 4]; |
| |
|
| | void main() { |
| | #ifdef NEEDS_INIT_IQ_SHMEM |
| | init_iq_shmem(gl_WorkGroupSize); |
| | #endif |
| |
|
| | init_indices(); |
| |
|
| | const uint32_t tid = gl_LocalInvocationIndex; |
| | const uint32_t d_tid = gl_LocalInvocationIndex % D_split; |
| | const uint32_t col_tid = gl_LocalInvocationIndex / D_split; |
| |
|
| | uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; |
| |
|
| | [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { |
| | uint32_t d = (idx + tid) % (HSK / 4); |
| | uint32_t r = (idx + tid) / (HSK / 4); |
| | if (r < Br && d < HSK / 4 && |
| | i * Br + r < N) { |
| | Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; |
| | } |
| | } |
| | barrier(); |
| |
|
| | vec4 Of[Br][HSV_per_thread / 4]; |
| | [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | Of[r][d] = vec4(0.0); |
| | } |
| | } |
| |
|
| | float Lf[Br], Mf[Br]; |
| |
|
| | |
| | const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); |
| |
|
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | Lf[r] = 0; |
| | Mf[r] = NEG_FLT_MAX_OVER_2; |
| | } |
| |
|
| | float slope[Br]; |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | slope[r] = 1.0; |
| | } |
| |
|
| | |
| | if (p.max_bias > 0.0f) { |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); |
| | } |
| | } |
| |
|
| | #if BLOCK_SIZE > 1 |
| | uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; |
| | uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; |
| | #else |
| | uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; |
| | uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; |
| | #endif |
| | uint32_t m_offset = 0; |
| | if (p.nem2 != 1 || p.nem3 != 1) { |
| | m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; |
| | } |
| |
|
| | [[dont_unroll]] |
| | for (uint32_t j = start_j; j < end_j; ++j) { |
| |
|
| | float Sf[Br][cols_per_thread]; |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { |
| | Sf[r][c] = 0.0; |
| | } |
| | } |
| |
|
| |
|
| | [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { |
| | if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { |
| | continue; |
| | } |
| | [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { |
| | #if BLOCK_SIZE > 1 |
| | uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); |
| | uint ib = coord / BLOCK_SIZE; |
| | uint iqs = (coord % BLOCK_SIZE); |
| | vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); |
| | #else |
| | vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); |
| | #endif |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); |
| | } |
| | } |
| | } |
| |
|
| | [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { |
| | |
| | [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); |
| | } |
| | } |
| | } |
| |
|
| | if (p.logit_softcap != 0.0f) { |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { |
| | Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); |
| | } |
| | } |
| | } |
| |
|
| | if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { |
| | bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; |
| |
|
| | [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { |
| | uint32_t c = (idx + tid) % Bc; |
| | uint32_t r = (idx + tid) / Bc; |
| | if (idx + tid < Bc * Br) { |
| | if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { |
| | masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); |
| | } else { |
| | masksh[c][r] = float(0); |
| | } |
| | } |
| | } |
| | barrier(); |
| |
|
| | [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | float mvf = masksh[c * cols_per_iter + col_tid][r]; |
| |
|
| | Sf[r][c] += slope[r]*mvf; |
| | } |
| | } |
| | barrier(); |
| | } |
| |
|
| | float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | rowmaxf[r] = NEG_FLT_MAX_OVER_2; |
| | [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { |
| | if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { |
| | continue; |
| | } |
| | rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); |
| | } |
| | Moldf[r] = Mf[r]; |
| |
|
| | |
| | |
| | |
| | Mf[r] = max(rowmaxf[r], Moldf[r]); |
| | [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { |
| | Pf[r][c] = exp(Sf[r][c] - Mf[r]); |
| | } |
| | eMf[r] = exp(Moldf[r] - Mf[r]); |
| |
|
| | |
| | rowsumf[r] = 0.0; |
| | [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { |
| | if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { |
| | continue; |
| | } |
| | rowsumf[r] += Pf[r][c]; |
| | } |
| |
|
| | Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; |
| | } |
| |
|
| | [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | Of[r][d] = eMf[r] * Of[r][d]; |
| | } |
| | } |
| |
|
| | [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { |
| | if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { |
| | continue; |
| | } |
| | [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { |
| | #if BLOCK_SIZE > 1 |
| | uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); |
| | uint ib = coord / BLOCK_SIZE; |
| | uint iqs = (coord % BLOCK_SIZE); |
| | vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); |
| | #else |
| | vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); |
| | #endif |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | Of[r][d] += Pf[r][c] * Vf; |
| | } |
| | } |
| | } |
| |
|
| | barrier(); |
| | } |
| |
|
| | |
| |
|
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | float rowmaxf, eMf; |
| |
|
| | tmpsh[tid] = Mf[r]; |
| | |
| | barrier(); |
| | [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { |
| | if (tid < s) { |
| | tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); |
| | } |
| | barrier(); |
| | } |
| | rowmaxf = tmpsh[d_tid]; |
| | barrier(); |
| |
|
| | float Moldf = Mf[r]; |
| |
|
| | |
| | |
| | Mf[r] = max(rowmaxf, Moldf); |
| | eMf = exp(Moldf - Mf[r]); |
| |
|
| | Lf[r] = eMf*Lf[r]; |
| |
|
| | tmpsh[tid] = Lf[r]; |
| |
|
| | |
| | barrier(); |
| | [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { |
| | if (tid < s) { |
| | tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; |
| | } |
| | barrier(); |
| | } |
| | Lf[r] = tmpsh[d_tid]; |
| | barrier(); |
| |
|
| | [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { |
| |
|
| | Of[r][d] = eMf * Of[r][d]; |
| | tmpshv4[tid] = Of[r][d]; |
| |
|
| | barrier(); |
| | [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { |
| | if (tid < s) { |
| | Of[r][d] += tmpshv4[tid + s]; |
| | tmpshv4[tid] = Of[r][d]; |
| | } |
| | barrier(); |
| | } |
| | Of[r][d] = tmpshv4[d_tid]; |
| | barrier(); |
| | } |
| | } |
| |
|
| |
|
| | |
| | |
| | if (p.k_num > 1) { |
| | uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); |
| |
|
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | if (r < N) { |
| | [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { |
| | [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { |
| | perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); |
| | } |
| | } |
| | } |
| | } |
| |
|
| | o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | if (r < N) { |
| | perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); |
| | perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); |
| | } |
| | } |
| |
|
| | return; |
| | } |
| |
|
| | if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); |
| |
|
| | float ms = 1.0f; |
| | float vs = 1.0f; |
| |
|
| | if (sink > Mf[r]) { |
| | ms = exp(Mf[r] - sink); |
| |
|
| | [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { |
| | Of[r][d] *= ms; |
| | } |
| | } else { |
| | vs = exp(sink - Mf[r]); |
| | } |
| |
|
| | Lf[r] = Lf[r]*ms + vs; |
| | } |
| | } |
| |
|
| | float Lfrcp[Br]; |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | Lfrcp[r] = 1.0 / Lf[r]; |
| | } |
| |
|
| | [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | Of[r][d] *= Lfrcp[r]; |
| | #if defined(ACC_TYPE_MAX) |
| | Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX)); |
| | #endif |
| | } |
| | } |
| |
|
| | uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; |
| |
|
| | if (p.gqa_ratio > 1) { |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | if (r < N) { |
| | [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { |
| | [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { |
| | perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); |
| | } |
| | } |
| | } |
| | } |
| | } else { |
| | [[unroll]] for (uint32_t r = 0; r < Br; ++r) { |
| | if (i * Br + r < N) { |
| | [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { |
| | [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { |
| | data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); |
| | } |
| | } |
| | } |
| | } |
| | } |
| | } |
| |
|