File size: 3,805 Bytes
20347e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | // Copyright © 2024 Apple Inc.
#pragma once
#include "gemm/defines.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoader {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out unneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
} // namespace steel
} // namespace mlx
|