eousphoros commited on
Commit
e82122f
·
verified ·
1 Parent(s): a5ceb64

Upload inference/test_nvfp4_kernel.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference/test_nvfp4_kernel.py +370 -0
inference/test_nvfp4_kernel.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Unit tests for NVFP4 kernel functions.
4
+
5
+ This tests dequantization and GEMM operations in isolation before
6
+ attempting full model inference.
7
+ """
8
+
9
+ import sys
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ # Import from local inference directory
14
+ from nvfp4_kernel import (
15
+ dequantize_nvfp4,
16
+ nvfp4_gemm_dequant,
17
+ NVFP4_LUT,
18
+ NVFP4_BLOCK_SIZE
19
+ )
20
+
21
+ # Constants from quantization script
22
+ FP4_MAX = 6.0
23
+ FP8_E4M3_MAX = 448.0
24
+ E2M1_BOUNDS = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], dtype=torch.float32)
25
+
26
+
27
+ def compute_nvfp4_scales(fp32_weight, block_size=16):
28
+ """
29
+ Compute two-level NVFP4 scaling factors.
30
+ Simplified version for testing.
31
+ """
32
+ # Global scale
33
+ global_amax = fp32_weight.abs().max()
34
+ weight_scale_2 = global_amax / (FP4_MAX * FP8_E4M3_MAX)
35
+
36
+ if weight_scale_2.abs() < 1e-10:
37
+ weight_scale_2 = torch.tensor(1e-8, dtype=torch.float32, device=fp32_weight.device)
38
+
39
+ # Per-block scale
40
+ M = fp32_weight.shape[0] if fp32_weight.dim() > 1 else 1
41
+ N = fp32_weight.shape[-1]
42
+
43
+ # Pad if needed
44
+ N_padded = ((N + block_size - 1) // block_size) * block_size
45
+ if N_padded != N:
46
+ if fp32_weight.dim() == 1:
47
+ padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device)
48
+ padded[:N] = fp32_weight
49
+ fp32_weight = padded
50
+ else:
51
+ padded = torch.zeros(M, N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device)
52
+ padded[:, :N] = fp32_weight
53
+ fp32_weight = padded
54
+
55
+ # Reshape to blocks
56
+ if fp32_weight.dim() == 1:
57
+ weight_blocks = fp32_weight.view(-1, block_size)
58
+ else:
59
+ weight_blocks = fp32_weight.view(M, -1, block_size)
60
+
61
+ # Compute per-block amax
62
+ per_block_amax = weight_blocks.abs().amax(dim=-1)
63
+ per_block_scale = per_block_amax / (FP4_MAX * weight_scale_2)
64
+ per_block_scale = per_block_scale.clamp(min=1e-8)
65
+
66
+ # Convert to FP8 E4M3
67
+ try:
68
+ weight_scale = per_block_scale.to(torch.float8_e4m3fn)
69
+ except (RuntimeError, TypeError):
70
+ weight_scale = per_block_scale.to(torch.float32)
71
+
72
+ return weight_scale, weight_scale_2
73
+
74
+
75
+ def quantize_to_nvfp4_packed(fp32_weight, weight_scale, weight_scale_2, block_size=16):
76
+ """
77
+ Quantize FP32 weight to NVFP4 packed uint8 format.
78
+ Simplified version for testing.
79
+ """
80
+ device = fp32_weight.device
81
+ M = fp32_weight.shape[0] if fp32_weight.dim() > 1 else 1
82
+ N = fp32_weight.shape[-1]
83
+
84
+ # Pad if needed
85
+ N_padded = ((N + block_size - 1) // block_size) * block_size
86
+ if N_padded != N:
87
+ if fp32_weight.dim() == 1:
88
+ padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=device)
89
+ padded[:N] = fp32_weight
90
+ fp32_weight = padded
91
+ else:
92
+ padded = torch.zeros(M, N_padded, dtype=fp32_weight.dtype, device=device)
93
+ padded[:, :N] = fp32_weight
94
+ fp32_weight = padded
95
+
96
+ # Reshape to blocks
97
+ if fp32_weight.dim() == 1:
98
+ weight_blocks = fp32_weight.view(-1, block_size)
99
+ else:
100
+ weight_blocks = fp32_weight.view(M, -1, block_size)
101
+
102
+ # Apply scaling
103
+ combined_scale = weight_scale.to(torch.float32) * weight_scale_2
104
+ scaled_weight = weight_blocks / combined_scale.unsqueeze(-1)
105
+
106
+ # Flatten
107
+ if fp32_weight.dim() == 1:
108
+ scaled_weight = scaled_weight.view(-1)
109
+ else:
110
+ scaled_weight = scaled_weight.view(M, -1)
111
+
112
+ # Get E2M1 bounds
113
+ e2m1_bounds = E2M1_BOUNDS.to(device)
114
+
115
+ # Extract sign and absolute values
116
+ sign_bit = (scaled_weight < 0).to(torch.uint8)
117
+ weight_abs = scaled_weight.abs()
118
+
119
+ # Quantize to E2M1 magnitude codes [0-7]
120
+ magnitude_code = torch.searchsorted(e2m1_bounds, weight_abs)
121
+
122
+ # Combine sign bit and magnitude
123
+ code = (sign_bit << 3) | magnitude_code.to(torch.uint8)
124
+
125
+ # Pack two 4-bit values per byte
126
+ N_current = code.shape[-1]
127
+ if N_current % 2 != 0:
128
+ # Pad to even
129
+ if code.dim() == 1:
130
+ padded = torch.zeros(N_current + 1, dtype=torch.uint8, device=device)
131
+ padded[:N_current] = code
132
+ code = padded
133
+ else:
134
+ padded = torch.zeros(M, N_current + 1, dtype=torch.uint8, device=device)
135
+ padded[:, :N_current] = code
136
+ code = padded
137
+
138
+ # Pack
139
+ if code.dim() == 1:
140
+ packed = (code[1::2] << 4) | code[0::2]
141
+ else:
142
+ packed = (code[:, 1::2] << 4) | code[:, 0::2]
143
+
144
+ return packed
145
+
146
+
147
+ def test_dequant_lookup_table():
148
+ """Test 1: Verify NVFP4 lookup table values are correct."""
149
+ print("\n" + "=" * 70)
150
+ print("Test 1: NVFP4 Lookup Table")
151
+ print("=" * 70)
152
+
153
+ expected = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
154
+ -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0]
155
+
156
+ assert len(NVFP4_LUT) == 16, f"LUT should have 16 entries, got {len(NVFP4_LUT)}"
157
+
158
+ for i, (actual, expected_val) in enumerate(zip(NVFP4_LUT, expected)):
159
+ assert abs(actual - expected_val) < 1e-6, f"LUT[{i}] = {actual}, expected {expected_val}"
160
+
161
+ print(f" PASS: Lookup table correct: {NVFP4_LUT.tolist()[:8]}")
162
+ print(f" {NVFP4_LUT.tolist()[8:]}")
163
+ print(" PASS: Test 1 PASSED\n")
164
+
165
+
166
+ def test_dequant_simple():
167
+ """Test 2: Simple dequantization with known values."""
168
+ print("=" * 70)
169
+ print("Test 2: Simple Dequantization")
170
+ print("=" * 70)
171
+
172
+ # Create simple test case: packed values representing [0, 1.0, 2.0, 3.0, ...]
173
+ # Codes: 0=0.0, 2=1.0, 4=2.0, 5=3.0, 6=4.0, 7=6.0
174
+ # Pack: (high << 4) | low
175
+ packed = torch.tensor([
176
+ [0x20, 0x54, 0x76, 0x00, 0x00, 0x00, 0x00, 0x00], # [0,2,4,5,6,7,0,0] -> [0,1,2,3,4,6,0,0]
177
+ ], dtype=torch.uint8)
178
+
179
+ # Uniform scales for simplicity
180
+ scale = torch.ones(1, 1, dtype=torch.float8_e4m3fn)
181
+ scale_2 = torch.tensor([1.0], dtype=torch.float32)
182
+
183
+ result = dequantize_nvfp4(packed, scale, scale_2, dtype=torch.float32)
184
+
185
+ print(f" Packed: {packed[0].tolist()}")
186
+ print(f" Scales: scale={scale.shape}, scale_2={scale_2.item()}")
187
+ print(f" Result shape: {result.shape}")
188
+ print(f" Result values: {result[0].tolist()}")
189
+
190
+ # Expected: [0, 1, 2, 3, 4, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
191
+ expected_values = [0, 1, 2, 3, 4, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
192
+ for i, (val, expected) in enumerate(zip(result[0].tolist(), expected_values)):
193
+ assert abs(val - expected) < 0.01, f"Position {i}: got {val}, expected {expected}"
194
+
195
+ print(" PASS: Dequantization correct")
196
+ print(" PASS: Test 2 PASSED\n")
197
+
198
+
199
+ def test_quantize_dequantize_roundtrip():
200
+ """Test 3: Quantize then dequantize, check error is acceptable."""
201
+ print("=" * 70)
202
+ print("Test 3: Quantization-Dequantization Roundtrip")
203
+ print("=" * 70)
204
+
205
+ # Create test tensor with values in representable range
206
+ M, N = 64, 256
207
+ torch.manual_seed(42)
208
+ fp32_weight = torch.randn(M, N, dtype=torch.float32) * 2.0 # Scale to ~[-6, 6]
209
+
210
+ print(f" Input shape: {fp32_weight.shape}")
211
+ print(f" Input range: [{fp32_weight.min():.3f}, {fp32_weight.max():.3f}]")
212
+
213
+ # Compute scales
214
+ scale, scale_2 = compute_nvfp4_scales(fp32_weight, block_size=16)
215
+ print(f" Scale shape: {scale.shape}, scale_2: {scale_2.item():.6e}")
216
+
217
+ # Quantize
218
+ packed = quantize_to_nvfp4_packed(fp32_weight, scale, scale_2, block_size=16)
219
+ print(f" Packed shape: {packed.shape} (expected [{M}, {N//2}])")
220
+ assert packed.shape == (M, N // 2), f"Packed shape mismatch"
221
+
222
+ # Dequantize
223
+ dequantized = dequantize_nvfp4(packed, scale, scale_2, dtype=torch.float32)
224
+ print(f" Dequantized shape: {dequantized.shape}")
225
+ assert dequantized.shape == (M, N), f"Dequantized shape mismatch"
226
+
227
+ # Compute error
228
+ error = (fp32_weight - dequantized).abs()
229
+ mean_error = error.mean().item()
230
+ max_error = error.max().item()
231
+ relative_error = (error / (fp32_weight.abs() + 1e-8)).mean().item()
232
+
233
+ print(f" Mean absolute error: {mean_error:.6f}")
234
+ print(f" Max absolute error: {max_error:.6f}")
235
+ print(f" Mean relative error: {relative_error:.6f}")
236
+
237
+ # For 4-bit quantization, we expect some error but should be reasonable
238
+ assert mean_error < 1.0, f"Mean error too high: {mean_error}"
239
+ assert relative_error < 0.5, f"Relative error too high: {relative_error}"
240
+
241
+ print(" PASS: Roundtrip error acceptable for 4-bit quantization")
242
+ print(" PASS: Test 3 PASSED\n")
243
+
244
+
245
+ def test_gemm_shapes():
246
+ """Test 4: NVFP4 GEMM with various shapes."""
247
+ print("=" * 70)
248
+ print("Test 4: NVFP4 GEMM Shape Tests")
249
+ print("=" * 70)
250
+
251
+ test_cases = [
252
+ (32, 64, 128), # Small
253
+ (128, 256, 512), # Medium
254
+ (64, 512, 256), # Asymmetric
255
+ ]
256
+
257
+ for M, N, K in test_cases:
258
+ print(f"\n Testing GEMM: [{M}, {K}] @ [{N}, {K}].T = [{M}, {N}]")
259
+
260
+ # Create input activation
261
+ x = torch.randn(M, K, dtype=torch.bfloat16)
262
+
263
+ # Create quantized weight
264
+ weight_fp32 = torch.randn(N, K, dtype=torch.float32) * 2.0
265
+ scale, scale_2 = compute_nvfp4_scales(weight_fp32, block_size=16)
266
+ packed_weight = quantize_to_nvfp4_packed(weight_fp32, scale, scale_2, block_size=16)
267
+
268
+ print(f" Input: {x.shape}, Weight: {packed_weight.shape}")
269
+ print(f" Scales: {scale.shape}, {scale_2.shape}")
270
+
271
+ # Run NVFP4 GEMM
272
+ result = nvfp4_gemm_dequant(x, packed_weight, scale, scale_2)
273
+
274
+ print(f" Output: {result.shape}")
275
+ assert result.shape == (M, N), f"Output shape mismatch: {result.shape} != ({M}, {N})"
276
+
277
+ # Verify no NaN/Inf
278
+ assert not torch.isnan(result).any(), "Output contains NaN"
279
+ assert not torch.isinf(result).any(), "Output contains Inf"
280
+
281
+ print(f" PASS: Shape correct, no NaN/Inf")
282
+
283
+ print("\n PASS: All GEMM shape tests passed")
284
+ print(" PASS: Test 4 PASSED\n")
285
+
286
+
287
+ def test_gemm_correctness():
288
+ """Test 5: Verify NVFP4 GEMM output is close to reference."""
289
+ print("=" * 70)
290
+ print("Test 5: NVFP4 GEMM Correctness")
291
+ print("=" * 70)
292
+
293
+ M, N, K = 64, 128, 256
294
+
295
+ # Create test tensors
296
+ x = torch.randn(M, K, dtype=torch.bfloat16)
297
+ weight_fp32 = torch.randn(N, K, dtype=torch.float32) * 1.5
298
+
299
+ # Quantize weight
300
+ scale, scale_2 = compute_nvfp4_scales(weight_fp32, block_size=16)
301
+ packed_weight = quantize_to_nvfp4_packed(weight_fp32, scale, scale_2, block_size=16)
302
+
303
+ # Run NVFP4 GEMM
304
+ result_nvfp4 = nvfp4_gemm_dequant(x, packed_weight, scale, scale_2)
305
+
306
+ # Run reference GEMM with FP32
307
+ result_reference = F.linear(x, weight_fp32.to(torch.bfloat16))
308
+
309
+ print(f" NVFP4 GEMM output: {result_nvfp4.shape}, dtype={result_nvfp4.dtype}")
310
+ print(f" Reference output: {result_reference.shape}, dtype={result_reference.dtype}")
311
+
312
+ # Compute error
313
+ error = (result_nvfp4.float() - result_reference.float()).abs()
314
+ mean_error = error.mean().item()
315
+ max_error = error.max().item()
316
+ relative_error = (error / (result_reference.float().abs() + 1e-8)).mean().item()
317
+
318
+ print(f" Mean absolute error: {mean_error:.6f}")
319
+ print(f" Max absolute error: {max_error:.6f}")
320
+ print(f" Mean relative error: {relative_error:.6f}")
321
+
322
+ # Due to 4-bit quantization, expect significant error but not catastrophic
323
+ assert mean_error < 5.0, f"Mean error too high: {mean_error}"
324
+ assert relative_error < 1.0, f"Relative error too high: {relative_error}"
325
+
326
+ print(" PASS: NVFP4 GEMM output reasonably close to reference")
327
+ print(" PASS: Test 5 PASSED\n")
328
+
329
+
330
+ def main():
331
+ """Run all NVFP4 kernel unit tests."""
332
+ print("\n" + "=" * 70)
333
+ print("NVFP4 Kernel Unit Tests")
334
+ print("=" * 70)
335
+ print("Testing NVFP4 quantization/dequantization and GEMM operations")
336
+ print("Expected runtime: < 30 seconds")
337
+ print("=" * 70)
338
+
339
+ try:
340
+ # Run all tests
341
+ test_dequant_lookup_table()
342
+ test_dequant_simple()
343
+ test_quantize_dequantize_roundtrip()
344
+ test_gemm_shapes()
345
+ test_gemm_correctness()
346
+
347
+ # Summary
348
+ print("=" * 70)
349
+ print("PASS: ALL TESTS PASSED")
350
+ print("=" * 70)
351
+ print("NVFP4 kernel functions are working correctly!")
352
+ print("Ready to proceed with full model testing.")
353
+ print("=" * 70)
354
+
355
+ return 0
356
+
357
+ except AssertionError as e:
358
+ print(f"\nFAIL: TEST FAILED: {e}")
359
+ import traceback
360
+ traceback.print_exc()
361
+ return 1
362
+ except Exception as e:
363
+ print(f"\nFAIL: UNEXPECTED ERROR: {e}")
364
+ import traceback
365
+ traceback.print_exc()
366
+ return 1
367
+
368
+
369
+ if __name__ == "__main__":
370
+ sys.exit(main())