| | |
| |
|
| | #pragma once |
| |
|
| | #include "gemm/defines.h" |
| |
|
| | |
| | |
| | |
| |
|
| | namespace mlx { |
| | namespace steel { |
| |
|
| | template < |
| | typename T, |
| | short BROWS, |
| | short BCOLS, |
| | short dst_ld, |
| | short reduction_dim, |
| | short tgp_size, |
| | short alignment = 1, |
| | short n_reads = (BCOLS * BROWS) / (tgp_size), |
| | short TCOLS = BCOLS / n_reads, |
| | short TROWS = tgp_size / TCOLS> |
| | struct BlockLoader { |
| | STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; |
| | STEEL_CONST short vec_size = n_reads; |
| |
|
| | |
| | const int src_ld; |
| | const int tile_stride; |
| |
|
| | |
| | const short thread_idx; |
| | const short bi; |
| | const short bj; |
| |
|
| | |
| | threadgroup T* dst; |
| | const device T* src; |
| |
|
| | struct alignas(alignment * sizeof(T)) ReadVector { |
| | uint8_t v[sizeof(T) * vec_size]; |
| | }; |
| |
|
| | |
| | METAL_FUNC BlockLoader( |
| | const device T* src_, |
| | const int src_ld_, |
| | threadgroup T* dst_, |
| | ushort simd_group_id [[simdgroup_index_in_threadgroup]], |
| | ushort simd_lane_id [[thread_index_in_simdgroup]]) |
| | : src_ld(src_ld_), |
| | tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), |
| | thread_idx(simd_group_id * 32 + simd_lane_id), |
| | bi(thread_idx / TCOLS), |
| | bj(vec_size * (thread_idx % TCOLS)), |
| | dst(dst_ + bi * dst_ld + bj), |
| | src(src_ + bi * src_ld + bj) {} |
| |
|
| | |
| | template <typename UnaryOp> |
| | METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { |
| | STEEL_PRAGMA_UNROLL |
| | for (short i = 0; i < BROWS; i += TROWS) { |
| | STEEL_PRAGMA_UNROLL |
| | for (short j = 0; j < vec_size; j++) { |
| | dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); |
| | } |
| | } |
| | } |
| |
|
| | |
| | METAL_FUNC void load_unsafe() const { |
| | STEEL_PRAGMA_UNROLL |
| | for (short i = 0; i < BROWS; i += TROWS) { |
| | *((threadgroup ReadVector*)(&dst[i * dst_ld])) = |
| | *((const device ReadVector*)(&src[i * src_ld])); |
| | } |
| | } |
| |
|
| | |
| | METAL_FUNC void load_safe(short2 src_tile_dim) const { |
| | src_tile_dim = src_tile_dim - short2(bj, bi); |
| |
|
| | |
| | if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { |
| | STEEL_PRAGMA_UNROLL |
| | for (short i = 0; i < BROWS; i += TROWS) { |
| | STEEL_PRAGMA_UNROLL |
| | for (short j = 0; j < vec_size; j++) { |
| | dst[i * dst_ld + j] = T(0); |
| | } |
| | } |
| | return; |
| | } |
| |
|
| | |
| | bool tmp_idx[vec_size]; |
| | T tmp_val[vec_size]; |
| |
|
| | STEEL_PRAGMA_UNROLL |
| | for (short i = 0; i < BROWS; i += TROWS) { |
| | |
| | STEEL_PRAGMA_UNROLL |
| | for (short j = 0; j < vec_size; j++) { |
| | tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); |
| | } |
| |
|
| | |
| | STEEL_PRAGMA_UNROLL |
| | for (short j = 0; j < vec_size; j++) { |
| | tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; |
| | } |
| |
|
| | |
| | STEEL_PRAGMA_UNROLL |
| | for (short j = 0; j < vec_size; j++) { |
| | tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); |
| | } |
| |
|
| | |
| | STEEL_PRAGMA_UNROLL |
| | for (short j = 0; j < vec_size; j++) { |
| | dst[i * dst_ld + j] = tmp_val[j]; |
| | } |
| | } |
| | } |
| |
|
| | |
| | METAL_FUNC void next() { |
| | src += tile_stride; |
| | } |
| | }; |
| |
|
| | } |
| | } |
| |
|