// Copyright © 2024 Apple Inc. #pragma once #include #include "gemm/utils/type_traits.h" #pragma METAL internals : enable namespace mlx { namespace steel { /////////////////////////////////////////////////////////////////////////////// // Integral constant with casting /////////////////////////////////////////////////////////////////////////////// template struct integral_constant { static constexpr constant T value = v; using value_type = T; using type = integral_constant; METAL_FUNC constexpr operator value_type() const noexcept { return value; } // METAL_FUNC constexpr value_type operator()() const noexcept { // return value; // } }; template using bool_constant = integral_constant; using true_type = bool_constant; using false_type = bool_constant; template struct is_integral : bool_constant::value> {}; template struct is_integral> : bool_constant::value> {}; template constexpr constant bool is_integral_v = is_integral::value; template using Int = integral_constant; /////////////////////////////////////////////////////////////////////////////// // Binary Operators on Integral constants /////////////////////////////////////////////////////////////////////////////// #define integral_const_binop(__op__, __operator__) \ template \ METAL_FUNC constexpr auto __operator__( \ integral_constant, integral_constant) { \ constexpr auto res = tv __op__ uv; \ return integral_constant{}; \ } integral_const_binop(+, operator+); integral_const_binop(-, operator-); integral_const_binop(*, operator*); integral_const_binop(/, operator/); integral_const_binop(==, operator==); integral_const_binop(!=, operator!=); integral_const_binop(<, operator<); integral_const_binop(>, operator>); integral_const_binop(<=, operator<=); integral_const_binop(>=, operator>=); integral_const_binop(&&, operator&&); integral_const_binop(||, operator||); template >> METAL_FUNC constexpr auto operator||(true_type, T) { return true_type{}; } template >> METAL_FUNC constexpr auto operator||(T, true_type) { return true_type{}; } template >> METAL_FUNC constexpr auto operator&&(false_type, T) { return false_type{}; } template >> METAL_FUNC constexpr auto operator&&(T, false_type) { return false_type{}; } // Dispatch utilities template void dispatch_bool(bool v, F f) { if (v) { f(true_type{}); } else { f(false_type{}); } } template constexpr void const_for_loop(F f) { if constexpr (start < stop) { constexpr auto idx = Int{}; f(idx); const_for_loop(f); } } #undef integral_const_binop /////////////////////////////////////////////////////////////////////////////// // Reduction operators /////////////////////////////////////////////////////////////////////////////// template METAL_FUNC constexpr T sum(T x) { return x; } template METAL_FUNC constexpr auto sum(T x, Us... us) { return x + sum(us...); } } // namespace steel } // namespace mlx #pragma METAL internals : disable