| | |
| |
|
| | #pragma once |
| |
|
| | #include <metal_stdlib> |
| | #include "gemm/utils/type_traits.h" |
| |
|
| | #pragma METAL internals : enable |
| |
|
| | namespace mlx { |
| | namespace steel { |
| |
|
| | |
| | |
| | |
| |
|
| | template <typename T, T v> |
| | struct integral_constant { |
| | static constexpr constant T value = v; |
| | using value_type = T; |
| | using type = integral_constant; |
| |
|
| | METAL_FUNC constexpr operator value_type() const noexcept { |
| | return value; |
| | } |
| |
|
| | |
| | |
| | |
| | }; |
| |
|
| | template <bool B> |
| | using bool_constant = integral_constant<bool, B>; |
| | using true_type = bool_constant<true>; |
| | using false_type = bool_constant<false>; |
| |
|
| | template <class T> |
| | struct is_integral : bool_constant<metal::is_integral<T>::value> {}; |
| |
|
| | template <class T, T v> |
| | struct is_integral<integral_constant<T, v>> |
| | : bool_constant<metal::is_integral<T>::value> {}; |
| |
|
| | template <typename T> |
| | constexpr constant bool is_integral_v = is_integral<T>::value; |
| |
|
| | template <int val> |
| | using Int = integral_constant<int, val>; |
| |
|
| | |
| | |
| | |
| |
|
| | #define integral_const_binop(__op__, __operator__) \ |
| | template <typename T, T tv, typename U, U uv> \ |
| | METAL_FUNC constexpr auto __operator__( \ |
| | integral_constant<T, tv>, integral_constant<U, uv>) { \ |
| | constexpr auto res = tv __op__ uv; \ |
| | return integral_constant<decltype(res), res>{}; \ |
| | } |
| |
|
| | integral_const_binop(+, operator+); |
| | integral_const_binop(-, operator-); |
| | integral_const_binop(*, operator*); |
| | integral_const_binop(/, operator/); |
| |
|
| | integral_const_binop(==, operator==); |
| | integral_const_binop(!=, operator!=); |
| | integral_const_binop(<, operator<); |
| | integral_const_binop(>, operator>); |
| | integral_const_binop(<=, operator<=); |
| | integral_const_binop(>=, operator>=); |
| |
|
| | integral_const_binop(&&, operator&&); |
| | integral_const_binop(||, operator||); |
| |
|
| | template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>> |
| | METAL_FUNC constexpr auto operator||(true_type, T) { |
| | return true_type{}; |
| | } |
| | template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>> |
| | METAL_FUNC constexpr auto operator||(T, true_type) { |
| | return true_type{}; |
| | } |
| |
|
| | template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>> |
| | METAL_FUNC constexpr auto operator&&(false_type, T) { |
| | return false_type{}; |
| | } |
| |
|
| | template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>> |
| | METAL_FUNC constexpr auto operator&&(T, false_type) { |
| | return false_type{}; |
| | } |
| |
|
| | |
| | template <typename F> |
| | void dispatch_bool(bool v, F f) { |
| | if (v) { |
| | f(true_type{}); |
| | } else { |
| | f(false_type{}); |
| | } |
| | } |
| |
|
| | template <int start, int stop, int step, typename F> |
| | constexpr void const_for_loop(F f) { |
| | if constexpr (start < stop) { |
| | constexpr auto idx = Int<start>{}; |
| | f(idx); |
| | const_for_loop<start + step, stop, step, F>(f); |
| | } |
| | } |
| |
|
| | #undef integral_const_binop |
| |
|
| | |
| | |
| | |
| |
|
| | template <typename T> |
| | METAL_FUNC constexpr T sum(T x) { |
| | return x; |
| | } |
| |
|
| | template <typename T, typename... Us> |
| | METAL_FUNC constexpr auto sum(T x, Us... us) { |
| | return x + sum(us...); |
| | } |
| |
|
| | } |
| | } |
| |
|
| | #pragma METAL internals : disable |