| |
|
|
| #pragma once |
|
|
| #include <metal_simdgroup> |
| #include <metal_simdgroup_matrix> |
| #include <metal_stdlib> |
|
|
| #include "gemm/defines.h" |
| #include "gemm/transforms.h" |
| #include "gemm/utils/integral_constant.h" |
|
|
| using namespace metal; |
|
|
| |
| |
| |
|
|
| namespace mlx { |
| namespace steel { |
|
|
| template <typename T, int kFragRows_, int kFragCols_> |
| struct BaseMMAFrag { |
| static_assert( |
| kFragRows_ == 8, |
| "Only 8 x 8 fragment matrices are currently supported"); |
| static_assert( |
| kFragCols_ == 8, |
| "Only 8 x 8 fragment matrices are currently supported"); |
| }; |
|
|
| template <typename T> |
| struct BaseMMAFrag<T, 8, 8> { |
| STEEL_CONST int kFragRows = 8; |
| STEEL_CONST int kFragCols = 8; |
|
|
| STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; |
|
|
| STEEL_CONST int kElemRows = 1; |
| STEEL_CONST int kElemCols = 2; |
|
|
| static_assert( |
| kElemRows * kElemCols == kElemsPerFrag, |
| "MMAFrag shape is not consistent with MMAFrag size"); |
|
|
| typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type; |
| typedef metal::vec<T, kElemsPerFrag> frag_type; |
|
|
| METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id |
| [[thread_index_in_simdgroup]]) { |
| const short qid = simd_lane_id / 4; |
| const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); |
| const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; |
| return short2{fn, fm}; |
| } |
|
|
| template <typename SrcPtrType, typename StrX, typename StrY> |
| METAL_FUNC static constexpr void |
| load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < kElemRows; i++) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < kElemCols; j++) { |
| dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]); |
| } |
| } |
| } |
|
|
| template < |
| typename SrcPtrType, |
| typename StrX, |
| typename StrY, |
| typename LimX, |
| typename LimY, |
| typename OffX, |
| typename OffY> |
| METAL_FUNC static constexpr void load_safe( |
| thread frag_type& dst, |
| SrcPtrType src, |
| StrX str_x, |
| StrY str_y, |
| LimX lim_x, |
| LimY lim_y, |
| OffX off_x = Int<0>{}, |
| OffY off_y = Int<0>{}) { |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < kElemRows; i++) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < kElemCols; j++) { |
| if ((off_x + i) < lim_x && (off_y + j) < lim_y) { |
| dst[i * kElemCols + j] = |
| static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]); |
| } else { |
| dst[i * kElemCols + j] = T(0); |
| } |
| } |
| } |
| } |
|
|
| template <typename DstPtrType, typename StrX, typename StrY> |
| METAL_FUNC static constexpr void |
| store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { |
| using U = pointer_element_t<DstPtrType>; |
|
|
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < kElemRows; i++) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < kElemCols; j++) { |
| dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]); |
| } |
| } |
| } |
|
|
| template < |
| typename DstPtrType, |
| typename StrX, |
| typename StrY, |
| typename LimX, |
| typename LimY, |
| typename OffX, |
| typename OffY> |
| METAL_FUNC static constexpr void store_safe( |
| const thread frag_type& src, |
| DstPtrType dst, |
| StrX str_x, |
| StrY str_y, |
| LimX lim_x, |
| LimY lim_y, |
| OffX off_x = Int<0>{}, |
| OffY off_y = Int<0>{}) { |
| using U = pointer_element_t<DstPtrType>; |
|
|
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < kElemRows; i++) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < kElemCols; j++) { |
| if ((off_x + i) < lim_x && (off_y + j) < lim_y) { |
| dst[(off_x + i) * str_x + (off_y + j) * str_y] = |
| static_cast<U>(src[i * kElemCols + j]); |
| } |
| } |
| } |
| } |
|
|
| template < |
| typename DstPtrType, |
| typename StrX, |
| typename StrY, |
| typename StartX, |
| typename StopX, |
| typename StartY, |
| typename StopY, |
| typename OffX, |
| typename OffY> |
| METAL_FUNC static constexpr void store_slice( |
| const thread frag_type& src, |
| DstPtrType dst, |
| StrX str_x, |
| StrY str_y, |
| StartX start_x, |
| StopX stop_x, |
| StartY start_y, |
| StopY stop_y, |
| OffX off_x = Int<0>{}, |
| OffY off_y = Int<0>{}) { |
| using U = pointer_element_t<DstPtrType>; |
|
|
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < kElemRows; i++) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < kElemCols; j++) { |
| if ((off_x + i) < stop_x && (off_x + i) >= start_x && |
| (off_y + j) < stop_y && (off_y + j) >= start_y) { |
| dst[(off_x + i) * str_x + (off_y + j) * str_y] = |
| static_cast<U>(src[i * kElemCols + j]); |
| } |
| } |
| } |
| } |
|
|
| METAL_FUNC static constexpr void mma( |
| thread frag_type& D, |
| thread frag_type& A, |
| thread frag_type& B, |
| thread frag_type& C) { |
| mat_type D_mat; |
| mat_type A_mat; |
| mat_type B_mat; |
| mat_type C_mat; |
|
|
| reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A; |
| reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B; |
| reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C; |
|
|
| mma(D_mat, A_mat, B_mat, C_mat); |
|
|
| D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements()); |
| } |
|
|
| METAL_FUNC static constexpr void mma( |
| thread mat_type& D, |
| thread mat_type& A, |
| thread mat_type& B, |
| thread mat_type& C) { |
| simdgroup_multiply_accumulate(D, A, B, C); |
| } |
| }; |
|
|
| template < |
| typename T, |
| int kTileRows_, |
| int kTileCols_, |
| class MMAFrag_ = BaseMMAFrag<T, 8, 8>> |
| struct MMATile { |
| using MMAFrag_t = MMAFrag_; |
| using elem_type = T; |
| STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; |
| STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; |
| STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; |
|
|
| STEEL_CONST int kTileRows = kTileRows_; |
| STEEL_CONST int kTileCols = kTileCols_; |
|
|
| STEEL_CONST int kRows = kTileRows * kFragRows; |
| STEEL_CONST int kCols = kTileCols * kFragCols; |
|
|
| STEEL_CONST int kNumFrags = kTileRows * kTileCols; |
| STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; |
|
|
| typedef typename MMAFrag_t::mat_type mat_type; |
| typedef typename MMAFrag_t::frag_type frag_type; |
|
|
| frag_type val_frags[kNumFrags] = {frag_type(0)}; |
|
|
| METAL_FUNC MMATile() thread {} |
|
|
| METAL_FUNC constexpr void clear() { |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < kNumFrags; ++i) { |
| val_frags[i] = frag_type(0); |
| } |
| } |
|
|
| METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { |
| return val_frags[i * kTileCols + j]; |
| } |
|
|
| METAL_FUNC constexpr const thread frag_type& frag_at( |
| const short i, |
| const short j) const { |
| return val_frags[i * kTileCols + j]; |
| } |
|
|
| METAL_FUNC mat_type mat_at(const short i, const short j) { |
| mat_type val_mat; |
| STEEL_PRAGMA_UNROLL |
| for (short ii = 0; ii < kElemsPerFrag; ++ii) { |
| val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; |
| } |
| return val_mat; |
| } |
|
|
| METAL_FUNC thread elem_type* elems() { |
| return reinterpret_cast<thread elem_type*>(val_frags); |
| } |
|
|
| METAL_FUNC const thread elem_type* elems() const { |
| return reinterpret_cast<const thread elem_type*>(val_frags); |
| } |
|
|
| template <typename U, int w_x, int w_y, int str_x, int str_y> |
| METAL_FUNC void load(const threadgroup U* src) { |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < kTileRows; ++i) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < kTileCols; ++j) { |
| MMAFrag_t::load( |
| frag_at(i, j), |
| &( |
| src[(i * kFragRows) * w_x * str_x + |
| (j * kFragCols) * w_y * str_y]), |
| Int<str_x>{}, |
| Int<str_y>{}); |
| } |
| } |
| } |
|
|
| template <typename U, int w_x, int w_y, int str_x, int str_y> |
| METAL_FUNC void store(threadgroup U* dst) const { |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < kTileRows; ++i) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < kTileCols; ++j) { |
| MMAFrag_t::store( |
| frag_at(i, j), |
| &( |
| dst[(i * kFragRows) * w_x * str_x + |
| (j * kFragCols) * w_y * str_y]), |
| Int<str_x>{}, |
| Int<str_y>{}); |
| } |
| } |
| } |
|
|
| template <typename U, int w_x, int w_y> |
| METAL_FUNC void load(const device U* src, const int ld) { |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < kTileRows; ++i) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < kTileCols; ++j) { |
| MMAFrag_t::load( |
| frag_at(i, j), |
| &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), |
| ld, |
| Int<1>{}); |
| } |
| } |
| } |
|
|
| template <typename U, int w_x, int w_y> |
| METAL_FUNC void store(device U* dst, const int ld) const { |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < kTileRows; ++i) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < kTileCols; ++j) { |
| MMAFrag_t::store( |
| frag_at(i, j), |
| &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), |
| ld, |
| Int<1>{}); |
| } |
| } |
| } |
|
|
| template <typename U, int w_x, int w_y> |
| METAL_FUNC void |
| load_safe(const device U* src, const int ld, const short2 src_tile_dims) { |
| STEEL_PRAGMA_UNROLL |
| for (int i = 0; i < kTileRows; ++i) { |
| STEEL_PRAGMA_UNROLL |
| for (int j = 0; j < kTileCols; ++j) { |
| MMAFrag_t::load_safe( |
| frag_at(i, j), |
| src, |
| ld, |
| Int<1>{}, |
| src_tile_dims.y, |
| src_tile_dims.x, |
| (i * kFragRows) * w_x, |
| (j * kFragCols) * w_y); |
| } |
| } |
| } |
|
|
| template <typename U, int w_x, int w_y> |
| METAL_FUNC void |
| store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { |
| STEEL_PRAGMA_UNROLL |
| for (int i = 0; i < kTileRows; ++i) { |
| STEEL_PRAGMA_UNROLL |
| for (int j = 0; j < kTileCols; ++j) { |
| MMAFrag_t::store_safe( |
| frag_at(i, j), |
| dst, |
| ld, |
| Int<1>{}, |
| dst_tile_dims.y, |
| dst_tile_dims.x, |
| (i * kFragRows) * w_x, |
| (j * kFragCols) * w_y); |
| } |
| } |
| } |
|
|
| template <typename U, int w_x, int w_y> |
| METAL_FUNC void store_slice( |
| device U* dst, |
| const int ld, |
| const short2 start, |
| const short2 stop) const { |
| STEEL_PRAGMA_UNROLL |
| for (int i = 0; i < kTileRows; ++i) { |
| STEEL_PRAGMA_UNROLL |
| for (int j = 0; j < kTileCols; ++j) { |
| MMAFrag_t::store_slice( |
| frag_at(i, j), |
| dst, |
| ld, |
| Int<1>{}, |
| start.y, |
| stop.y, |
| start.x, |
| stop.x, |
| (i * kFragRows) * w_x, |
| (j * kFragCols) * w_y); |
| } |
| } |
| } |
| }; |
|
|
| template <typename T, typename U, int M, int N, int K> |
| METAL_FUNC void tile_matmad( |
| thread MMATile<T, M, N>& D, |
| thread MMATile<U, M, K>& A, |
| thread MMATile<U, K, N>& B, |
| thread MMATile<T, M, N>& C) { |
| STEEL_PRAGMA_UNROLL |
| for (short m = 0; m < M; ++m) { |
| STEEL_PRAGMA_UNROLL |
| for (short n = 0; n < N; ++n) { |
| short n_serp = (m % 2) ? (N - 1 - n) : n; |
| STEEL_PRAGMA_UNROLL |
| for (short k = 0; k < K; ++k) { |
| MMATile<T, M, N>::MMAFrag_t::mma( |
| D.frag_at(m, n_serp), |
| A.frag_at(m, k), |
| B.frag_at(k, n_serp), |
| C.frag_at(m, n_serp)); |
| } |
| } |
| } |
| } |
|
|
| template < |
| typename T, |
| typename U, |
| int BM, |
| int BN, |
| int BK, |
| int WM, |
| int WN, |
| bool transpose_a, |
| bool transpose_b, |
| short lda_tgp, |
| short ldb_tgp, |
| typename AccumType = float, |
| typename Epilogue = TransformNone<U, AccumType>> |
| struct BlockMMA { |
| |
| STEEL_CONST short kFragSize = 8; |
| using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>; |
|
|
| |
| STEEL_CONST short TM_stride = kFragSize * WM; |
| |
| STEEL_CONST short TN_stride = kFragSize * WN; |
|
|
| |
| STEEL_CONST short TM = BM / (kFragSize * WM); |
| |
| STEEL_CONST short TN = BN / (kFragSize * WN); |
|
|
| |
| STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; |
| STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; |
|
|
| |
| STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; |
| STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; |
|
|
| |
| STEEL_CONST short tile_stride_a = kFragSize * A_str_k; |
| STEEL_CONST short tile_stride_b = kFragSize * B_str_k; |
|
|
| |
| MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile; |
| MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile; |
| MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile; |
|
|
| |
| short sm; |
| short sn; |
|
|
| short As_offset; |
| short Bs_offset; |
|
|
| |
| METAL_FUNC BlockMMA( |
| ushort simd_group_id [[simdgroup_index_in_threadgroup]], |
| ushort simd_lane_id [[thread_index_in_simdgroup]]) { |
| |
| short tm = kFragSize * (simd_group_id / WN); |
| short tn = kFragSize * (simd_group_id % WN); |
|
|
| short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); |
| sm = simd_coord.y; |
| sn = simd_coord.x; |
|
|
| |
| As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; |
| Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; |
|
|
| sm += tm; |
| sn += tn; |
| } |
|
|
| |
| METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { |
| |
| As += As_offset; |
| Bs += Bs_offset; |
|
|
| |
| STEEL_PRAGMA_UNROLL |
| for (short kk = 0; kk < BK; kk += kFragSize) { |
| simdgroup_barrier(mem_flags::mem_none); |
|
|
| Atile.template load<T, WM, 1, A_str_m, A_str_k>(As); |
|
|
| simdgroup_barrier(mem_flags::mem_none); |
|
|
| Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs); |
|
|
| simdgroup_barrier(mem_flags::mem_none); |
|
|
| tile_matmad(Ctile, Atile, Btile, Ctile); |
|
|
| |
| As += tile_stride_a; |
| Bs += tile_stride_b; |
| } |
| } |
|
|
| |
| METAL_FUNC void store_result(device U* D, const int ldd) { |
| |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { |
| Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); |
| } |
|
|
| |
| D += sm * ldd + sn; |
|
|
| Ctile.template store<U, WM, WN>(D, ldd); |
| } |
|
|
| METAL_FUNC void |
| store_result_slice(device U* D, const int ldd, short2 start, short2 stop) { |
| |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { |
| Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); |
| } |
|
|
| D += sm * ldd + sn; |
| start -= short2(sn, sm); |
| stop -= short2(sn, sm); |
|
|
| |
| if (stop.y <= 0 || stop.x <= 0) { |
| return; |
| } |
|
|
| Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop); |
| } |
|
|
| METAL_FUNC void |
| store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { |
| |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { |
| Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); |
| } |
|
|
| |
| D += sm * ldd + sn; |
| dst_tile_dims -= short2(sn, sm); |
|
|
| if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) |
| return; |
|
|
| Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims); |
| } |
|
|
| |
| template <typename UnaryEpilogue> |
| METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { |
| |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { |
| Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); |
| } |
| } |
|
|
| |
| template <typename BinaryEpilogue> |
| METAL_FUNC void apply_epilogue( |
| const device U* C, |
| const int ldc, |
| const int fdc, |
| thread const BinaryEpilogue& epilogue_op) { |
| |
| C += (sm)*ldc + (sn)*fdc; |
|
|
| |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < TM; i++) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < TN; j++) { |
| |
| thread auto& accum = Ctile.frag_at(i, j); |
| int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; |
|
|
| |
| STEEL_PRAGMA_UNROLL |
| for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { |
| accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); |
| } |
| } |
| } |
| } |
|
|
| |
| template <typename BinaryEpilogue> |
| METAL_FUNC void apply_epilogue_safe( |
| const device U* C, |
| const int ldc, |
| const int fdc, |
| short2 dst_tile_dims, |
| thread const BinaryEpilogue& epilogue_op) { |
| |
| C += (sm)*ldc + (sn)*fdc; |
| dst_tile_dims -= short2(sn, sm); |
|
|
| if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) |
| return; |
|
|
| |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < TM; i++) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < TN; j++) { |
| |
| thread auto& accum = Ctile.frag_at(i, j); |
| int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; |
|
|
| constexpr short kelems = decltype(Ctile)::kElemsPerFrag; |
|
|
| |
| U c_elems[kelems] = {0}; |
|
|
| STEEL_PRAGMA_UNROLL |
| for (short k = 0; k < kelems; k++) { |
| if ((j * TN_stride + k) < dst_tile_dims.x) { |
| c_elems[k] = C[offset_c + k * fdc]; |
| } |
| } |
|
|
| |
| STEEL_PRAGMA_UNROLL |
| for (short k = 0; k < kelems; k++) { |
| accum[k] = epilogue_op.apply(accum[k], c_elems[k]); |
| } |
| } |
| } |
| } |
|
|
| |
| METAL_FUNC void store_result( |
| device U* D, |
| const int ldd, |
| const device U* C, |
| const int ldc, |
| const int fdc, |
| thread const Epilogue& epilogue_op) const { |
| |
| C += (sm)*ldc + (sn)*fdc; |
| D += (sm)*ldd + sn; |
|
|
| constexpr short kelems = decltype(Ctile)::kElemsPerFrag; |
|
|
| |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < TM; i++) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < TN; j++) { |
| |
| thread const auto& accum = Ctile.frag_at(i, j); |
| int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; |
| int offset_d = (i * TM_stride) * ldd + (j * TN_stride); |
|
|
| |
| STEEL_PRAGMA_UNROLL |
| for (short k = 0; k < kelems; k++) { |
| D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); |
| } |
| } |
| } |
| } |
|
|
| METAL_FUNC void store_result_safe( |
| device U* D, |
| const int ldd, |
| const device U* C, |
| const int ldc, |
| const int fdc, |
| short2 dst_tile_dims, |
| thread const Epilogue& epilogue_op) const { |
| |
| C += (sm)*ldc + (sn)*fdc; |
| D += (sm)*ldd + sn; |
| dst_tile_dims -= short2(sn, sm); |
|
|
| if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) |
| return; |
|
|
| constexpr short kelems = decltype(Ctile)::kElemsPerFrag; |
|
|
| STEEL_PRAGMA_UNROLL |
| for (int i = 0; i < TM; i++) { |
| if (i * TM_stride < dst_tile_dims.y) { |
| STEEL_PRAGMA_UNROLL |
| for (int j = 0; j < TN; j++) { |
| |
| thread const auto& accum = Ctile.frag_at(i, j); |
| int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; |
| int offset_d = (i * TM_stride) * ldd + (j * TN_stride); |
|
|
| |
| STEEL_PRAGMA_UNROLL |
| for (short k = 0; k < kelems; k++) { |
| if ((j * TN_stride + k) < dst_tile_dims.x) { |
| D[offset_d + k] = |
| epilogue_op.apply(accum[k], C[offset_c + k * fdc]); |
| } |
| } |
| } |
| } |
| } |
| } |
| }; |
|
|
| } |
| } |
|
|