File size: 5,475 Bytes
20347e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | // bitsandbytes MPS Metal kernels - NF4/FP4 codebook definitions and helpers
// Adapted from bitsandbytes CUDA kernels (kernels.cu) for Apple Metal
#pragma once
#include <metal_stdlib>
using namespace metal;
// ============================================================================
// Quant type enum (matches bitsandbytes common.h)
// ============================================================================
enum BnBQuantType {
BNB_FP4 = 1,
BNB_NF4 = 2,
};
// ============================================================================
// NF4 codebook - 16 values optimized for normal distributions
// Maps 4-bit indices (0-15) to float values in [-1, 1]
// ============================================================================
constant float NF4_CODEBOOK[16] = {
-1.0f,
-0.6961928009986877f,
-0.5250730514526367f,
-0.39491748809814453f,
-0.28444138169288635f,
-0.18477343022823334f,
-0.09105003625154495f,
0.0f,
0.07958029955625534f,
0.16093020141124725f,
0.24611230194568634f,
0.33791524171829224f,
0.44070982933044434f,
0.5626170039176941f,
0.7229568362236023f,
1.0f,
};
// ============================================================================
// FP4 codebook - 16 values using sign-magnitude FP4 encoding
// Indices 0-7: non-negative, indices 8-15: negative (bit 3 = sign)
// ============================================================================
constant float FP4_CODEBOOK[16] = {
0.0f,
0.005208333333f,
0.66666667f,
1.0f,
0.33333333f,
0.5f,
0.16666667f,
0.25f,
0.0f,
-0.005208333333f,
-0.66666667f,
-1.0f,
-0.33333333f,
-0.5f,
-0.16666667f,
-0.25f,
};
// ============================================================================
// Codebook accessor by quant_type template parameter
// ============================================================================
template <int quant_type>
inline constant float* bnb_codebook() {
if (quant_type == BNB_NF4) {
return NF4_CODEBOOK;
} else {
return FP4_CODEBOOK;
}
}
// ============================================================================
// NF4 quantization - binary search (matches CUDA dQuantizeNF4)
// Input: normalized value in [-1, 1]
// Output: 4-bit index (0-15)
// ============================================================================
inline uchar quantize_nf4(float x) {
if (x > 0.03979014977812767f) {
if (x > 0.3893125355243683f) {
if (x > 0.6427869200706482f) {
return (x > 0.8614784181118011f) ? 15 : 14;
}
return (x > 0.5016634166240692f) ? 13 : 12;
}
if (x > 0.2035212516784668f) {
return (x > 0.2920137718319893f) ? 11 : 10;
}
return (x > 0.1202552504837513f) ? 9 : 8;
}
if (x > -0.33967943489551544f) {
if (x > -0.13791173323988914f) {
return (x > -0.045525018125772476f) ? 7 : 6;
}
return (x > -0.23460740596055984f) ? 5 : 4;
}
if (x > -0.6106329262256622f) {
return (x > -0.4599952697753906f) ? 3 : 2;
}
return (x > -0.8480964004993439f) ? 1 : 0;
}
// ============================================================================
// FP4 quantization - binary search (matches CUDA dQuantizeFP4)
// Input: normalized value in [-1, 1]
// Output: 4-bit index (0-15), MSB = sign bit
// ============================================================================
inline uchar quantize_fp4(float x) {
uchar sign = (x < 0.0f) ? 8 : 0;
x = metal::abs(x);
uchar code;
if (x > 0.29166667f) {
if (x > 0.75f) {
code = (x > 0.8333333f) ? 3 : 2;
} else {
code = (x > 0.4166667f) ? 5 : 4;
}
} else {
if (x > 0.0859375f) {
code = (x > 0.20833333f) ? 7 : 6;
} else {
code = (x > 0.00260416f) ? 1 : 0;
}
}
return sign | code;
}
// ============================================================================
// Generic quantize dispatch by quant_type
// ============================================================================
template <int quant_type>
inline uchar bnb_quantize_value(float normalized) {
if (quant_type == BNB_NF4) {
return quantize_nf4(normalized);
} else {
return quantize_fp4(normalized);
}
}
// ============================================================================
// Dequantize a single 4-bit value using codebook lookup
// ============================================================================
template <int quant_type>
inline float bnb_dequantize_value(uchar nibble) {
return bnb_codebook<quant_type>()[nibble & 0x0f];
}
// ============================================================================
// BnB 4-bit dequantize for block loader (adapted from MLX affine dequantize)
// Unpacks N values from packed bytes using codebook lookup.
//
// BnB packing: high nibble = first element, low nibble = second element
// Each byte stores 2 4-bit values.
// ============================================================================
template <typename U, int N, int quant_type>
inline void bnb_dequantize(
const device uint8_t* w,
U absmax_val,
threadgroup U* w_local) {
constant float* codebook = bnb_codebook<quant_type>();
for (int i = 0; i < N / 2; i++) {
uint8_t byte_val = w[i];
uint8_t high = (byte_val >> 4) & 0x0f;
uint8_t low = byte_val & 0x0f;
w_local[2 * i] = U(codebook[high]) * absmax_val;
w_local[2 * i + 1] = U(codebook[low]) * absmax_val;
}
}
|