#!/usr/bin/env python3 """ Unit tests for NVFP4 kernel functions. This tests dequantization and GEMM operations in isolation before attempting full model inference. """ import sys import torch import torch.nn.functional as F # Import from local inference directory from nvfp4_kernel import ( dequantize_nvfp4, nvfp4_gemm_dequant, NVFP4_LUT, NVFP4_BLOCK_SIZE ) # Constants from quantization script FP4_MAX = 6.0 FP8_E4M3_MAX = 448.0 E2M1_BOUNDS = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], dtype=torch.float32) def compute_nvfp4_scales(fp32_weight, block_size=16): """ Compute two-level NVFP4 scaling factors. Simplified version for testing. """ # Global scale global_amax = fp32_weight.abs().max() weight_scale_2 = global_amax / (FP4_MAX * FP8_E4M3_MAX) if weight_scale_2.abs() < 1e-10: weight_scale_2 = torch.tensor(1e-8, dtype=torch.float32, device=fp32_weight.device) # Per-block scale M = fp32_weight.shape[0] if fp32_weight.dim() > 1 else 1 N = fp32_weight.shape[-1] # Pad if needed N_padded = ((N + block_size - 1) // block_size) * block_size if N_padded != N: if fp32_weight.dim() == 1: padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device) padded[:N] = fp32_weight fp32_weight = padded else: padded = torch.zeros(M, N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device) padded[:, :N] = fp32_weight fp32_weight = padded # Reshape to blocks if fp32_weight.dim() == 1: weight_blocks = fp32_weight.view(-1, block_size) else: weight_blocks = fp32_weight.view(M, -1, block_size) # Compute per-block amax per_block_amax = weight_blocks.abs().amax(dim=-1) per_block_scale = per_block_amax / (FP4_MAX * weight_scale_2) per_block_scale = per_block_scale.clamp(min=1e-8) # Convert to FP8 E4M3 try: weight_scale = per_block_scale.to(torch.float8_e4m3fn) except (RuntimeError, TypeError): weight_scale = per_block_scale.to(torch.float32) return weight_scale, weight_scale_2 def quantize_to_nvfp4_packed(fp32_weight, weight_scale, weight_scale_2, block_size=16): """ Quantize FP32 weight to NVFP4 packed uint8 format. Simplified version for testing. """ device = fp32_weight.device M = fp32_weight.shape[0] if fp32_weight.dim() > 1 else 1 N = fp32_weight.shape[-1] # Pad if needed N_padded = ((N + block_size - 1) // block_size) * block_size if N_padded != N: if fp32_weight.dim() == 1: padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=device) padded[:N] = fp32_weight fp32_weight = padded else: padded = torch.zeros(M, N_padded, dtype=fp32_weight.dtype, device=device) padded[:, :N] = fp32_weight fp32_weight = padded # Reshape to blocks if fp32_weight.dim() == 1: weight_blocks = fp32_weight.view(-1, block_size) else: weight_blocks = fp32_weight.view(M, -1, block_size) # Apply scaling combined_scale = weight_scale.to(torch.float32) * weight_scale_2 scaled_weight = weight_blocks / combined_scale.unsqueeze(-1) # Flatten if fp32_weight.dim() == 1: scaled_weight = scaled_weight.view(-1) else: scaled_weight = scaled_weight.view(M, -1) # Get E2M1 bounds e2m1_bounds = E2M1_BOUNDS.to(device) # Extract sign and absolute values sign_bit = (scaled_weight < 0).to(torch.uint8) weight_abs = scaled_weight.abs() # Quantize to E2M1 magnitude codes [0-7] magnitude_code = torch.searchsorted(e2m1_bounds, weight_abs) # Combine sign bit and magnitude code = (sign_bit << 3) | magnitude_code.to(torch.uint8) # Pack two 4-bit values per byte N_current = code.shape[-1] if N_current % 2 != 0: # Pad to even if code.dim() == 1: padded = torch.zeros(N_current + 1, dtype=torch.uint8, device=device) padded[:N_current] = code code = padded else: padded = torch.zeros(M, N_current + 1, dtype=torch.uint8, device=device) padded[:, :N_current] = code code = padded # Pack if code.dim() == 1: packed = (code[1::2] << 4) | code[0::2] else: packed = (code[:, 1::2] << 4) | code[:, 0::2] return packed def test_dequant_lookup_table(): """Test 1: Verify NVFP4 lookup table values are correct.""" print("\n" + "=" * 70) print("Test 1: NVFP4 Lookup Table") print("=" * 70) expected = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0] assert len(NVFP4_LUT) == 16, f"LUT should have 16 entries, got {len(NVFP4_LUT)}" for i, (actual, expected_val) in enumerate(zip(NVFP4_LUT, expected)): assert abs(actual - expected_val) < 1e-6, f"LUT[{i}] = {actual}, expected {expected_val}" print(f" PASS: Lookup table correct: {NVFP4_LUT.tolist()[:8]}") print(f" {NVFP4_LUT.tolist()[8:]}") print(" PASS: Test 1 PASSED\n") def test_dequant_simple(): """Test 2: Simple dequantization with known values.""" print("=" * 70) print("Test 2: Simple Dequantization") print("=" * 70) # Create simple test case: packed values representing [0, 1.0, 2.0, 3.0, ...] # Codes: 0=0.0, 2=1.0, 4=2.0, 5=3.0, 6=4.0, 7=6.0 # Pack: (high << 4) | low packed = torch.tensor([ [0x20, 0x54, 0x76, 0x00, 0x00, 0x00, 0x00, 0x00], # [0,2,4,5,6,7,0,0] -> [0,1,2,3,4,6,0,0] ], dtype=torch.uint8) # Uniform scales for simplicity scale = torch.ones(1, 1, dtype=torch.float8_e4m3fn) scale_2 = torch.tensor([1.0], dtype=torch.float32) result = dequantize_nvfp4(packed, scale, scale_2, dtype=torch.float32) print(f" Packed: {packed[0].tolist()}") print(f" Scales: scale={scale.shape}, scale_2={scale_2.item()}") print(f" Result shape: {result.shape}") print(f" Result values: {result[0].tolist()}") # Expected: [0, 1, 2, 3, 4, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] expected_values = [0, 1, 2, 3, 4, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] for i, (val, expected) in enumerate(zip(result[0].tolist(), expected_values)): assert abs(val - expected) < 0.01, f"Position {i}: got {val}, expected {expected}" print(" PASS: Dequantization correct") print(" PASS: Test 2 PASSED\n") def test_quantize_dequantize_roundtrip(): """Test 3: Quantize then dequantize, check error is acceptable.""" print("=" * 70) print("Test 3: Quantization-Dequantization Roundtrip") print("=" * 70) # Create test tensor with values in representable range M, N = 64, 256 torch.manual_seed(42) fp32_weight = torch.randn(M, N, dtype=torch.float32) * 2.0 # Scale to ~[-6, 6] print(f" Input shape: {fp32_weight.shape}") print(f" Input range: [{fp32_weight.min():.3f}, {fp32_weight.max():.3f}]") # Compute scales scale, scale_2 = compute_nvfp4_scales(fp32_weight, block_size=16) print(f" Scale shape: {scale.shape}, scale_2: {scale_2.item():.6e}") # Quantize packed = quantize_to_nvfp4_packed(fp32_weight, scale, scale_2, block_size=16) print(f" Packed shape: {packed.shape} (expected [{M}, {N//2}])") assert packed.shape == (M, N // 2), f"Packed shape mismatch" # Dequantize dequantized = dequantize_nvfp4(packed, scale, scale_2, dtype=torch.float32) print(f" Dequantized shape: {dequantized.shape}") assert dequantized.shape == (M, N), f"Dequantized shape mismatch" # Compute error error = (fp32_weight - dequantized).abs() mean_error = error.mean().item() max_error = error.max().item() relative_error = (error / (fp32_weight.abs() + 1e-8)).mean().item() print(f" Mean absolute error: {mean_error:.6f}") print(f" Max absolute error: {max_error:.6f}") print(f" Mean relative error: {relative_error:.6f}") # For 4-bit quantization, we expect some error but should be reasonable assert mean_error < 1.0, f"Mean error too high: {mean_error}" assert relative_error < 0.5, f"Relative error too high: {relative_error}" print(" PASS: Roundtrip error acceptable for 4-bit quantization") print(" PASS: Test 3 PASSED\n") def test_gemm_shapes(): """Test 4: NVFP4 GEMM with various shapes.""" print("=" * 70) print("Test 4: NVFP4 GEMM Shape Tests") print("=" * 70) test_cases = [ (32, 64, 128), # Small (128, 256, 512), # Medium (64, 512, 256), # Asymmetric ] for M, N, K in test_cases: print(f"\n Testing GEMM: [{M}, {K}] @ [{N}, {K}].T = [{M}, {N}]") # Create input activation x = torch.randn(M, K, dtype=torch.bfloat16) # Create quantized weight weight_fp32 = torch.randn(N, K, dtype=torch.float32) * 2.0 scale, scale_2 = compute_nvfp4_scales(weight_fp32, block_size=16) packed_weight = quantize_to_nvfp4_packed(weight_fp32, scale, scale_2, block_size=16) print(f" Input: {x.shape}, Weight: {packed_weight.shape}") print(f" Scales: {scale.shape}, {scale_2.shape}") # Run NVFP4 GEMM result = nvfp4_gemm_dequant(x, packed_weight, scale, scale_2) print(f" Output: {result.shape}") assert result.shape == (M, N), f"Output shape mismatch: {result.shape} != ({M}, {N})" # Verify no NaN/Inf assert not torch.isnan(result).any(), "Output contains NaN" assert not torch.isinf(result).any(), "Output contains Inf" print(f" PASS: Shape correct, no NaN/Inf") print("\n PASS: All GEMM shape tests passed") print(" PASS: Test 4 PASSED\n") def test_gemm_correctness(): """Test 5: Verify NVFP4 GEMM output is close to reference.""" print("=" * 70) print("Test 5: NVFP4 GEMM Correctness") print("=" * 70) M, N, K = 64, 128, 256 # Create test tensors x = torch.randn(M, K, dtype=torch.bfloat16) weight_fp32 = torch.randn(N, K, dtype=torch.float32) * 1.5 # Quantize weight scale, scale_2 = compute_nvfp4_scales(weight_fp32, block_size=16) packed_weight = quantize_to_nvfp4_packed(weight_fp32, scale, scale_2, block_size=16) # Run NVFP4 GEMM result_nvfp4 = nvfp4_gemm_dequant(x, packed_weight, scale, scale_2) # Run reference GEMM with FP32 result_reference = F.linear(x, weight_fp32.to(torch.bfloat16)) print(f" NVFP4 GEMM output: {result_nvfp4.shape}, dtype={result_nvfp4.dtype}") print(f" Reference output: {result_reference.shape}, dtype={result_reference.dtype}") # Compute error error = (result_nvfp4.float() - result_reference.float()).abs() mean_error = error.mean().item() max_error = error.max().item() relative_error = (error / (result_reference.float().abs() + 1e-8)).mean().item() print(f" Mean absolute error: {mean_error:.6f}") print(f" Max absolute error: {max_error:.6f}") print(f" Mean relative error: {relative_error:.6f}") # Due to 4-bit quantization, expect significant error but not catastrophic assert mean_error < 5.0, f"Mean error too high: {mean_error}" assert relative_error < 1.0, f"Relative error too high: {relative_error}" print(" PASS: NVFP4 GEMM output reasonably close to reference") print(" PASS: Test 5 PASSED\n") def main(): """Run all NVFP4 kernel unit tests.""" print("\n" + "=" * 70) print("NVFP4 Kernel Unit Tests") print("=" * 70) print("Testing NVFP4 quantization/dequantization and GEMM operations") print("Expected runtime: < 30 seconds") print("=" * 70) try: # Run all tests test_dequant_lookup_table() test_dequant_simple() test_quantize_dequantize_roundtrip() test_gemm_shapes() test_gemm_correctness() # Summary print("=" * 70) print("PASS: ALL TESTS PASSED") print("=" * 70) print("NVFP4 kernel functions are working correctly!") print("Ready to proceed with full model testing.") print("=" * 70) return 0 except AssertionError as e: print(f"\nFAIL: TEST FAILED: {e}") import traceback traceback.print_exc() return 1 except Exception as e: print(f"\nFAIL: UNEXPECTED ERROR: {e}") import traceback traceback.print_exc() return 1 if __name__ == "__main__": sys.exit(main())