| | |
| | """ |
| | Unit tests for NVFP4 kernel functions. |
| | |
| | This tests dequantization and GEMM operations in isolation before |
| | attempting full model inference. |
| | """ |
| |
|
| | import sys |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | |
| | from nvfp4_kernel import ( |
| | dequantize_nvfp4, |
| | nvfp4_gemm_dequant, |
| | NVFP4_LUT, |
| | NVFP4_BLOCK_SIZE |
| | ) |
| |
|
| | |
| | FP4_MAX = 6.0 |
| | FP8_E4M3_MAX = 448.0 |
| | E2M1_BOUNDS = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], dtype=torch.float32) |
| |
|
| |
|
| | def compute_nvfp4_scales(fp32_weight, block_size=16): |
| | """ |
| | Compute two-level NVFP4 scaling factors. |
| | Simplified version for testing. |
| | """ |
| | |
| | global_amax = fp32_weight.abs().max() |
| | weight_scale_2 = global_amax / (FP4_MAX * FP8_E4M3_MAX) |
| |
|
| | if weight_scale_2.abs() < 1e-10: |
| | weight_scale_2 = torch.tensor(1e-8, dtype=torch.float32, device=fp32_weight.device) |
| |
|
| | |
| | M = fp32_weight.shape[0] if fp32_weight.dim() > 1 else 1 |
| | N = fp32_weight.shape[-1] |
| |
|
| | |
| | N_padded = ((N + block_size - 1) // block_size) * block_size |
| | if N_padded != N: |
| | if fp32_weight.dim() == 1: |
| | padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device) |
| | padded[:N] = fp32_weight |
| | fp32_weight = padded |
| | else: |
| | padded = torch.zeros(M, N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device) |
| | padded[:, :N] = fp32_weight |
| | fp32_weight = padded |
| |
|
| | |
| | if fp32_weight.dim() == 1: |
| | weight_blocks = fp32_weight.view(-1, block_size) |
| | else: |
| | weight_blocks = fp32_weight.view(M, -1, block_size) |
| |
|
| | |
| | per_block_amax = weight_blocks.abs().amax(dim=-1) |
| | per_block_scale = per_block_amax / (FP4_MAX * weight_scale_2) |
| | per_block_scale = per_block_scale.clamp(min=1e-8) |
| |
|
| | |
| | try: |
| | weight_scale = per_block_scale.to(torch.float8_e4m3fn) |
| | except (RuntimeError, TypeError): |
| | weight_scale = per_block_scale.to(torch.float32) |
| |
|
| | return weight_scale, weight_scale_2 |
| |
|
| |
|
| | def quantize_to_nvfp4_packed(fp32_weight, weight_scale, weight_scale_2, block_size=16): |
| | """ |
| | Quantize FP32 weight to NVFP4 packed uint8 format. |
| | Simplified version for testing. |
| | """ |
| | device = fp32_weight.device |
| | M = fp32_weight.shape[0] if fp32_weight.dim() > 1 else 1 |
| | N = fp32_weight.shape[-1] |
| |
|
| | |
| | N_padded = ((N + block_size - 1) // block_size) * block_size |
| | if N_padded != N: |
| | if fp32_weight.dim() == 1: |
| | padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=device) |
| | padded[:N] = fp32_weight |
| | fp32_weight = padded |
| | else: |
| | padded = torch.zeros(M, N_padded, dtype=fp32_weight.dtype, device=device) |
| | padded[:, :N] = fp32_weight |
| | fp32_weight = padded |
| |
|
| | |
| | if fp32_weight.dim() == 1: |
| | weight_blocks = fp32_weight.view(-1, block_size) |
| | else: |
| | weight_blocks = fp32_weight.view(M, -1, block_size) |
| |
|
| | |
| | combined_scale = weight_scale.to(torch.float32) * weight_scale_2 |
| | scaled_weight = weight_blocks / combined_scale.unsqueeze(-1) |
| |
|
| | |
| | if fp32_weight.dim() == 1: |
| | scaled_weight = scaled_weight.view(-1) |
| | else: |
| | scaled_weight = scaled_weight.view(M, -1) |
| |
|
| | |
| | e2m1_bounds = E2M1_BOUNDS.to(device) |
| |
|
| | |
| | sign_bit = (scaled_weight < 0).to(torch.uint8) |
| | weight_abs = scaled_weight.abs() |
| |
|
| | |
| | magnitude_code = torch.searchsorted(e2m1_bounds, weight_abs) |
| |
|
| | |
| | code = (sign_bit << 3) | magnitude_code.to(torch.uint8) |
| |
|
| | |
| | N_current = code.shape[-1] |
| | if N_current % 2 != 0: |
| | |
| | if code.dim() == 1: |
| | padded = torch.zeros(N_current + 1, dtype=torch.uint8, device=device) |
| | padded[:N_current] = code |
| | code = padded |
| | else: |
| | padded = torch.zeros(M, N_current + 1, dtype=torch.uint8, device=device) |
| | padded[:, :N_current] = code |
| | code = padded |
| |
|
| | |
| | if code.dim() == 1: |
| | packed = (code[1::2] << 4) | code[0::2] |
| | else: |
| | packed = (code[:, 1::2] << 4) | code[:, 0::2] |
| |
|
| | return packed |
| |
|
| |
|
| | def test_dequant_lookup_table(): |
| | """Test 1: Verify NVFP4 lookup table values are correct.""" |
| | print("\n" + "=" * 70) |
| | print("Test 1: NVFP4 Lookup Table") |
| | print("=" * 70) |
| |
|
| | expected = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, |
| | -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0] |
| |
|
| | assert len(NVFP4_LUT) == 16, f"LUT should have 16 entries, got {len(NVFP4_LUT)}" |
| |
|
| | for i, (actual, expected_val) in enumerate(zip(NVFP4_LUT, expected)): |
| | assert abs(actual - expected_val) < 1e-6, f"LUT[{i}] = {actual}, expected {expected_val}" |
| |
|
| | print(f" PASS: Lookup table correct: {NVFP4_LUT.tolist()[:8]}") |
| | print(f" {NVFP4_LUT.tolist()[8:]}") |
| | print(" PASS: Test 1 PASSED\n") |
| |
|
| |
|
| | def test_dequant_simple(): |
| | """Test 2: Simple dequantization with known values.""" |
| | print("=" * 70) |
| | print("Test 2: Simple Dequantization") |
| | print("=" * 70) |
| |
|
| | |
| | |
| | |
| | packed = torch.tensor([ |
| | [0x20, 0x54, 0x76, 0x00, 0x00, 0x00, 0x00, 0x00], |
| | ], dtype=torch.uint8) |
| |
|
| | |
| | scale = torch.ones(1, 1, dtype=torch.float8_e4m3fn) |
| | scale_2 = torch.tensor([1.0], dtype=torch.float32) |
| |
|
| | result = dequantize_nvfp4(packed, scale, scale_2, dtype=torch.float32) |
| |
|
| | print(f" Packed: {packed[0].tolist()}") |
| | print(f" Scales: scale={scale.shape}, scale_2={scale_2.item()}") |
| | print(f" Result shape: {result.shape}") |
| | print(f" Result values: {result[0].tolist()}") |
| |
|
| | |
| | expected_values = [0, 1, 2, 3, 4, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
| | for i, (val, expected) in enumerate(zip(result[0].tolist(), expected_values)): |
| | assert abs(val - expected) < 0.01, f"Position {i}: got {val}, expected {expected}" |
| |
|
| | print(" PASS: Dequantization correct") |
| | print(" PASS: Test 2 PASSED\n") |
| |
|
| |
|
| | def test_quantize_dequantize_roundtrip(): |
| | """Test 3: Quantize then dequantize, check error is acceptable.""" |
| | print("=" * 70) |
| | print("Test 3: Quantization-Dequantization Roundtrip") |
| | print("=" * 70) |
| |
|
| | |
| | M, N = 64, 256 |
| | torch.manual_seed(42) |
| | fp32_weight = torch.randn(M, N, dtype=torch.float32) * 2.0 |
| |
|
| | print(f" Input shape: {fp32_weight.shape}") |
| | print(f" Input range: [{fp32_weight.min():.3f}, {fp32_weight.max():.3f}]") |
| |
|
| | |
| | scale, scale_2 = compute_nvfp4_scales(fp32_weight, block_size=16) |
| | print(f" Scale shape: {scale.shape}, scale_2: {scale_2.item():.6e}") |
| |
|
| | |
| | packed = quantize_to_nvfp4_packed(fp32_weight, scale, scale_2, block_size=16) |
| | print(f" Packed shape: {packed.shape} (expected [{M}, {N//2}])") |
| | assert packed.shape == (M, N // 2), f"Packed shape mismatch" |
| |
|
| | |
| | dequantized = dequantize_nvfp4(packed, scale, scale_2, dtype=torch.float32) |
| | print(f" Dequantized shape: {dequantized.shape}") |
| | assert dequantized.shape == (M, N), f"Dequantized shape mismatch" |
| |
|
| | |
| | error = (fp32_weight - dequantized).abs() |
| | mean_error = error.mean().item() |
| | max_error = error.max().item() |
| | relative_error = (error / (fp32_weight.abs() + 1e-8)).mean().item() |
| |
|
| | print(f" Mean absolute error: {mean_error:.6f}") |
| | print(f" Max absolute error: {max_error:.6f}") |
| | print(f" Mean relative error: {relative_error:.6f}") |
| |
|
| | |
| | assert mean_error < 1.0, f"Mean error too high: {mean_error}" |
| | assert relative_error < 0.5, f"Relative error too high: {relative_error}" |
| |
|
| | print(" PASS: Roundtrip error acceptable for 4-bit quantization") |
| | print(" PASS: Test 3 PASSED\n") |
| |
|
| |
|
| | def test_gemm_shapes(): |
| | """Test 4: NVFP4 GEMM with various shapes.""" |
| | print("=" * 70) |
| | print("Test 4: NVFP4 GEMM Shape Tests") |
| | print("=" * 70) |
| |
|
| | test_cases = [ |
| | (32, 64, 128), |
| | (128, 256, 512), |
| | (64, 512, 256), |
| | ] |
| |
|
| | for M, N, K in test_cases: |
| | print(f"\n Testing GEMM: [{M}, {K}] @ [{N}, {K}].T = [{M}, {N}]") |
| |
|
| | |
| | x = torch.randn(M, K, dtype=torch.bfloat16) |
| |
|
| | |
| | weight_fp32 = torch.randn(N, K, dtype=torch.float32) * 2.0 |
| | scale, scale_2 = compute_nvfp4_scales(weight_fp32, block_size=16) |
| | packed_weight = quantize_to_nvfp4_packed(weight_fp32, scale, scale_2, block_size=16) |
| |
|
| | print(f" Input: {x.shape}, Weight: {packed_weight.shape}") |
| | print(f" Scales: {scale.shape}, {scale_2.shape}") |
| |
|
| | |
| | result = nvfp4_gemm_dequant(x, packed_weight, scale, scale_2) |
| |
|
| | print(f" Output: {result.shape}") |
| | assert result.shape == (M, N), f"Output shape mismatch: {result.shape} != ({M}, {N})" |
| |
|
| | |
| | assert not torch.isnan(result).any(), "Output contains NaN" |
| | assert not torch.isinf(result).any(), "Output contains Inf" |
| |
|
| | print(f" PASS: Shape correct, no NaN/Inf") |
| |
|
| | print("\n PASS: All GEMM shape tests passed") |
| | print(" PASS: Test 4 PASSED\n") |
| |
|
| |
|
| | def test_gemm_correctness(): |
| | """Test 5: Verify NVFP4 GEMM output is close to reference.""" |
| | print("=" * 70) |
| | print("Test 5: NVFP4 GEMM Correctness") |
| | print("=" * 70) |
| |
|
| | M, N, K = 64, 128, 256 |
| |
|
| | |
| | x = torch.randn(M, K, dtype=torch.bfloat16) |
| | weight_fp32 = torch.randn(N, K, dtype=torch.float32) * 1.5 |
| |
|
| | |
| | scale, scale_2 = compute_nvfp4_scales(weight_fp32, block_size=16) |
| | packed_weight = quantize_to_nvfp4_packed(weight_fp32, scale, scale_2, block_size=16) |
| |
|
| | |
| | result_nvfp4 = nvfp4_gemm_dequant(x, packed_weight, scale, scale_2) |
| |
|
| | |
| | result_reference = F.linear(x, weight_fp32.to(torch.bfloat16)) |
| |
|
| | print(f" NVFP4 GEMM output: {result_nvfp4.shape}, dtype={result_nvfp4.dtype}") |
| | print(f" Reference output: {result_reference.shape}, dtype={result_reference.dtype}") |
| |
|
| | |
| | error = (result_nvfp4.float() - result_reference.float()).abs() |
| | mean_error = error.mean().item() |
| | max_error = error.max().item() |
| | relative_error = (error / (result_reference.float().abs() + 1e-8)).mean().item() |
| |
|
| | print(f" Mean absolute error: {mean_error:.6f}") |
| | print(f" Max absolute error: {max_error:.6f}") |
| | print(f" Mean relative error: {relative_error:.6f}") |
| |
|
| | |
| | assert mean_error < 5.0, f"Mean error too high: {mean_error}" |
| | assert relative_error < 1.0, f"Relative error too high: {relative_error}" |
| |
|
| | print(" PASS: NVFP4 GEMM output reasonably close to reference") |
| | print(" PASS: Test 5 PASSED\n") |
| |
|
| |
|
| | def main(): |
| | """Run all NVFP4 kernel unit tests.""" |
| | print("\n" + "=" * 70) |
| | print("NVFP4 Kernel Unit Tests") |
| | print("=" * 70) |
| | print("Testing NVFP4 quantization/dequantization and GEMM operations") |
| | print("Expected runtime: < 30 seconds") |
| | print("=" * 70) |
| |
|
| | try: |
| | |
| | test_dequant_lookup_table() |
| | test_dequant_simple() |
| | test_quantize_dequantize_roundtrip() |
| | test_gemm_shapes() |
| | test_gemm_correctness() |
| |
|
| | |
| | print("=" * 70) |
| | print("PASS: ALL TESTS PASSED") |
| | print("=" * 70) |
| | print("NVFP4 kernel functions are working correctly!") |
| | print("Ready to proceed with full model testing.") |
| | print("=" * 70) |
| |
|
| | return 0 |
| |
|
| | except AssertionError as e: |
| | print(f"\nFAIL: TEST FAILED: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return 1 |
| | except Exception as e: |
| | print(f"\nFAIL: UNEXPECTED ERROR: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return 1 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | sys.exit(main()) |
| |
|