| """Tests for bitsandbytes MPS 4-bit quantization kernels.""" |
|
|
| import pytest |
| import torch |
|
|
| from bitsandbytes_mps import ( |
| FP4, |
| NF4, |
| dequantize_4bit, |
| gemm_4bit, |
| gemv_4bit, |
| linear_4bit, |
| quantize_4bit, |
| ) |
|
|
| |
| NF4_CODEBOOK = [ |
| -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, |
| -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, |
| 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, |
| 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, |
| 0.7229568362236023, 1.0, |
| ] |
|
|
| FP4_CODEBOOK = [ |
| 0.0, 0.005208333333, 0.66666667, 1.0, 0.33333333, 0.5, 0.16666667, 0.25, |
| 0.0, -0.005208333333, -0.66666667, -1.0, -0.33333333, -0.5, -0.16666667, |
| -0.25, |
| ] |
|
|
| DEVICE = "mps" |
|
|
|
|
| def _reference_quantize_nf4(x_flat, blocksize): |
| """Reference Python implementation of NF4 blockwise quantization.""" |
| n = x_flat.numel() |
| num_blocks = (n + blocksize - 1) // blocksize |
| absmax = torch.zeros(num_blocks, dtype=torch.float32) |
| packed = torch.zeros((n + 1) // 2, dtype=torch.uint8) |
|
|
| codebook = torch.tensor(NF4_CODEBOOK, dtype=torch.float32) |
|
|
| for b in range(num_blocks): |
| start = b * blocksize |
| end = min(start + blocksize, n) |
| block = x_flat[start:end].float() |
| am = block.abs().max().item() |
| absmax[b] = am |
|
|
| if am > 0: |
| normalized = (block / am).clamp(-1, 1) |
| else: |
| normalized = torch.zeros_like(block) |
|
|
| for i in range(0, end - start, 2): |
| v0 = normalized[i].item() |
| q0 = (codebook - v0).abs().argmin().item() |
|
|
| q1 = 0 |
| if i + 1 < end - start: |
| v1 = normalized[i + 1].item() |
| q1 = (codebook - v1).abs().argmin().item() |
|
|
| byte_idx = (start + i) // 2 |
| packed[byte_idx] = (q0 << 4) | (q1 & 0x0F) |
|
|
| return packed, absmax |
|
|
|
|
| def _reference_dequantize_nf4(packed, absmax, blocksize, numel): |
| """Reference Python implementation of NF4 blockwise dequantization.""" |
| codebook = torch.tensor(NF4_CODEBOOK, dtype=torch.float32) |
| output = torch.zeros(numel, dtype=torch.float32) |
|
|
| for i in range(numel): |
| byte_idx = i // 2 |
| block_idx = i // blocksize |
| byte_val = packed[byte_idx].item() |
|
|
| if i % 2 == 0: |
| nibble = (byte_val >> 4) & 0x0F |
| else: |
| nibble = byte_val & 0x0F |
|
|
| output[i] = codebook[nibble] * absmax[block_idx].item() |
|
|
| return output |
|
|
|
|
| |
| |
| |
|
|
|
|
| @pytest.mark.parametrize("blocksize", [64, 128]) |
| @pytest.mark.parametrize("quant_type", [NF4, FP4]) |
| @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| def test_quantize_dequantize_roundtrip(blocksize, quant_type, dtype): |
| """Test that quantize -> dequantize approximately recovers the original.""" |
| torch.manual_seed(42) |
| n = 1024 |
| x = torch.randn(n, dtype=dtype, device=DEVICE) |
|
|
| packed, absmax = quantize_4bit(x, blocksize=blocksize, quant_type=quant_type) |
|
|
| assert packed.shape == (n // 2,) |
| assert packed.dtype == torch.uint8 |
| assert absmax.dtype == torch.float32 |
| assert absmax.shape == ((n + blocksize - 1) // blocksize,) |
|
|
| x_deq = dequantize_4bit( |
| packed, absmax, blocksize=blocksize, quant_type=quant_type, |
| numel=n, output_dtype=dtype, |
| ) |
|
|
| assert x_deq.shape == (n,) |
| assert x_deq.dtype == dtype |
|
|
| |
| x_cpu = x.float().cpu() |
| x_deq_cpu = x_deq.float().cpu() |
| cosine_sim = torch.nn.functional.cosine_similarity( |
| x_cpu.unsqueeze(0), x_deq_cpu.unsqueeze(0) |
| ).item() |
| assert cosine_sim > 0.95, f"Cosine similarity too low: {cosine_sim}" |
|
|
|
|
| @pytest.mark.parametrize("blocksize", [64, 128]) |
| def test_dequantize_matches_reference(blocksize): |
| """Test dequantization matches the Python reference implementation.""" |
| torch.manual_seed(123) |
| n = 256 |
| x = torch.randn(n, dtype=torch.float16, device=DEVICE) |
|
|
| packed, absmax = quantize_4bit(x, blocksize=blocksize, quant_type=NF4) |
|
|
| |
| x_deq = dequantize_4bit( |
| packed, absmax, blocksize=blocksize, quant_type=NF4, |
| numel=n, output_dtype=torch.float16, |
| ) |
|
|
| |
| x_ref = _reference_dequantize_nf4( |
| packed.cpu(), absmax.cpu(), blocksize, n |
| ) |
|
|
| torch.testing.assert_close( |
| x_deq.float().cpu(), x_ref, rtol=1e-3, atol=1e-3 |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @pytest.mark.parametrize("blocksize", [64, 128]) |
| @pytest.mark.parametrize("quant_type", [NF4, FP4]) |
| def test_gemv_correctness(blocksize, quant_type): |
| """Test fused GEMV against dequantize + matmul reference.""" |
| torch.manual_seed(42) |
| N, K = 256, 256 |
|
|
| |
| W = torch.randn(N, K, dtype=torch.float16, device=DEVICE) |
| W_flat = W.flatten() |
| packed, absmax = quantize_4bit(W_flat, blocksize=blocksize, quant_type=quant_type) |
|
|
| |
| packed_w = packed.view(N, K // 2) |
| absmax_w = absmax.view(N, -1) |
|
|
| |
| x = torch.randn(K, dtype=torch.float16, device=DEVICE) |
|
|
| |
| y = gemv_4bit(x, packed_w, absmax_w, output_features=N, |
| blocksize=blocksize, quant_type=quant_type) |
|
|
| |
| W_deq = dequantize_4bit(packed, absmax, blocksize=blocksize, |
| quant_type=quant_type, numel=N*K, |
| output_dtype=torch.float16) |
| W_deq = W_deq.view(N, K) |
| y_ref = W_deq @ x |
|
|
| |
| rel_error = (y.float() - y_ref.float()).abs().mean() / y_ref.float().abs().mean() |
| assert rel_error < 0.05, f"GEMV relative error too high: {rel_error}" |
|
|
|
|
| |
| |
| |
|
|
|
|
| @pytest.mark.parametrize("blocksize", [64, 128]) |
| @pytest.mark.parametrize("quant_type", [NF4, FP4]) |
| def test_gemm_correctness(blocksize, quant_type): |
| """Test fused GEMM against dequantize + matmul reference.""" |
| torch.manual_seed(42) |
| M, N, K = 8, 128, 128 |
|
|
| W = torch.randn(N, K, dtype=torch.float16, device=DEVICE) |
| W_flat = W.flatten() |
| packed, absmax = quantize_4bit(W_flat, blocksize=blocksize, quant_type=quant_type) |
|
|
| packed_w = packed.view(N, K // 2) |
| absmax_w = absmax.view(N, -1) |
|
|
| X = torch.randn(M, K, dtype=torch.float16, device=DEVICE) |
|
|
| |
| Y = gemm_4bit(X, packed_w, absmax_w, output_features=N, |
| blocksize=blocksize, quant_type=quant_type) |
|
|
| |
| W_deq = dequantize_4bit(packed, absmax, blocksize=blocksize, |
| quant_type=quant_type, numel=N*K, |
| output_dtype=torch.float16) |
| W_deq = W_deq.view(N, K) |
| Y_ref = X @ W_deq.T |
|
|
| rel_error = (Y.float() - Y_ref.float()).abs().mean() / Y_ref.float().abs().mean() |
| assert rel_error < 0.05, f"GEMM relative error too high: {rel_error}" |
|
|
|
|
| |
| |
| |
|
|
|
|
| def test_linear_4bit_auto_select(): |
| """Test that linear_4bit auto-selects GEMV vs GEMM.""" |
| torch.manual_seed(42) |
| N, K = 128, 128 |
|
|
| W = torch.randn(N, K, dtype=torch.float16, device=DEVICE) |
| packed, absmax = quantize_4bit(W.flatten(), blocksize=64, quant_type=NF4) |
| packed_w = packed.view(N, K // 2) |
| absmax_w = absmax.view(N, -1) |
|
|
| |
| x = torch.randn(K, dtype=torch.float16, device=DEVICE) |
| y = linear_4bit(x, packed_w, absmax_w, output_features=N) |
| assert y.shape == (N,) |
|
|
| |
| X = torch.randn(4, K, dtype=torch.float16, device=DEVICE) |
| Y = linear_4bit(X, packed_w, absmax_w, output_features=N) |
| assert Y.shape == (4, N) |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__, "-v"]) |
|
|