// Copyright © 2023 Apple Inc. #pragma once #if defined __METAL__ || defined MLX_METAL_JIT #define MTL_CONST constant #else #define MTL_CONST #endif static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; static MTL_CONST constexpr int REDUCE_N_READS = 4; static MTL_CONST constexpr int REDUCE_N_WRITES = 4; static MTL_CONST constexpr int SOFTMAX_N_READS = 4; static MTL_CONST constexpr int RMS_N_READS = 4; static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096; // Instantiate a templated kernel. // Extra args are used as template parameters: // e.g. instantiate_kernel(binary_int, binary, a, b) -> // [[host_name(binary_int)]] [kernel] binary #define instantiate_kernel(name, func, ...) \ template [[host_name( \ name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;