| #include "common.cuh" |
| #include "fattn-common.cuh" |
| #include "fattn-mma-f16.cuh" |
| #include "fattn-tile-f16.cuh" |
| #include "fattn-tile-f32.cuh" |
| #include "fattn-vec-f16.cuh" |
| #include "fattn-vec-f32.cuh" |
| #include "fattn-wmma-f16.cuh" |
| #include "fattn.cuh" |
|
|
| template <int cols_per_block> |
| static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| const ggml_tensor * Q = dst->src[0]; |
|
|
| switch (Q->ne[0]) { |
| case 64: |
| ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst); |
| break; |
| case 80: |
| ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst); |
| break; |
| case 96: |
| ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst); |
| break; |
| case 112: |
| ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst); |
| break; |
| case 128: |
| ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst); |
| break; |
| case 256: |
| ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst); |
| break; |
| default: |
| GGML_ABORT("fatal error"); |
| break; |
| } |
| } |
|
|
| static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| const ggml_tensor * Q = dst->src[0]; |
|
|
| if (Q->ne[1] <= 8) { |
| ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); |
| return; |
| } |
|
|
| if (Q->ne[1] <= 16) { |
| ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst); |
| return; |
| } |
|
|
| if (Q->ne[1] <= 32) { |
| ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst); |
| return; |
| } |
|
|
| ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst); |
| } |
|
|
| #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ |
| if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ |
| ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \ |
| return; \ |
| } \ |
| |
| static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| ggml_tensor * Q = dst->src[0]; |
| ggml_tensor * K = dst->src[1]; |
| ggml_tensor * V = dst->src[2]; |
|
|
| #ifdef GGML_CUDA_FA_ALL_QUANTS |
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) |
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) |
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) |
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) |
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) |
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 ) |
|
|
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) |
|
|
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) |
|
|
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) |
|
|
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) |
|
|
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) |
|
|
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) |
|
|
| FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) |
| #else |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) |
|
|
| FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) |
|
|
| FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) |
| FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) |
| FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) |
| #endif |
|
|
| on_no_fattn_vec_case(Q->ne[0]); |
| } |
|
|
| #define FATTN_VEC_F32_CASE(D, type_K, type_V) \ |
| if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ |
| ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \ |
| return; \ |
| } \ |
| |
| static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| ggml_tensor * Q = dst->src[0]; |
| ggml_tensor * K = dst->src[1]; |
| ggml_tensor * V = dst->src[2]; |
|
|
| #ifdef GGML_CUDA_FA_ALL_QUANTS |
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) |
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) |
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) |
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) |
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) |
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) |
|
|
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) |
|
|
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) |
|
|
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) |
|
|
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) |
|
|
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) |
|
|
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) |
|
|
| FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) |
| #else |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) |
|
|
| FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) |
|
|
| FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) |
| FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) |
| FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) |
| #endif |
|
|
| on_no_fattn_vec_case(Q->ne[0]); |
| } |
|
|
| void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| const ggml_tensor * KQV = dst; |
| const ggml_tensor * Q = dst->src[0]; |
|
|
| ggml_cuda_set_device(ctx.device); |
| const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; |
| const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); |
|
|
| |
| if (cc >= GGML_CUDA_CC_OFFSET_AMD) { |
| if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { |
| ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); |
| } else { |
| ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); |
| } |
| return; |
| } |
|
|
| if (!fast_fp16_available(cc)) { |
| if (Q->ne[1] <= 8 || Q->ne[0] == 256) { |
| ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); |
| } else { |
| ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); |
| } |
| return; |
| } |
|
|
| if (!fp16_mma_available(cc)) { |
| if (prec == GGML_PREC_DEFAULT) { |
| if (Q->ne[1] <= 8) { |
| ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); |
| } else { |
| ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); |
| } |
| } else { |
| if (Q->ne[1] <= 8) { |
| ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); |
| } else { |
| ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); |
| } |
| } |
| return; |
| } |
|
|
| if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { |
| if (prec == GGML_PREC_DEFAULT) { |
| ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); |
| return; |
| } else if(Q->ne[0] <= 128) { |
| ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); |
| return; |
| } |
| } |
|
|
| |
| if (cc == GGML_CUDA_CC_VOLTA) { |
| ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); |
| return; |
| } |
|
|
| ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); |
| } |
|
|