// Copyright © 2023-2024 Apple Inc. #pragma once #include #include "bf16.h" #include "defines.h" typedef half float16_t; // Work per thread values for different types. The values here are expected to // match get_work_per_thread in mlx/backend/metal/utils.h template struct WorkPerThread { static_assert(sizeof(U) <= 8, "Type too large"); static constexpr int constant n = 8 / sizeof(U); }; /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// template struct Limits { static const constant U max = metal::numeric_limits::max(); static const constant U min = metal::numeric_limits::min(); static const constant U finite_max = metal::numeric_limits::max(); static const constant U finite_min = metal::numeric_limits::min(); }; #define instantiate_default_limit(type) \ template <> \ struct Limits { \ static constexpr constant type max = metal::numeric_limits::max(); \ static constexpr constant type min = metal::numeric_limits::min(); \ static constexpr constant type finite_max = \ metal::numeric_limits::max(); \ static constexpr constant type finite_min = \ metal::numeric_limits::min(); \ }; instantiate_default_limit(uint8_t); instantiate_default_limit(uint16_t); instantiate_default_limit(uint32_t); instantiate_default_limit(uint64_t); instantiate_default_limit(int8_t); instantiate_default_limit(int16_t); instantiate_default_limit(int32_t); instantiate_default_limit(int64_t); #define instantiate_float_limit(type) \ template <> \ struct Limits { \ static constexpr constant type max = \ metal::numeric_limits::infinity(); \ static constexpr constant type min = \ -metal::numeric_limits::infinity(); \ static constexpr constant type finite_max = \ metal::numeric_limits::max(); \ static constexpr constant type finite_min = \ -metal::numeric_limits::max(); \ }; instantiate_float_limit(half); instantiate_float_limit(float); instantiate_float_limit(bfloat16_t); template <> struct Limits { static constexpr constant bool max = true; static constexpr constant bool min = false; }; // complex64_t specialization removed - not needed for BnB kernels /////////////////////////////////////////////////////////////////////////////// // Indexing utils /////////////////////////////////////////////////////////////////////////////// #define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") /////////////////////////////////////////////////////////////////////////////// // Single Array with generic dims template METAL_FUNC IdxT elem_to_loc( IdxT elem, constant const int* shape, constant const int64_t* strides, int ndim) { IdxT loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * IdxT(strides[i]); elem /= shape[i]; } return loc; } // Non templated version to handle arbitrary dims template METAL_FUNC IdxT elem_to_loc( uint3 elem, constant const int* shape, constant const int64_t* strides, int ndim) { IdxT loc = elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]); for (int d = ndim - 3; d >= 0; --d) { loc += (elem.z % shape[d]) * IdxT(strides[d]); elem.z /= shape[d]; } return loc; } /////////////////////////////////////////////////////////////////////////////// // Single Array with fixed N dims template METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) { return elem * IdxT(stride); } template METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) { return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]); } template METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) { return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) + elem.z * IdxT(strides[0]); } /////////////////////////////////////////////////////////////////////////////// // Multiple Arrays with generic dims template METAL_FUNC vec elem_to_loc_2_nd( uint3 elem, constant const int* shape, constant const int64_t* a_strides, constant const int64_t* b_strides, int ndim) { vec loc = { IdxT( elem.x * IdxT(a_strides[ndim - 1]) + IdxT(elem.y) * IdxT(a_strides[ndim - 2])), IdxT( elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]))}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; loc.x += l * IdxT(a_strides[d]); loc.y += l * IdxT(b_strides[d]); elem.z /= shape[d]; } return loc; } template METAL_FUNC vec elem_to_loc_3_nd( uint3 elem, constant const int* shape, constant const int64_t* a_strides, constant const int64_t* b_strides, constant const int64_t* c_strides, int ndim) { vec loc = { IdxT(elem.x * IdxT(a_strides[ndim - 1])) + IdxT(elem.y * IdxT(a_strides[ndim - 2])), IdxT(elem.x * IdxT(b_strides[ndim - 1])) + IdxT(elem.y * IdxT(b_strides[ndim - 2])), IdxT(elem.x * IdxT(c_strides[ndim - 1])) + IdxT(elem.y * IdxT(c_strides[ndim - 2]))}; for (int d = ndim - 3; d >= 0; --d) { uint l = elem.z % shape[d]; loc.x += l * IdxT(a_strides[d]); loc.y += l * IdxT(b_strides[d]); loc.z += l * IdxT(c_strides[d]); elem.z /= shape[d]; } return loc; } /////////////////////////////////////////////////////////////////////////////// // Elem to loc in a loop utils /////////////////////////////////////////////////////////////////////////////// template struct LoopedElemToLoc { int dim; LoopedElemToLoc inner_looper; OffsetT offset{0}; int index{0}; LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} void next(const constant int* shape, const constant int64_t* strides) { if (dim == 0) { return; } index++; offset += OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { index = 0; inner_looper.next(shape, strides); offset = inner_looper.offset; } } void next(int n, const constant int* shape, const constant int64_t* strides) { if (dim == 0) { return; } index += n; offset += n * OffsetT(strides[dim - 1]); if (index >= shape[dim - 1]) { int extra = index - shape[dim - 1]; if (extra >= shape[dim - 1]) { inner_looper.next(1 + extra / shape[dim - 1], shape, strides); extra = extra % shape[dim - 1]; } else { inner_looper.next(shape, strides); } index = 0; offset = inner_looper.offset; if (extra > 0) { next(extra, shape, strides); } } } OffsetT location() { return offset; } }; template struct LoopedElemToLoc<1, OffsetT, true> { int dim; OffsetT offset{0}; uint index{0}; LoopedElemToLoc(int dim) : dim(dim) {} void next(const constant int* shape, const constant int64_t* strides) { index++; if (dim > 1) { offset = elem_to_loc(index, shape, strides, dim); } else { offset += OffsetT(strides[0]); } } void next(int n, const constant int* shape, const constant int64_t* strides) { index += n; if (dim > 1) { offset = elem_to_loc(index, shape, strides, dim); } else { offset = index * OffsetT(strides[0]); } } OffsetT location() { return offset; } }; template struct LoopedElemToLoc<1, OffsetT, false> { OffsetT offset{0}; LoopedElemToLoc(int) {} void next(const constant int*, const constant int64_t* strides) { offset += OffsetT(strides[0]); } void next(int n, const constant int*, const constant int64_t* strides) { offset += n * OffsetT(strides[0]); } OffsetT location() { return offset; } }; /////////////////////////////////////////////////////////////////////////////// // Calculation utils /////////////////////////////////////////////////////////////////////////////// /** Compute ceil((float)N/(float)M) */ template inline T ceildiv(T N, U M) { return (N + M - 1) / M; } // https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 inline float log1p(float x) { float xp1 = 1.0f + x; if (xp1 == Limits::max) { return Limits::max; } if (xp1 == 1.0f) { return x; } return x * (metal::log(xp1) / (xp1 - 1.0f)); } inline bfloat16_t log1p(bfloat16_t x) { float xp1 = 1.0f + static_cast(x); if (xp1 == Limits::max) { return Limits::max; } if (xp1 == 1.0f) { return x; } return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); } /////////////////////////////////////////////////////////////////////////////// // SIMD shuffle ops /////////////////////////////////////////////////////////////////////////////// inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { return as_type( metal::simd_shuffle_down(as_type(data), delta)); } inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { return as_type( metal::simd_shuffle_down(as_type(data), delta)); } inline bool simd_shuffle_down(bool data, uint16_t delta) { return simd_shuffle_down(static_cast(data), delta); } inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) { return as_type(metal::simd_shuffle_up(as_type(data), delta)); } inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) { return as_type(metal::simd_shuffle_up(as_type(data), delta)); } inline bool simd_shuffle_up(bool data, uint16_t delta) { return simd_shuffle_up(static_cast(data), delta); } inline uint64_t simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) { return as_type(metal::simd_shuffle_and_fill_up( as_type(data), as_type(filling), delta)); } inline int64_t simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) { return as_type(metal::simd_shuffle_and_fill_up( as_type(data), as_type(filling), delta)); } inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) { return simd_shuffle_and_fill_up( static_cast(data), static_cast(filling), delta); } inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) { return as_type(metal::simd_shuffle(as_type(data), lane)); } inline int64_t simd_shuffle(int64_t data, uint16_t lane) { return as_type(metal::simd_shuffle(as_type(data), lane)); } inline bool simd_shuffle(bool data, uint16_t lane) { return simd_shuffle(static_cast(data), lane); } // std::conditional is not included with Metal template struct ConditionalType { using type = U; }; template struct ConditionalType { using type = T; };