// Copyright © 2024 Apple Inc. #pragma once #include "gemm/loader.h" #include "gemm/mma.h" #include "gemm/params.h" #include "gemm/transforms.h" #include "gemm/utils.h" using namespace metal; /////////////////////////////////////////////////////////////////////////////// // GEMM kernel class /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template struct LoopAlignment {}; template < typename T, typename U, int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned, typename AccumType = typename AccumHelper::accum_type, typename Epilogue = TransformNone> struct GEMMKernel { STEEL_CONST short tgp_padding_a = 16 / sizeof(T); STEEL_CONST short tgp_padding_b = 16 / sizeof(T); STEEL_CONST short tgp_mem_size_a = transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); STEEL_CONST short tgp_mem_size_b = transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; STEEL_CONST short tgp_size = WM * WN * 32; using loader_a_t = BlockLoader< T, transpose_a ? BK : BM, transpose_a ? BM : BK, transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, !transpose_a, tgp_size>; using loader_b_t = BlockLoader< T, transpose_b ? BN : BK, transpose_b ? BK : BN, transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, transpose_b, tgp_size>; using mma_t = BlockMMA< T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, AccumType, Epilogue>; /* Main kernel function */ template static METAL_FUNC void gemm_loop( threadgroup T* As [[threadgroup(0)]], threadgroup T* Bs [[threadgroup(1)]], const int gemm_k_iterations, thread loader_a_t& loader_a, thread loader_b_t& loader_b, thread mma_t& mma_op, thread const short& tgp_bm, thread const short& tgp_bn, thread const short& lbk, LoopAlignment l = {}) { // Appease the compiler (void)l; short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup if (M_aligned) { loader_a.load_unsafe(); } else { loader_a.load_safe(tile_dims_A); } if (N_aligned) { loader_b.load_unsafe(); } else { loader_b.load_safe(tile_dims_B); } threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } if (!K_aligned_) { threadgroup_barrier(mem_flags::mem_threadgroup); short2 tile_dims_A_last = transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); short2 tile_dims_B_last = transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); loader_a.load_safe(tile_dims_A_last); loader_b.load_safe(tile_dims_B_last); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } } /* Main kernel function */ static METAL_FUNC void run( const device T* A [[buffer(0)]], const device T* B [[buffer(1)]], device U* D [[buffer(2)]], const constant GEMMParams* params [[buffer(3)]], threadgroup T* As [[threadgroup(0)]], threadgroup T* Bs [[threadgroup(1)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // Pacifying compiler (void)lid; const int tid_y = ((tid.y) << params->swizzle_log) + ((tid.x) & ((1 << params->swizzle_log) - 1)); const int tid_x = (tid.x) >> params->swizzle_log; if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { return; } threadgroup_barrier(mem_flags::mem_none); // Find block in A, B, C const int c_row = tid_y * BM; const int c_col = tid_x * BN; const size_t c_row_long = size_t(c_row); const size_t c_col_long = size_t(c_col); A += transpose_a ? c_row_long : c_row_long * params->lda; B += transpose_b ? c_col_long * params->ldb : c_col_long; D += c_row_long * params->ldd + c_col_long; // Prepare threadgroup loading operations thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); // Prepare threadgroup mma operation thread mma_t mma_op(simd_group_id, simd_lane_id); int gemm_k_iterations = params->gemm_k_iterations_aligned; /////////////////////////////////////////////////////////////////////////////// // MNK aligned loop if (MN_aligned) { for (int k = 0; k < gemm_k_iterations; k++) { threadgroup_barrier(mem_flags::mem_threadgroup); // Load elements into threadgroup loader_a.load_unsafe(); loader_b.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); // Multiply and accumulate threadgroup elements mma_op.mma(As, Bs); // Prepare for next iteration loader_a.next(); loader_b.next(); } threadgroup_barrier(mem_flags::mem_none); // Loop tail if (!K_aligned) { int lbk = params->K - params->gemm_k_iterations_aligned * BK; short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); loader_a.load_safe(tile_dims_A); loader_b.load_safe(tile_dims_B); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(As, Bs); } // Store results to device memory mma_op.store_result(D, params->ldd); return; } /////////////////////////////////////////////////////////////////////////////// // MN unaligned loop else { // Loop over K - unaligned case short tgp_bm = min(BM, params->M - c_row); short tgp_bn = min(BN, params->N - c_col); short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; if (tgp_bm == BM && tgp_bn == BN) { gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk); mma_op.store_result(D, params->ldd); return; } else if (tgp_bn == BN) { gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk); mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); return; } else if (tgp_bm == BM) { gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk); mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); return; } else { gemm_loop( As, Bs, gemm_k_iterations, loader_a, loader_b, mma_op, tgp_bm, tgp_bn, leftover_bk); mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); return; } } } }; } // namespace steel } // namespace mlx