// Copyright © 2024 Apple Inc. #pragma once #include "gemm/utils.h" /////////////////////////////////////////////////////////////////////////////// // Transforms and Epilogues /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { template struct TransformNone { static METAL_FUNC OutT apply(InT x) { return static_cast(x); } static METAL_FUNC OutT apply(InT x, OutT) { return static_cast(x); } }; template struct TransformAdd { TransformAdd(const float, const float) {} static METAL_FUNC OutT apply(InT x) { return static_cast(x); } static METAL_FUNC OutT apply(InT x, OutT c) { return static_cast(x) + c; } }; template struct TransformAxpby { const float alpha; const float beta; TransformAxpby(const float alpha_, const float beta_) : alpha(alpha_), beta(beta_) {} static METAL_FUNC OutT apply(InT x) { return static_cast(x); } METAL_FUNC OutT apply(InT x, OutT c) const { return static_cast( x * static_cast(alpha) + (static_cast(beta) * c)); } }; template struct AccumHelper { typedef float accum_type; }; struct BlockSwizzle { static METAL_FUNC int2 swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { const int tid_x = (tid.x) >> swizzle_log; const int tid_y = ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); return int2(tid_x, tid_y); } }; } // namespace steel } // namespace mlx