"""Tests for bitsandbytes MPS 4-bit quantization kernels.""" import pytest import torch from bitsandbytes_mps import ( FP4, NF4, dequantize_4bit, gemm_4bit, gemv_4bit, linear_4bit, quantize_4bit, ) # NF4 codebook values (matching bnb_types.h) NF4_CODEBOOK = [ -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0, ] FP4_CODEBOOK = [ 0.0, 0.005208333333, 0.66666667, 1.0, 0.33333333, 0.5, 0.16666667, 0.25, 0.0, -0.005208333333, -0.66666667, -1.0, -0.33333333, -0.5, -0.16666667, -0.25, ] DEVICE = "mps" def _reference_quantize_nf4(x_flat, blocksize): """Reference Python implementation of NF4 blockwise quantization.""" n = x_flat.numel() num_blocks = (n + blocksize - 1) // blocksize absmax = torch.zeros(num_blocks, dtype=torch.float32) packed = torch.zeros((n + 1) // 2, dtype=torch.uint8) codebook = torch.tensor(NF4_CODEBOOK, dtype=torch.float32) for b in range(num_blocks): start = b * blocksize end = min(start + blocksize, n) block = x_flat[start:end].float() am = block.abs().max().item() absmax[b] = am if am > 0: normalized = (block / am).clamp(-1, 1) else: normalized = torch.zeros_like(block) for i in range(0, end - start, 2): v0 = normalized[i].item() q0 = (codebook - v0).abs().argmin().item() q1 = 0 if i + 1 < end - start: v1 = normalized[i + 1].item() q1 = (codebook - v1).abs().argmin().item() byte_idx = (start + i) // 2 packed[byte_idx] = (q0 << 4) | (q1 & 0x0F) return packed, absmax def _reference_dequantize_nf4(packed, absmax, blocksize, numel): """Reference Python implementation of NF4 blockwise dequantization.""" codebook = torch.tensor(NF4_CODEBOOK, dtype=torch.float32) output = torch.zeros(numel, dtype=torch.float32) for i in range(numel): byte_idx = i // 2 block_idx = i // blocksize byte_val = packed[byte_idx].item() if i % 2 == 0: nibble = (byte_val >> 4) & 0x0F else: nibble = byte_val & 0x0F output[i] = codebook[nibble] * absmax[block_idx].item() return output # ============================================================================ # Quantization / Dequantization Tests # ============================================================================ @pytest.mark.parametrize("blocksize", [64, 128]) @pytest.mark.parametrize("quant_type", [NF4, FP4]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_quantize_dequantize_roundtrip(blocksize, quant_type, dtype): """Test that quantize -> dequantize approximately recovers the original.""" torch.manual_seed(42) n = 1024 x = torch.randn(n, dtype=dtype, device=DEVICE) packed, absmax = quantize_4bit(x, blocksize=blocksize, quant_type=quant_type) assert packed.shape == (n // 2,) assert packed.dtype == torch.uint8 assert absmax.dtype == torch.float32 assert absmax.shape == ((n + blocksize - 1) // blocksize,) x_deq = dequantize_4bit( packed, absmax, blocksize=blocksize, quant_type=quant_type, numel=n, output_dtype=dtype, ) assert x_deq.shape == (n,) assert x_deq.dtype == dtype # 4-bit quantization has significant error; check correlation x_cpu = x.float().cpu() x_deq_cpu = x_deq.float().cpu() cosine_sim = torch.nn.functional.cosine_similarity( x_cpu.unsqueeze(0), x_deq_cpu.unsqueeze(0) ).item() assert cosine_sim > 0.95, f"Cosine similarity too low: {cosine_sim}" @pytest.mark.parametrize("blocksize", [64, 128]) def test_dequantize_matches_reference(blocksize): """Test dequantization matches the Python reference implementation.""" torch.manual_seed(123) n = 256 x = torch.randn(n, dtype=torch.float16, device=DEVICE) packed, absmax = quantize_4bit(x, blocksize=blocksize, quant_type=NF4) # GPU dequantize x_deq = dequantize_4bit( packed, absmax, blocksize=blocksize, quant_type=NF4, numel=n, output_dtype=torch.float16, ) # Reference dequantize (on CPU) x_ref = _reference_dequantize_nf4( packed.cpu(), absmax.cpu(), blocksize, n ) torch.testing.assert_close( x_deq.float().cpu(), x_ref, rtol=1e-3, atol=1e-3 ) # ============================================================================ # GEMV Tests # ============================================================================ @pytest.mark.parametrize("blocksize", [64, 128]) @pytest.mark.parametrize("quant_type", [NF4, FP4]) def test_gemv_correctness(blocksize, quant_type): """Test fused GEMV against dequantize + matmul reference.""" torch.manual_seed(42) N, K = 256, 256 # Create weight and quantize W = torch.randn(N, K, dtype=torch.float16, device=DEVICE) W_flat = W.flatten() packed, absmax = quantize_4bit(W_flat, blocksize=blocksize, quant_type=quant_type) # Reshape for GEMV packed_w = packed.view(N, K // 2) absmax_w = absmax.view(N, -1) # Input vector x = torch.randn(K, dtype=torch.float16, device=DEVICE) # Fused GEMV y = gemv_4bit(x, packed_w, absmax_w, output_features=N, blocksize=blocksize, quant_type=quant_type) # Reference: dequantize then matmul W_deq = dequantize_4bit(packed, absmax, blocksize=blocksize, quant_type=quant_type, numel=N*K, output_dtype=torch.float16) W_deq = W_deq.view(N, K) y_ref = W_deq @ x # Check relative error rel_error = (y.float() - y_ref.float()).abs().mean() / y_ref.float().abs().mean() assert rel_error < 0.05, f"GEMV relative error too high: {rel_error}" # ============================================================================ # GEMM Tests # ============================================================================ @pytest.mark.parametrize("blocksize", [64, 128]) @pytest.mark.parametrize("quant_type", [NF4, FP4]) def test_gemm_correctness(blocksize, quant_type): """Test fused GEMM against dequantize + matmul reference.""" torch.manual_seed(42) M, N, K = 8, 128, 128 W = torch.randn(N, K, dtype=torch.float16, device=DEVICE) W_flat = W.flatten() packed, absmax = quantize_4bit(W_flat, blocksize=blocksize, quant_type=quant_type) packed_w = packed.view(N, K // 2) absmax_w = absmax.view(N, -1) X = torch.randn(M, K, dtype=torch.float16, device=DEVICE) # Fused GEMM Y = gemm_4bit(X, packed_w, absmax_w, output_features=N, blocksize=blocksize, quant_type=quant_type) # Reference W_deq = dequantize_4bit(packed, absmax, blocksize=blocksize, quant_type=quant_type, numel=N*K, output_dtype=torch.float16) W_deq = W_deq.view(N, K) Y_ref = X @ W_deq.T rel_error = (Y.float() - Y_ref.float()).abs().mean() / Y_ref.float().abs().mean() assert rel_error < 0.05, f"GEMM relative error too high: {rel_error}" # ============================================================================ # Linear layer test # ============================================================================ def test_linear_4bit_auto_select(): """Test that linear_4bit auto-selects GEMV vs GEMM.""" torch.manual_seed(42) N, K = 128, 128 W = torch.randn(N, K, dtype=torch.float16, device=DEVICE) packed, absmax = quantize_4bit(W.flatten(), blocksize=64, quant_type=NF4) packed_w = packed.view(N, K // 2) absmax_w = absmax.view(N, -1) # Single vector - should use GEMV x = torch.randn(K, dtype=torch.float16, device=DEVICE) y = linear_4bit(x, packed_w, absmax_w, output_features=N) assert y.shape == (N,) # Batch - should use GEMM X = torch.randn(4, K, dtype=torch.float16, device=DEVICE) Y = linear_4bit(X, packed_w, absmax_w, output_features=N) assert Y.shape == (4, N) if __name__ == "__main__": pytest.main([__file__, "-v"])