| // Copyright © 2024 Apple Inc. | |
| /////////////////////////////////////////////////////////////////////////////// | |
| // GEMM param classes | |
| /////////////////////////////////////////////////////////////////////////////// | |
| namespace mlx { | |
| namespace steel { | |
| struct GEMMParams { | |
| const int M; | |
| const int N; | |
| const int K; | |
| const int lda; | |
| const int ldb; | |
| const int ldd; | |
| const int tiles_n; | |
| const int tiles_m; | |
| const int64_t batch_stride_a; | |
| const int64_t batch_stride_b; | |
| const int64_t batch_stride_d; | |
| const int swizzle_log; | |
| const int gemm_k_iterations_aligned; | |
| const int batch_ndim; | |
| }; | |
| struct GEMMSpiltKParams { | |
| const int M; | |
| const int N; | |
| const int K; | |
| const int lda; | |
| const int ldb; | |
| const int ldc; | |
| const int tiles_n; | |
| const int tiles_m; | |
| const int split_k_partitions; | |
| const int split_k_partition_stride; | |
| const int split_k_partition_size; | |
| const int gemm_k_iterations_aligned; | |
| }; | |
| struct GEMMAddMMParams { | |
| const int ldc; | |
| const int fdc; | |
| const int64_t batch_stride_c; | |
| const float alpha; | |
| const float beta; | |
| }; | |
| } // namespace steel | |
| } // namespace mlx | |