File size: 1,894 Bytes
20347e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | // bitsandbytes MPS Metal kernels - template instantiations
// Instantiates kernel variants for all (type, blocksize, quant_type) combos.
// clang-format off
#include "utils.h"
#include "gemm/gemm.h"
#include "quantized_utils.h"
#include "bnb_quantized.h"
// ============================================================================
// Instantiation macros
// ============================================================================
#define instantiate_bnb_kernel(name, type, blocksize, quant_type) \
template [[host_name( \
#name "_" #type "_bs_" #blocksize "_qt_" #quant_type \
)]] [[kernel]] decltype(name<type, blocksize, quant_type>) \
name<type, blocksize, quant_type>;
// ---- Instantiate all kernel types for a given (type, blocksize, quant_type) ----
#define instantiate_bnb_all_kernels(type, blocksize, quant_type) \
instantiate_bnb_kernel(bnb_quantize_blockwise, type, blocksize, quant_type) \
instantiate_bnb_kernel(bnb_dequantize_blockwise, type, blocksize, quant_type) \
instantiate_bnb_kernel(bnb_qmv, type, blocksize, quant_type) \
instantiate_bnb_kernel(bnb_qmm_t, type, blocksize, quant_type)
// ---- Instantiate for all quant types (FP4=1, NF4=2) ----
#define instantiate_bnb_quant_types(type, blocksize) \
instantiate_bnb_all_kernels(type, blocksize, 1) \
instantiate_bnb_all_kernels(type, blocksize, 2)
// ---- Instantiate for all blocksizes ----
#define instantiate_bnb_blocksizes(type) \
instantiate_bnb_quant_types(type, 64) \
instantiate_bnb_quant_types(type, 128) \
instantiate_bnb_quant_types(type, 256) \
instantiate_bnb_quant_types(type, 512)
// ---- Instantiate for all scalar types ----
instantiate_bnb_blocksizes(half)
instantiate_bnb_blocksizes(bfloat16_t)
instantiate_bnb_blocksizes(float)
// clang-format on
|