| | |
| |
|
| | #pragma once |
| |
|
| | #include "gemm/utils.h" |
| |
|
| | |
| | |
| | |
| |
|
| | namespace mlx { |
| | namespace steel { |
| |
|
| | template <typename OutT, typename InT> |
| | struct TransformNone { |
| | static METAL_FUNC OutT apply(InT x) { |
| | return static_cast<OutT>(x); |
| | } |
| |
|
| | static METAL_FUNC OutT apply(InT x, OutT) { |
| | return static_cast<OutT>(x); |
| | } |
| | }; |
| |
|
| | template <typename OutT, typename InT> |
| | struct TransformAdd { |
| | TransformAdd(const float, const float) {} |
| |
|
| | static METAL_FUNC OutT apply(InT x) { |
| | return static_cast<OutT>(x); |
| | } |
| |
|
| | static METAL_FUNC OutT apply(InT x, OutT c) { |
| | return static_cast<OutT>(x) + c; |
| | } |
| | }; |
| |
|
| | template <typename OutT, typename InT> |
| | struct TransformAxpby { |
| | const float alpha; |
| | const float beta; |
| |
|
| | TransformAxpby(const float alpha_, const float beta_) |
| | : alpha(alpha_), beta(beta_) {} |
| |
|
| | static METAL_FUNC OutT apply(InT x) { |
| | return static_cast<OutT>(x); |
| | } |
| |
|
| | METAL_FUNC OutT apply(InT x, OutT c) const { |
| | return static_cast<OutT>( |
| | x * static_cast<InT>(alpha) + (static_cast<OutT>(beta) * c)); |
| | } |
| | }; |
| |
|
| | template <typename T> |
| | struct AccumHelper { |
| | typedef float accum_type; |
| | }; |
| |
|
| | struct BlockSwizzle { |
| | static METAL_FUNC int2 |
| | swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { |
| | const int tid_x = (tid.x) >> swizzle_log; |
| | const int tid_y = |
| | ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); |
| | return int2(tid_x, tid_y); |
| | } |
| | }; |
| |
|
| | } |
| | } |