| | |
| | |
| |
|
| | |
| | #include "utils.h" |
| | #include "gemm/gemm.h" |
| | #include "quantized_utils.h" |
| | #include "bnb_quantized.h" |
| |
|
| | |
| | |
| | |
| |
|
| | #define instantiate_bnb_kernel(name, type, blocksize, quant_type) \ |
| | template [[host_name( \ |
| | #name "_" #type "_bs_" #blocksize "_qt_" #quant_type \ |
| | )]] [[kernel]] decltype(name<type, blocksize, quant_type>) \ |
| | name<type, blocksize, quant_type>; |
| |
|
| | |
| |
|
| | #define instantiate_bnb_all_kernels(type, blocksize, quant_type) \ |
| | instantiate_bnb_kernel(bnb_quantize_blockwise, type, blocksize, quant_type) \ |
| | instantiate_bnb_kernel(bnb_dequantize_blockwise, type, blocksize, quant_type) \ |
| | instantiate_bnb_kernel(bnb_qmv, type, blocksize, quant_type) \ |
| | instantiate_bnb_kernel(bnb_qmm_t, type, blocksize, quant_type) |
| |
|
| | |
| |
|
| | #define instantiate_bnb_quant_types(type, blocksize) \ |
| | instantiate_bnb_all_kernels(type, blocksize, 1) \ |
| | instantiate_bnb_all_kernels(type, blocksize, 2) |
| |
|
| | |
| |
|
| | #define instantiate_bnb_blocksizes(type) \ |
| | instantiate_bnb_quant_types(type, 64) \ |
| | instantiate_bnb_quant_types(type, 128) \ |
| | instantiate_bnb_quant_types(type, 256) \ |
| | instantiate_bnb_quant_types(type, 512) |
| |
|
| | |
| |
|
| | instantiate_bnb_blocksizes(half) |
| | instantiate_bnb_blocksizes(bfloat16_t) |
| | instantiate_bnb_blocksizes(float) |
| |
|
| | |
| |
|