| // bitsandbytes MPS Metal kernels - NF4/FP4 codebook definitions and helpers | |
| // Adapted from bitsandbytes CUDA kernels (kernels.cu) for Apple Metal | |
| using namespace metal; | |
| // ============================================================================ | |
| // Quant type enum (matches bitsandbytes common.h) | |
| // ============================================================================ | |
| enum BnBQuantType { | |
| BNB_FP4 = 1, | |
| BNB_NF4 = 2, | |
| }; | |
| // ============================================================================ | |
| // NF4 codebook - 16 values optimized for normal distributions | |
| // Maps 4-bit indices (0-15) to float values in [-1, 1] | |
| // ============================================================================ | |
| constant float NF4_CODEBOOK[16] = { | |
| -1.0f, | |
| -0.6961928009986877f, | |
| -0.5250730514526367f, | |
| -0.39491748809814453f, | |
| -0.28444138169288635f, | |
| -0.18477343022823334f, | |
| -0.09105003625154495f, | |
| 0.0f, | |
| 0.07958029955625534f, | |
| 0.16093020141124725f, | |
| 0.24611230194568634f, | |
| 0.33791524171829224f, | |
| 0.44070982933044434f, | |
| 0.5626170039176941f, | |
| 0.7229568362236023f, | |
| 1.0f, | |
| }; | |
| // ============================================================================ | |
| // FP4 codebook - 16 values using sign-magnitude FP4 encoding | |
| // Indices 0-7: non-negative, indices 8-15: negative (bit 3 = sign) | |
| // ============================================================================ | |
| constant float FP4_CODEBOOK[16] = { | |
| 0.0f, | |
| 0.005208333333f, | |
| 0.66666667f, | |
| 1.0f, | |
| 0.33333333f, | |
| 0.5f, | |
| 0.16666667f, | |
| 0.25f, | |
| 0.0f, | |
| -0.005208333333f, | |
| -0.66666667f, | |
| -1.0f, | |
| -0.33333333f, | |
| -0.5f, | |
| -0.16666667f, | |
| -0.25f, | |
| }; | |
| // ============================================================================ | |
| // Codebook accessor by quant_type template parameter | |
| // ============================================================================ | |
| template <int quant_type> | |
| inline constant float* bnb_codebook() { | |
| if (quant_type == BNB_NF4) { | |
| return NF4_CODEBOOK; | |
| } else { | |
| return FP4_CODEBOOK; | |
| } | |
| } | |
| // ============================================================================ | |
| // NF4 quantization - binary search (matches CUDA dQuantizeNF4) | |
| // Input: normalized value in [-1, 1] | |
| // Output: 4-bit index (0-15) | |
| // ============================================================================ | |
| inline uchar quantize_nf4(float x) { | |
| if (x > 0.03979014977812767f) { | |
| if (x > 0.3893125355243683f) { | |
| if (x > 0.6427869200706482f) { | |
| return (x > 0.8614784181118011f) ? 15 : 14; | |
| } | |
| return (x > 0.5016634166240692f) ? 13 : 12; | |
| } | |
| if (x > 0.2035212516784668f) { | |
| return (x > 0.2920137718319893f) ? 11 : 10; | |
| } | |
| return (x > 0.1202552504837513f) ? 9 : 8; | |
| } | |
| if (x > -0.33967943489551544f) { | |
| if (x > -0.13791173323988914f) { | |
| return (x > -0.045525018125772476f) ? 7 : 6; | |
| } | |
| return (x > -0.23460740596055984f) ? 5 : 4; | |
| } | |
| if (x > -0.6106329262256622f) { | |
| return (x > -0.4599952697753906f) ? 3 : 2; | |
| } | |
| return (x > -0.8480964004993439f) ? 1 : 0; | |
| } | |
| // ============================================================================ | |
| // FP4 quantization - binary search (matches CUDA dQuantizeFP4) | |
| // Input: normalized value in [-1, 1] | |
| // Output: 4-bit index (0-15), MSB = sign bit | |
| // ============================================================================ | |
| inline uchar quantize_fp4(float x) { | |
| uchar sign = (x < 0.0f) ? 8 : 0; | |
| x = metal::abs(x); | |
| uchar code; | |
| if (x > 0.29166667f) { | |
| if (x > 0.75f) { | |
| code = (x > 0.8333333f) ? 3 : 2; | |
| } else { | |
| code = (x > 0.4166667f) ? 5 : 4; | |
| } | |
| } else { | |
| if (x > 0.0859375f) { | |
| code = (x > 0.20833333f) ? 7 : 6; | |
| } else { | |
| code = (x > 0.00260416f) ? 1 : 0; | |
| } | |
| } | |
| return sign | code; | |
| } | |
| // ============================================================================ | |
| // Generic quantize dispatch by quant_type | |
| // ============================================================================ | |
| template <int quant_type> | |
| inline uchar bnb_quantize_value(float normalized) { | |
| if (quant_type == BNB_NF4) { | |
| return quantize_nf4(normalized); | |
| } else { | |
| return quantize_fp4(normalized); | |
| } | |
| } | |
| // ============================================================================ | |
| // Dequantize a single 4-bit value using codebook lookup | |
| // ============================================================================ | |
| template <int quant_type> | |
| inline float bnb_dequantize_value(uchar nibble) { | |
| return bnb_codebook<quant_type>()[nibble & 0x0f]; | |
| } | |
| // ============================================================================ | |
| // BnB 4-bit dequantize for block loader (adapted from MLX affine dequantize) | |
| // Unpacks N values from packed bytes using codebook lookup. | |
| // | |
| // BnB packing: high nibble = first element, low nibble = second element | |
| // Each byte stores 2 4-bit values. | |
| // ============================================================================ | |
| template <typename U, int N, int quant_type> | |
| inline void bnb_dequantize( | |
| const device uint8_t* w, | |
| U absmax_val, | |
| threadgroup U* w_local) { | |
| constant float* codebook = bnb_codebook<quant_type>(); | |
| for (int i = 0; i < N / 2; i++) { | |
| uint8_t byte_val = w[i]; | |
| uint8_t high = (byte_val >> 4) & 0x0f; | |
| uint8_t low = byte_val & 0x0f; | |
| w_local[2 * i] = U(codebook[high]) * absmax_val; | |
| w_local[2 * i + 1] = U(codebook[low]) * absmax_val; | |
| } | |
| } | |