| | |
| |
|
| | #pragma once |
| |
|
| | #include <metal_math> |
| |
|
| | #include "bf16.h" |
| | #include "defines.h" |
| |
|
| | typedef half float16_t; |
| |
|
| | |
| | |
| | template <typename U> |
| | struct WorkPerThread { |
| | static_assert(sizeof(U) <= 8, "Type too large"); |
| | static constexpr int constant n = 8 / sizeof(U); |
| | }; |
| |
|
| | |
| | |
| | |
| |
|
| | template <typename U> |
| | struct Limits { |
| | static const constant U max = metal::numeric_limits<U>::max(); |
| | static const constant U min = metal::numeric_limits<U>::min(); |
| | static const constant U finite_max = metal::numeric_limits<U>::max(); |
| | static const constant U finite_min = metal::numeric_limits<U>::min(); |
| | }; |
| |
|
| | #define instantiate_default_limit(type) \ |
| | template <> \ |
| | struct Limits<type> { \ |
| | static constexpr constant type max = metal::numeric_limits<type>::max(); \ |
| | static constexpr constant type min = metal::numeric_limits<type>::min(); \ |
| | static constexpr constant type finite_max = \ |
| | metal::numeric_limits<type>::max(); \ |
| | static constexpr constant type finite_min = \ |
| | metal::numeric_limits<type>::min(); \ |
| | }; |
| |
|
| | instantiate_default_limit(uint8_t); |
| | instantiate_default_limit(uint16_t); |
| | instantiate_default_limit(uint32_t); |
| | instantiate_default_limit(uint64_t); |
| | instantiate_default_limit(int8_t); |
| | instantiate_default_limit(int16_t); |
| | instantiate_default_limit(int32_t); |
| | instantiate_default_limit(int64_t); |
| |
|
| | #define instantiate_float_limit(type) \ |
| | template <> \ |
| | struct Limits<type> { \ |
| | static constexpr constant type max = \ |
| | metal::numeric_limits<type>::infinity(); \ |
| | static constexpr constant type min = \ |
| | -metal::numeric_limits<type>::infinity(); \ |
| | static constexpr constant type finite_max = \ |
| | metal::numeric_limits<type>::max(); \ |
| | static constexpr constant type finite_min = \ |
| | -metal::numeric_limits<type>::max(); \ |
| | }; |
| |
|
| | instantiate_float_limit(half); |
| | instantiate_float_limit(float); |
| | instantiate_float_limit(bfloat16_t); |
| |
|
| | template <> |
| | struct Limits<bool> { |
| | static constexpr constant bool max = true; |
| | static constexpr constant bool min = false; |
| | }; |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | #define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") |
| |
|
| | |
| | |
| |
|
| | template <typename IdxT = int64_t> |
| | METAL_FUNC IdxT elem_to_loc( |
| | IdxT elem, |
| | constant const int* shape, |
| | constant const int64_t* strides, |
| | int ndim) { |
| | IdxT loc = 0; |
| | for (int i = ndim - 1; i >= 0 && elem > 0; --i) { |
| | loc += (elem % shape[i]) * IdxT(strides[i]); |
| | elem /= shape[i]; |
| | } |
| | return loc; |
| | } |
| |
|
| | |
| | template <typename IdxT = int64_t> |
| | METAL_FUNC IdxT elem_to_loc( |
| | uint3 elem, |
| | constant const int* shape, |
| | constant const int64_t* strides, |
| | int ndim) { |
| | IdxT loc = |
| | elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]); |
| | for (int d = ndim - 3; d >= 0; --d) { |
| | loc += (elem.z % shape[d]) * IdxT(strides[d]); |
| | elem.z /= shape[d]; |
| | } |
| | return loc; |
| | } |
| |
|
| | |
| | |
| |
|
| | template <typename IdxT = int64_t> |
| | METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) { |
| | return elem * IdxT(stride); |
| | } |
| |
|
| | template <typename IdxT = int64_t> |
| | METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) { |
| | return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]); |
| | } |
| |
|
| | template <typename IdxT = int64_t> |
| | METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) { |
| | return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) + |
| | elem.z * IdxT(strides[0]); |
| | } |
| |
|
| | |
| | |
| |
|
| | template <typename IdxT = int64_t> |
| | METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd( |
| | uint3 elem, |
| | constant const int* shape, |
| | constant const int64_t* a_strides, |
| | constant const int64_t* b_strides, |
| | int ndim) { |
| | vec<IdxT, 2> loc = { |
| | IdxT( |
| | elem.x * IdxT(a_strides[ndim - 1]) + |
| | IdxT(elem.y) * IdxT(a_strides[ndim - 2])), |
| | IdxT( |
| | elem.x * IdxT(b_strides[ndim - 1]) + |
| | elem.y * IdxT(b_strides[ndim - 2]))}; |
| | for (int d = ndim - 3; d >= 0; --d) { |
| | uint l = elem.z % shape[d]; |
| | loc.x += l * IdxT(a_strides[d]); |
| | loc.y += l * IdxT(b_strides[d]); |
| | elem.z /= shape[d]; |
| | } |
| | return loc; |
| | } |
| |
|
| | template <typename IdxT = int64_t> |
| | METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd( |
| | uint3 elem, |
| | constant const int* shape, |
| | constant const int64_t* a_strides, |
| | constant const int64_t* b_strides, |
| | constant const int64_t* c_strides, |
| | int ndim) { |
| | vec<IdxT, 3> loc = { |
| | IdxT(elem.x * IdxT(a_strides[ndim - 1])) + |
| | IdxT(elem.y * IdxT(a_strides[ndim - 2])), |
| | IdxT(elem.x * IdxT(b_strides[ndim - 1])) + |
| | IdxT(elem.y * IdxT(b_strides[ndim - 2])), |
| | IdxT(elem.x * IdxT(c_strides[ndim - 1])) + |
| | IdxT(elem.y * IdxT(c_strides[ndim - 2]))}; |
| | for (int d = ndim - 3; d >= 0; --d) { |
| | uint l = elem.z % shape[d]; |
| | loc.x += l * IdxT(a_strides[d]); |
| | loc.y += l * IdxT(b_strides[d]); |
| | loc.z += l * IdxT(c_strides[d]); |
| | elem.z /= shape[d]; |
| | } |
| | return loc; |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | template <int DIM, typename OffsetT = size_t, bool General = true> |
| | struct LoopedElemToLoc { |
| | int dim; |
| | LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper; |
| | OffsetT offset{0}; |
| | int index{0}; |
| |
|
| | LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} |
| |
|
| | void next(const constant int* shape, const constant int64_t* strides) { |
| | if (dim == 0) { |
| | return; |
| | } |
| | index++; |
| | offset += OffsetT(strides[dim - 1]); |
| | if (index >= shape[dim - 1]) { |
| | index = 0; |
| | inner_looper.next(shape, strides); |
| | offset = inner_looper.offset; |
| | } |
| | } |
| |
|
| | void next(int n, const constant int* shape, const constant int64_t* strides) { |
| | if (dim == 0) { |
| | return; |
| | } |
| | index += n; |
| | offset += n * OffsetT(strides[dim - 1]); |
| |
|
| | if (index >= shape[dim - 1]) { |
| | int extra = index - shape[dim - 1]; |
| | if (extra >= shape[dim - 1]) { |
| | inner_looper.next(1 + extra / shape[dim - 1], shape, strides); |
| | extra = extra % shape[dim - 1]; |
| | } else { |
| | inner_looper.next(shape, strides); |
| | } |
| | index = 0; |
| | offset = inner_looper.offset; |
| | if (extra > 0) { |
| | next(extra, shape, strides); |
| | } |
| | } |
| | } |
| |
|
| | OffsetT location() { |
| | return offset; |
| | } |
| | }; |
| |
|
| | template <typename OffsetT> |
| | struct LoopedElemToLoc<1, OffsetT, true> { |
| | int dim; |
| | OffsetT offset{0}; |
| | uint index{0}; |
| |
|
| | LoopedElemToLoc(int dim) : dim(dim) {} |
| |
|
| | void next(const constant int* shape, const constant int64_t* strides) { |
| | index++; |
| | if (dim > 1) { |
| | offset = elem_to_loc<OffsetT>(index, shape, strides, dim); |
| | } else { |
| | offset += OffsetT(strides[0]); |
| | } |
| | } |
| |
|
| | void next(int n, const constant int* shape, const constant int64_t* strides) { |
| | index += n; |
| | if (dim > 1) { |
| | offset = elem_to_loc<OffsetT>(index, shape, strides, dim); |
| | } else { |
| | offset = index * OffsetT(strides[0]); |
| | } |
| | } |
| |
|
| | OffsetT location() { |
| | return offset; |
| | } |
| | }; |
| |
|
| | template <typename OffsetT> |
| | struct LoopedElemToLoc<1, OffsetT, false> { |
| | OffsetT offset{0}; |
| |
|
| | LoopedElemToLoc(int) {} |
| |
|
| | void next(const constant int*, const constant int64_t* strides) { |
| | offset += OffsetT(strides[0]); |
| | } |
| |
|
| | void next(int n, const constant int*, const constant int64_t* strides) { |
| | offset += n * OffsetT(strides[0]); |
| | } |
| |
|
| | OffsetT location() { |
| | return offset; |
| | } |
| | }; |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | template <typename T, typename U> |
| | inline T ceildiv(T N, U M) { |
| | return (N + M - 1) / M; |
| | } |
| |
|
| | |
| | inline float log1p(float x) { |
| | float xp1 = 1.0f + x; |
| | if (xp1 == Limits<float>::max) { |
| | return Limits<float>::max; |
| | } |
| | if (xp1 == 1.0f) { |
| | return x; |
| | } |
| |
|
| | return x * (metal::log(xp1) / (xp1 - 1.0f)); |
| | } |
| |
|
| | inline bfloat16_t log1p(bfloat16_t x) { |
| | float xp1 = 1.0f + static_cast<float>(x); |
| | if (xp1 == Limits<float>::max) { |
| | return Limits<bfloat16_t>::max; |
| | } |
| | if (xp1 == 1.0f) { |
| | return x; |
| | } |
| |
|
| | return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { |
| | return as_type<uint64_t>( |
| | metal::simd_shuffle_down(as_type<uint2>(data), delta)); |
| | } |
| |
|
| | inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { |
| | return as_type<int64_t>( |
| | metal::simd_shuffle_down(as_type<uint2>(data), delta)); |
| | } |
| |
|
| | inline bool simd_shuffle_down(bool data, uint16_t delta) { |
| | return simd_shuffle_down(static_cast<uint32_t>(data), delta); |
| | } |
| |
|
| | inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) { |
| | return as_type<uint64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta)); |
| | } |
| |
|
| | inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) { |
| | return as_type<int64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta)); |
| | } |
| |
|
| | inline bool simd_shuffle_up(bool data, uint16_t delta) { |
| | return simd_shuffle_up(static_cast<uint32_t>(data), delta); |
| | } |
| |
|
| | inline uint64_t |
| | simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) { |
| | return as_type<uint64_t>(metal::simd_shuffle_and_fill_up( |
| | as_type<uint2>(data), as_type<uint2>(filling), delta)); |
| | } |
| |
|
| | inline int64_t |
| | simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) { |
| | return as_type<int64_t>(metal::simd_shuffle_and_fill_up( |
| | as_type<uint2>(data), as_type<uint2>(filling), delta)); |
| | } |
| |
|
| | inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) { |
| | return simd_shuffle_and_fill_up( |
| | static_cast<uint32_t>(data), static_cast<uint32_t>(filling), delta); |
| | } |
| |
|
| | inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) { |
| | return as_type<uint64_t>(metal::simd_shuffle(as_type<uint2>(data), lane)); |
| | } |
| |
|
| | inline int64_t simd_shuffle(int64_t data, uint16_t lane) { |
| | return as_type<int64_t>(metal::simd_shuffle(as_type<uint2>(data), lane)); |
| | } |
| |
|
| | inline bool simd_shuffle(bool data, uint16_t lane) { |
| | return simd_shuffle(static_cast<uint32_t>(data), lane); |
| | } |
| |
|
| | |
| | template <bool condition, typename T, typename U> |
| | struct ConditionalType { |
| | using type = U; |
| | }; |
| |
|
| | template <typename T, typename U> |
| | struct ConditionalType<true, T, U> { |
| | using type = T; |
| | }; |