File size: 1,894 Bytes
20347e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
// bitsandbytes MPS Metal kernels - template instantiations
// Instantiates kernel variants for all (type, blocksize, quant_type) combos.

// clang-format off
#include "utils.h"
#include "gemm/gemm.h"
#include "quantized_utils.h"
#include "bnb_quantized.h"

// ============================================================================
// Instantiation macros
// ============================================================================

#define instantiate_bnb_kernel(name, type, blocksize, quant_type) \
  template [[host_name(                                           \
      #name "_" #type "_bs_" #blocksize "_qt_" #quant_type        \
  )]] [[kernel]] decltype(name<type, blocksize, quant_type>)      \
      name<type, blocksize, quant_type>;

// ---- Instantiate all kernel types for a given (type, blocksize, quant_type) ----

#define instantiate_bnb_all_kernels(type, blocksize, quant_type)     \
  instantiate_bnb_kernel(bnb_quantize_blockwise, type, blocksize, quant_type)   \
  instantiate_bnb_kernel(bnb_dequantize_blockwise, type, blocksize, quant_type) \
  instantiate_bnb_kernel(bnb_qmv, type, blocksize, quant_type)                 \
  instantiate_bnb_kernel(bnb_qmm_t, type, blocksize, quant_type)

// ---- Instantiate for all quant types (FP4=1, NF4=2) ----

#define instantiate_bnb_quant_types(type, blocksize)  \
  instantiate_bnb_all_kernels(type, blocksize, 1)     \
  instantiate_bnb_all_kernels(type, blocksize, 2)

// ---- Instantiate for all blocksizes ----

#define instantiate_bnb_blocksizes(type)     \
  instantiate_bnb_quant_types(type, 64)      \
  instantiate_bnb_quant_types(type, 128)    \
  instantiate_bnb_quant_types(type, 256)    \
  instantiate_bnb_quant_types(type, 512)

// ---- Instantiate for all scalar types ----

instantiate_bnb_blocksizes(half)
instantiate_bnb_blocksizes(bfloat16_t)
instantiate_bnb_blocksizes(float)

// clang-format on