// Copyright © 2024 Apple Inc. #pragma once #include "gemm/defines.h" /////////////////////////////////////////////////////////////////////////////// // Loading helper /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template < typename T, short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short alignment = 1, short n_reads = (BCOLS * BROWS) / (tgp_size), short TCOLS = BCOLS / n_reads, short TROWS = tgp_size / TCOLS> struct BlockLoader { STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; STEEL_CONST short vec_size = n_reads; // Leading dimension for src const int src_ld; const int tile_stride; // Thread location indices const short thread_idx; const short bi; const short bj; // threadgroup and device memory threadgroup T* dst; const device T* src; struct alignas(alignment * sizeof(T)) ReadVector { uint8_t v[sizeof(T) * vec_size]; }; /* Constructor */ METAL_FUNC BlockLoader( const device T* src_, const int src_ld_, threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), thread_idx(simd_group_id * 32 + simd_lane_id), bi(thread_idx / TCOLS), bj(vec_size * (thread_idx % TCOLS)), dst(dst_ + bi * dst_ld + bj), src(src_ + bi * src_ld + bj) {} /* Apply operation to threadgroup without bound checking */ template METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); } } } /* Load from device memory into threadgroup memory - without bound checking */ METAL_FUNC void load_unsafe() const { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { *((threadgroup ReadVector*)(&dst[i * dst_ld])) = *((const device ReadVector*)(&src[i * src_ld])); } } /* Load from device memory into threadgroup memory - with bound checking */ METAL_FUNC void load_safe(short2 src_tile_dim) const { src_tile_dim = src_tile_dim - short2(bj, bi); // Skip loading if thread has no valid reads if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); } } return; } // Use fast thread memory for bound checks bool tmp_idx[vec_size]; T tmp_val[vec_size]; STEEL_PRAGMA_UNROLL for (short i = 0; i < BROWS; i += TROWS) { // Make sure tmp_idx only contains valid indices STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); } // Read valid indices into tmp_val STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); } // Copy values to threadgroup memory STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = tmp_val[j]; } } } /* Iteration helper */ METAL_FUNC void next() { src += tile_stride; } }; } // namespace steel } // namespace mlx