| | |
| |
|
| | #pragma once |
| |
|
| | #include "gemm/loader.h" |
| | #include "gemm/mma.h" |
| | #include "gemm/params.h" |
| | #include "gemm/transforms.h" |
| | #include "gemm/utils.h" |
| |
|
| | using namespace metal; |
| |
|
| | |
| | |
| | |
| |
|
| | namespace mlx { |
| | namespace steel { |
| |
|
| | template <bool M_aligned, bool N_aligned, bool K_aligned> |
| | struct LoopAlignment {}; |
| |
|
| | template < |
| | typename T, |
| | typename U, |
| | int BM, |
| | int BN, |
| | int BK, |
| | int WM, |
| | int WN, |
| | bool transpose_a, |
| | bool transpose_b, |
| | bool MN_aligned, |
| | bool K_aligned, |
| | typename AccumType = typename AccumHelper<T>::accum_type, |
| | typename Epilogue = TransformNone<U, AccumType>> |
| | struct GEMMKernel { |
| | STEEL_CONST short tgp_padding_a = 16 / sizeof(T); |
| | STEEL_CONST short tgp_padding_b = 16 / sizeof(T); |
| | STEEL_CONST short tgp_mem_size_a = |
| | transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); |
| | STEEL_CONST short tgp_mem_size_b = |
| | transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); |
| | STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; |
| |
|
| | STEEL_CONST short tgp_size = WM * WN * 32; |
| |
|
| | using loader_a_t = BlockLoader< |
| | T, |
| | transpose_a ? BK : BM, |
| | transpose_a ? BM : BK, |
| | transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, |
| | !transpose_a, |
| | tgp_size>; |
| | using loader_b_t = BlockLoader< |
| | T, |
| | transpose_b ? BN : BK, |
| | transpose_b ? BK : BN, |
| | transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, |
| | transpose_b, |
| | tgp_size>; |
| | using mma_t = BlockMMA< |
| | T, |
| | U, |
| | BM, |
| | BN, |
| | BK, |
| | WM, |
| | WN, |
| | transpose_a, |
| | transpose_b, |
| | transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, |
| | transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, |
| | AccumType, |
| | Epilogue>; |
| |
|
| | |
| | template <bool M_aligned, bool N_aligned, bool K_aligned_> |
| | static METAL_FUNC void gemm_loop( |
| | threadgroup T* As [[threadgroup(0)]], |
| | threadgroup T* Bs [[threadgroup(1)]], |
| | const int gemm_k_iterations, |
| | thread loader_a_t& loader_a, |
| | thread loader_b_t& loader_b, |
| | thread mma_t& mma_op, |
| | thread const short& tgp_bm, |
| | thread const short& tgp_bn, |
| | thread const short& lbk, |
| | LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) { |
| | |
| | (void)l; |
| |
|
| | short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); |
| |
|
| | short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); |
| |
|
| | for (int k = 0; k < gemm_k_iterations; k++) { |
| | threadgroup_barrier(mem_flags::mem_threadgroup); |
| | |
| | if (M_aligned) { |
| | loader_a.load_unsafe(); |
| | } else { |
| | loader_a.load_safe(tile_dims_A); |
| | } |
| |
|
| | if (N_aligned) { |
| | loader_b.load_unsafe(); |
| | } else { |
| | loader_b.load_safe(tile_dims_B); |
| | } |
| |
|
| | threadgroup_barrier(mem_flags::mem_threadgroup); |
| |
|
| | |
| | mma_op.mma(As, Bs); |
| |
|
| | |
| | loader_a.next(); |
| | loader_b.next(); |
| | } |
| |
|
| | if (!K_aligned_) { |
| | threadgroup_barrier(mem_flags::mem_threadgroup); |
| |
|
| | short2 tile_dims_A_last = |
| | transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); |
| | short2 tile_dims_B_last = |
| | transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); |
| |
|
| | loader_a.load_safe(tile_dims_A_last); |
| | loader_b.load_safe(tile_dims_B_last); |
| |
|
| | threadgroup_barrier(mem_flags::mem_threadgroup); |
| |
|
| | mma_op.mma(As, Bs); |
| | } |
| | } |
| |
|
| | |
| | static METAL_FUNC void run( |
| | const device T* A [[buffer(0)]], |
| | const device T* B [[buffer(1)]], |
| | device U* D [[buffer(2)]], |
| | const constant GEMMParams* params [[buffer(3)]], |
| | threadgroup T* As [[threadgroup(0)]], |
| | threadgroup T* Bs [[threadgroup(1)]], |
| | uint simd_lane_id [[thread_index_in_simdgroup]], |
| | uint simd_group_id [[simdgroup_index_in_threadgroup]], |
| | uint3 tid [[threadgroup_position_in_grid]], |
| | uint3 lid [[thread_position_in_threadgroup]]) { |
| | |
| | (void)lid; |
| |
|
| | const int tid_y = ((tid.y) << params->swizzle_log) + |
| | ((tid.x) & ((1 << params->swizzle_log) - 1)); |
| | const int tid_x = (tid.x) >> params->swizzle_log; |
| |
|
| | if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { |
| | return; |
| | } |
| |
|
| | threadgroup_barrier(mem_flags::mem_none); |
| |
|
| | |
| | const int c_row = tid_y * BM; |
| | const int c_col = tid_x * BN; |
| | const size_t c_row_long = size_t(c_row); |
| | const size_t c_col_long = size_t(c_col); |
| |
|
| | A += transpose_a ? c_row_long : c_row_long * params->lda; |
| | B += transpose_b ? c_col_long * params->ldb : c_col_long; |
| | D += c_row_long * params->ldd + c_col_long; |
| |
|
| | |
| | thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); |
| | thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); |
| |
|
| | |
| | thread mma_t mma_op(simd_group_id, simd_lane_id); |
| |
|
| | int gemm_k_iterations = params->gemm_k_iterations_aligned; |
| |
|
| | |
| | |
| | if (MN_aligned) { |
| | for (int k = 0; k < gemm_k_iterations; k++) { |
| | threadgroup_barrier(mem_flags::mem_threadgroup); |
| | |
| | loader_a.load_unsafe(); |
| | loader_b.load_unsafe(); |
| |
|
| | threadgroup_barrier(mem_flags::mem_threadgroup); |
| |
|
| | |
| | mma_op.mma(As, Bs); |
| |
|
| | |
| | loader_a.next(); |
| | loader_b.next(); |
| | } |
| |
|
| | threadgroup_barrier(mem_flags::mem_none); |
| |
|
| | |
| | if (!K_aligned) { |
| | int lbk = params->K - params->gemm_k_iterations_aligned * BK; |
| | short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); |
| | short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); |
| |
|
| | loader_a.load_safe(tile_dims_A); |
| | loader_b.load_safe(tile_dims_B); |
| |
|
| | threadgroup_barrier(mem_flags::mem_threadgroup); |
| |
|
| | mma_op.mma(As, Bs); |
| | } |
| |
|
| | |
| | mma_op.store_result(D, params->ldd); |
| | return; |
| |
|
| | } |
| | |
| | |
| | else { |
| | short tgp_bm = min(BM, params->M - c_row); |
| | short tgp_bn = min(BN, params->N - c_col); |
| | short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; |
| |
|
| | if (tgp_bm == BM && tgp_bn == BN) { |
| | gemm_loop<true, true, K_aligned>( |
| | As, |
| | Bs, |
| | gemm_k_iterations, |
| | loader_a, |
| | loader_b, |
| | mma_op, |
| | tgp_bm, |
| | tgp_bn, |
| | leftover_bk); |
| |
|
| | mma_op.store_result(D, params->ldd); |
| | return; |
| |
|
| | } else if (tgp_bn == BN) { |
| | gemm_loop<false, true, K_aligned>( |
| | As, |
| | Bs, |
| | gemm_k_iterations, |
| | loader_a, |
| | loader_b, |
| | mma_op, |
| | tgp_bm, |
| | tgp_bn, |
| | leftover_bk); |
| |
|
| | mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); |
| | return; |
| |
|
| | } else if (tgp_bm == BM) { |
| | gemm_loop<true, false, K_aligned>( |
| | As, |
| | Bs, |
| | gemm_k_iterations, |
| | loader_a, |
| | loader_b, |
| | mma_op, |
| | tgp_bm, |
| | tgp_bn, |
| | leftover_bk); |
| |
|
| | mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); |
| | return; |
| |
|
| | } else { |
| | gemm_loop<false, false, K_aligned>( |
| | As, |
| | Bs, |
| | gemm_k_iterations, |
| | loader_a, |
| | loader_b, |
| | mma_op, |
| | tgp_bm, |
| | tgp_bn, |
| | leftover_bk); |
| |
|
| | mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); |
| | return; |
| | } |
| | } |
| | } |
| | }; |
| |
|
| | } |
| | } |