// bitsandbytes MPS Metal kernels - NF4/FP4 codebook definitions and helpers // Adapted from bitsandbytes CUDA kernels (kernels.cu) for Apple Metal #pragma once #include using namespace metal; // ============================================================================ // Quant type enum (matches bitsandbytes common.h) // ============================================================================ enum BnBQuantType { BNB_FP4 = 1, BNB_NF4 = 2, }; // ============================================================================ // NF4 codebook - 16 values optimized for normal distributions // Maps 4-bit indices (0-15) to float values in [-1, 1] // ============================================================================ constant float NF4_CODEBOOK[16] = { -1.0f, -0.6961928009986877f, -0.5250730514526367f, -0.39491748809814453f, -0.28444138169288635f, -0.18477343022823334f, -0.09105003625154495f, 0.0f, 0.07958029955625534f, 0.16093020141124725f, 0.24611230194568634f, 0.33791524171829224f, 0.44070982933044434f, 0.5626170039176941f, 0.7229568362236023f, 1.0f, }; // ============================================================================ // FP4 codebook - 16 values using sign-magnitude FP4 encoding // Indices 0-7: non-negative, indices 8-15: negative (bit 3 = sign) // ============================================================================ constant float FP4_CODEBOOK[16] = { 0.0f, 0.005208333333f, 0.66666667f, 1.0f, 0.33333333f, 0.5f, 0.16666667f, 0.25f, 0.0f, -0.005208333333f, -0.66666667f, -1.0f, -0.33333333f, -0.5f, -0.16666667f, -0.25f, }; // ============================================================================ // Codebook accessor by quant_type template parameter // ============================================================================ template inline constant float* bnb_codebook() { if (quant_type == BNB_NF4) { return NF4_CODEBOOK; } else { return FP4_CODEBOOK; } } // ============================================================================ // NF4 quantization - binary search (matches CUDA dQuantizeNF4) // Input: normalized value in [-1, 1] // Output: 4-bit index (0-15) // ============================================================================ inline uchar quantize_nf4(float x) { if (x > 0.03979014977812767f) { if (x > 0.3893125355243683f) { if (x > 0.6427869200706482f) { return (x > 0.8614784181118011f) ? 15 : 14; } return (x > 0.5016634166240692f) ? 13 : 12; } if (x > 0.2035212516784668f) { return (x > 0.2920137718319893f) ? 11 : 10; } return (x > 0.1202552504837513f) ? 9 : 8; } if (x > -0.33967943489551544f) { if (x > -0.13791173323988914f) { return (x > -0.045525018125772476f) ? 7 : 6; } return (x > -0.23460740596055984f) ? 5 : 4; } if (x > -0.6106329262256622f) { return (x > -0.4599952697753906f) ? 3 : 2; } return (x > -0.8480964004993439f) ? 1 : 0; } // ============================================================================ // FP4 quantization - binary search (matches CUDA dQuantizeFP4) // Input: normalized value in [-1, 1] // Output: 4-bit index (0-15), MSB = sign bit // ============================================================================ inline uchar quantize_fp4(float x) { uchar sign = (x < 0.0f) ? 8 : 0; x = metal::abs(x); uchar code; if (x > 0.29166667f) { if (x > 0.75f) { code = (x > 0.8333333f) ? 3 : 2; } else { code = (x > 0.4166667f) ? 5 : 4; } } else { if (x > 0.0859375f) { code = (x > 0.20833333f) ? 7 : 6; } else { code = (x > 0.00260416f) ? 1 : 0; } } return sign | code; } // ============================================================================ // Generic quantize dispatch by quant_type // ============================================================================ template inline uchar bnb_quantize_value(float normalized) { if (quant_type == BNB_NF4) { return quantize_nf4(normalized); } else { return quantize_fp4(normalized); } } // ============================================================================ // Dequantize a single 4-bit value using codebook lookup // ============================================================================ template inline float bnb_dequantize_value(uchar nibble) { return bnb_codebook()[nibble & 0x0f]; } // ============================================================================ // BnB 4-bit dequantize for block loader (adapted from MLX affine dequantize) // Unpacks N values from packed bytes using codebook lookup. // // BnB packing: high nibble = first element, low nibble = second element // Each byte stores 2 4-bit values. // ============================================================================ template inline void bnb_dequantize( const device uint8_t* w, U absmax_val, threadgroup U* w_local) { constant float* codebook = bnb_codebook(); for (int i = 0; i < N / 2; i++) { uint8_t byte_val = w[i]; uint8_t high = (byte_val >> 4) & 0x0f; uint8_t low = byte_val & 0x0f; w_local[2 * i] = U(codebook[high]) * absmax_val; w_local[2 * i + 1] = U(codebook[low]) * absmax_val; } }