// bitsandbytes MPS Metal kernels - 4-bit quantized operations // Adapted from MLX quantized.h for bitsandbytes NF4/FP4 format. // // Key differences from MLX affine quantization: // MLX: dequant(q) = scale * q_int + bias (linear mapping) // BnB: dequant(q) = codebook[q_int] * absmax (lookup-based) // // Packing format: // BnB: high nibble = first element, low nibble = second element // Two 4-bit values per byte, pack_factor = 2 #include #include #include "bnb_types.h" using namespace metal; #define MLX_MTL_CONST static constant constexpr const MLX_MTL_CONST int SIMD_SIZE = 32; // ============================================================================ // BnBQuantizedBlockLoader // // Loads blocks of BnB 4-bit packed weights into threadgroup memory, // performing codebook dequantization on the fly. // Adapted from MLX QuantizedBlockLoader. // // Template parameters: // T - output scalar type (float16_t, bfloat16_t, float) // BROWS - number of rows in the tile // BCOLS - number of columns in the tile (unpacked) // dst_ld - leading dimension of destination (threadgroup memory) // reduction_dim - 0 for K along rows, 1 for K along columns // tgp_size - threads per threadgroup // blocksize - BnB blocksize (elements per absmax value) // quant_type - BNB_FP4 (1) or BNB_NF4 (2) // ============================================================================ template < typename T, short BROWS, short BCOLS, short dst_ld, short reduction_dim, short tgp_size, short blocksize, int quant_type> struct BnBQuantizedBlockLoader { static_assert( BCOLS <= blocksize, "The blocksize should be larger than the tile columns"); static_assert( blocksize % BCOLS == 0, "The blocksize should be divisible by the tile columns"); MLX_MTL_CONST short pack_factor = 2; MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; MLX_MTL_CONST short group_steps = blocksize / BCOLS; const int src_ld; const int tile_stride; short group_step_cnt; const int group_stride; const short thread_idx; const short bi; const short bj; threadgroup T* dst; const device uint8_t* src; const device float* absmax_ptr; BnBQuantizedBlockLoader( const device uint8_t* src_, const device float* absmax_, const int src_ld_, threadgroup T* dst_, ushort simd_group_id [[simdgroup_index_in_threadgroup]], ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride( reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor), group_step_cnt(0), group_stride(BROWS * src_ld / blocksize), thread_idx(simd_group_id * 32 + simd_lane_id), bi(n_reads * thread_idx / BCOLS_PACKED), bj((n_reads * thread_idx) % BCOLS_PACKED), dst(dst_ + bi * dst_ld + bj * pack_factor), src(src_ + bi * src_ld / pack_factor + bj), absmax_ptr(absmax_ + bi * src_ld / blocksize) {} void load_unsafe() const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } float am = *absmax_ptr; for (int i = 0; i < n_reads; i++) { bnb_dequantize(src + i, T(am), dst + i * pack_factor); } } void load_safe(short2 src_tile_dim) const { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } if (reduction_dim == 1 && bi >= src_tile_dim.x) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } if (reduction_dim == 0 && bi >= src_tile_dim.y) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } float am = *absmax_ptr; for (int i = 0; i < n_reads; i++) { bnb_dequantize(src + i, T(am), dst + i * pack_factor); } } void next() { src += tile_stride; if (reduction_dim == 1) { if (group_steps > 1) { group_step_cnt++; if (group_step_cnt == group_steps) { group_step_cnt = 0; absmax_ptr++; } } else { absmax_ptr++; } } else { absmax_ptr += group_stride; } } }; // ============================================================================ // BnB GEMV (matrix-vector multiply with 4-bit quantized weights) // // Computes y = dequant(W) @ x // W: [N, K/2] packed bytes, absmax: [N, ceil(K/blocksize)], x: [K], y: [N] // // Each simdgroup handles results_per_simdgroup output rows. // Each thread processes values_per_thread elements of K per iteration. // ============================================================================ template METAL_FUNC void bnb_qmv_impl( const device uint8_t* w, const device float* absmax, const device T* x, device T* y, const constant int& in_vec_size, const constant int& out_vec_size, uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int bytes_per_thread = 4; constexpr int values_per_thread = bytes_per_thread * 2; constexpr int block_size_k = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = blocksize / values_per_thread; constant float* codebook = bnb_codebook(); typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; const int K_packed = in_vec_size / 2; const int K_groups = (in_vec_size + blocksize - 1) / blocksize; const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; if (out_row >= out_vec_size) { return; } const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); const device uint8_t* ws = w + used_out_row * K_packed + simd_lid * bytes_per_thread; const device float* am = absmax + used_out_row * K_groups + simd_lid / scale_step_per_thread; const device T* xi = x + tid.x * in_vec_size + simd_lid * values_per_thread; y += tid.x * out_vec_size + used_out_row; int k = 0; for (; k < in_vec_size - block_size_k; k += block_size_k) { // Load x values for (int i = 0; i < values_per_thread; i++) { x_thread[i] = U(xi[i]); } // Compute dot product for each output row for (int row = 0; row < results_per_simdgroup; row++) { const device uint8_t* wl = ws + row * K_packed; U scale = U(am[row * K_groups]); U accum = 0; for (int i = 0; i < bytes_per_thread; i++) { uint8_t byte_val = wl[i]; U w0 = U(codebook[(byte_val >> 4) & 0x0f]); U w1 = U(codebook[byte_val & 0x0f]); accum += x_thread[2 * i] * w0 + x_thread[2 * i + 1] * w1; } result[row] += accum * scale; } ws += block_size_k / 2; am += block_size_k / blocksize; xi += block_size_k; } // Handle remaining K elements const int remaining = clamp( static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); if (remaining > 0) { for (int i = 0; i < remaining; i++) { x_thread[i] = U(xi[i]); } for (int i = remaining; i < values_per_thread; i++) { x_thread[i] = 0; } for (int row = 0; row < results_per_simdgroup; row++) { const device uint8_t* wl = ws + row * K_packed; U scale = U(am[row * K_groups]); U accum = 0; int bytes_to_read = (remaining + 1) / 2; for (int i = 0; i < bytes_to_read; i++) { uint8_t byte_val = wl[i]; U w0 = U(codebook[(byte_val >> 4) & 0x0f]); U w1 = U(codebook[byte_val & 0x0f]); accum += x_thread[2 * i] * w0 + x_thread[2 * i + 1] * w1; } result[row] += accum * scale; } } // Reduce across SIMD lanes for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]); if (simd_lid == 0) { y[row] = static_cast(result[row]); } } } // ============================================================================ // BnB GEMM with transposed weight (y = x @ dequant(w).T) // // x: [M, K], w: [N, K/2] packed, absmax: [N, ceil(K/blocksize)], y: [M, N] // // Uses tiled matrix multiply with BnBQuantizedBlockLoader for on-the-fly // dequantization of weights during the GEMM computation. // ============================================================================ template < typename T, const int blocksize, const int quant_type, const int BM = 32, const int BK = 32, const int BN = 32> METAL_FUNC void bnb_qmm_t_impl( const device uint8_t* w, const device float* absmax, const device T* x, device T* y, threadgroup T* Xs, threadgroup T* Ws, const constant int& K, const constant int& N, const constant int& M, uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); (void)lid; constexpr int WM = 2; constexpr int WN = 2; constexpr int pack_factor = 2; constexpr int BK_padded = (BK + 16 / sizeof(T)); using mma_t = mlx::steel:: BlockMMA; using loader_x_t = mlx::steel::BlockLoader; using loader_w_t = BnBQuantizedBlockLoader< T, BN, BK, BK_padded, 1, WM * WN * SIMD_SIZE, blocksize, quant_type>; const int K_packed = K / pack_factor; const int K_groups = (K + blocksize - 1) / blocksize; const int y_row = tid.y * BM; const int y_col = tid.x * BN; x += y_row * static_cast(K); w += y_col * K_packed; absmax += y_col * K_groups; y += y_row * static_cast(N) + y_col; const short num_els = min(BM, M - y_row); const short num_outs = min(BN, N - y_col); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_w_t loader_w( (const device uint8_t*)w, absmax, K, Ws, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { if (num_outs < BN) { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_safe(short2(BK, num_outs)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_safe(short2(BK, num_els)); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } else { if (num_outs < BN) { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_safe(short2(BK, num_outs)); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } else { for (int k = 0; k < K; k += BK) { threadgroup_barrier(mem_flags::mem_threadgroup); loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); } } } // Store results threadgroup_barrier(mem_flags::mem_threadgroup); if (num_els < BM || num_outs < BN) { mma_op.store_result_safe(y, N, short2(num_outs, num_els)); } else { mma_op.store_result(y, N); } } // ============================================================================ // Kernel entry points // ============================================================================ // ---- Standalone blockwise quantize ---- // Each thread handles one block of elements. template [[kernel]] void bnb_quantize_blockwise( const device T* input [[buffer(0)]], device float* absmax [[buffer(1)]], device uint8_t* packed [[buffer(2)]], const constant int& n [[buffer(3)]], uint gid [[thread_position_in_grid]]) { const int num_blocks = (n + blocksize - 1) / blocksize; if (static_cast(gid) >= num_blocks) { return; } int block_start = gid * blocksize; int block_end = min(block_start + blocksize, n); // Find absmax for this block float max_val = 0.0f; for (int i = block_start; i < block_end; i++) { float current = metal::abs(float(input[i])); max_val = metal::max(max_val, current); } absmax[gid] = max_val; float inv = (max_val > 0.0f) ? 1.0f / max_val : 0.0f; // Quantize and pack pairs of values int out_byte = block_start / 2; for (int i = block_start; i < block_end; i += 2) { float norm0 = (max_val > 0.0f) ? clamp(float(input[i]) * inv, -1.0f, 1.0f) : 0.0f; uchar q0 = bnb_quantize_value(norm0); uchar q1 = 0; if (i + 1 < block_end) { float norm1 = (max_val > 0.0f) ? clamp(float(input[i + 1]) * inv, -1.0f, 1.0f) : 0.0f; q1 = bnb_quantize_value(norm1); } packed[out_byte++] = (q0 << 4) | (q1 & 0x0f); } } // ---- Standalone blockwise dequantize ---- // Each threadgroup handles one block. Threads within share the absmax. template [[kernel]] void bnb_dequantize_blockwise( const device uint8_t* packed [[buffer(0)]], const device float* absmax [[buffer(1)]], device T* output [[buffer(2)]], const constant int& n [[buffer(3)]], uint tgid [[threadgroup_position_in_grid]], uint tid [[thread_index_in_threadgroup]], uint tg_size [[threads_per_threadgroup]]) { const int num_blocks = (n + blocksize - 1) / blocksize; if (static_cast(tgid) >= num_blocks) { return; } constant float* codebook = bnb_codebook(); int block_start = tgid * blocksize; int block_end = min(block_start + blocksize, n); threadgroup float shared_scale = 0.0f; if (tid == 0) { shared_scale = absmax[tgid]; } threadgroup_barrier(mem_flags::mem_threadgroup); float scale = shared_scale; int pairs_in_block = (block_end - block_start + 1) / 2; for (int pair = static_cast(tid); pair < pairs_in_block; pair += static_cast(tg_size)) { int elem_idx = block_start + pair * 2; int byte_idx = elem_idx / 2; uint8_t byte_val = packed[byte_idx]; uint8_t high = (byte_val >> 4) & 0x0f; uint8_t low = byte_val & 0x0f; output[elem_idx] = T(codebook[high] * scale); if (elem_idx + 1 < block_end) { output[elem_idx + 1] = T(codebook[low] * scale); } } } // ---- GEMV kernel entry point ---- // y = dequant(W) @ x // W: [N, K/2], absmax: [N, K_groups], x: [K], y: [N] template [[kernel]] void bnb_qmv( const device uint8_t* w [[buffer(0)]], const device float* absmax [[buffer(1)]], const device T* x [[buffer(2)]], device T* y [[buffer(3)]], const constant int& in_vec_size [[buffer(4)]], const constant int& out_vec_size [[buffer(5)]], uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { bnb_qmv_impl( w, absmax, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); } // ---- GEMM (transposed weight) kernel entry point ---- // Y = X @ dequant(W).T // X: [M, K], W: [N, K/2], absmax: [N, K_groups], Y: [M, N] template [[kernel]] void bnb_qmm_t( const device uint8_t* w [[buffer(0)]], const device float* absmax [[buffer(1)]], const device T* x [[buffer(2)]], device T* y [[buffer(3)]], const constant int& K [[buffer(4)]], const constant int& N [[buffer(5)]], const constant int& M [[buffer(6)]], uint3 tid [[threadgroup_position_in_grid]], uint lid [[thread_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { (void)lid; constexpr int BM = 32; constexpr int BK = 32; constexpr int BN = 32; constexpr int BK_padded = (BK + 16 / sizeof(T)); threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; bnb_qmm_t_impl( w, absmax, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); }