// Copyright © 2024 Apple Inc. #pragma once /////////////////////////////////////////////////////////////////////////////// // GEMM param classes /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { struct GEMMParams { const int M; const int N; const int K; const int lda; const int ldb; const int ldd; const int tiles_n; const int tiles_m; const int64_t batch_stride_a; const int64_t batch_stride_b; const int64_t batch_stride_d; const int swizzle_log; const int gemm_k_iterations_aligned; const int batch_ndim; }; struct GEMMSpiltKParams { const int M; const int N; const int K; const int lda; const int ldb; const int ldc; const int tiles_n; const int tiles_m; const int split_k_partitions; const int split_k_partition_stride; const int split_k_partition_size; const int gemm_k_iterations_aligned; }; struct GEMMAddMMParams { const int ldc; const int fdc; const int64_t batch_stride_c; const float alpha; const float beta; }; } // namespace steel } // namespace mlx