Buckets:
| """Unit test for the NVFP4 / FP8 fake-quant primitives (run on CUDA — fp8 casts need it).""" | |
| import torch | |
| from flux2distill.svdquant import (_round_e2m1, fake_quant_nvfp4, fake_quant_fp8, | |
| fake_quant_weight, fake_quant_act, SVDQuantLinear, | |
| _E2M1_LEVELS) | |
| dev = "cuda" | |
| ok = True | |
| # 1. E2M1 rounding snaps to the 8 levels | |
| a = torch.tensor([0.0, 0.1, 0.3, 0.6, 0.9, 1.2, 1.7, 2.4, 3.2, 5.5, 7.0], device=dev) | |
| r = _round_e2m1(a.clamp(max=6.0)) | |
| print("1) e2m1 round:", [round(v, 2) for v in r.tolist()]) | |
| assert set(r.tolist()) <= set(_E2M1_LEVELS), "values not on E2M1 grid!" | |
| assert _round_e2m1(torch.tensor([0.3], device=dev)).item() == 0.5 # 0.3 -> 0.5 | |
| assert _round_e2m1(torch.tensor([2.4], device=dev)).item() == 2.0 # 2.4 -> 2.0 | |
| assert _round_e2m1(torch.tensor([3.2], device=dev)).item() == 3.0 | |
| # 2. weight recon: nvfp4-g16 vs int4-g64 on a gaussian weight | |
| torch.manual_seed(0) | |
| W = (torch.randn(512, 512, device=dev) * 0.05) | |
| def rel(q, W): return ((q - W).norm() / W.norm()).item() | |
| i4 = rel(fake_quant_weight(W, 4, 64), W) | |
| nv = rel(fake_quant_nvfp4(W, 16), W) | |
| print(f"2) gaussian weight rel-err: int4-g64={i4:.4f} nvfp4-g16={nv:.4f}") | |
| # 3. outlier channel (where finer grouping should help nvfp4) | |
| W2 = W.clone(); W2[:, 0] *= 60.0 | |
| i4o = rel(fake_quant_weight(W2, 4, 64), W2) | |
| nvo = rel(fake_quant_nvfp4(W2, 16), W2) | |
| print(f"3) +60x outlier col rel-err: int4-g64={i4o:.4f} nvfp4-g16={nvo:.4f}") | |
| # 4. fp8 activation recon | |
| xa = torch.randn(8, 512, device=dev) | |
| print(f"4) fp8 act rel-err (per-token)={rel(fake_quant_fp8(xa, 0), xa):.4f} " | |
| f"int8={rel(fake_quant_act(xa, 8, 0), xa):.4f}") | |
| # 5. module forward (nvfp4 W4A4) produces finite output, right shape | |
| m = SVDQuantLinear(512, 256, True, 16, w_bits=4, a_bits=4, w_group=16, a_group=16, | |
| dtype=torch.bfloat16, w_fmt="nvfp4", a_fmt="nvfp4").to(dev) | |
| m.lora_down.normal_(0, 0.02); m.lora_up.normal_(0, 0.02) | |
| m.w_res.copy_(fake_quant_nvfp4(torch.randn(256, 512, device=dev) * 0.05, 16).to(torch.bfloat16)) | |
| x = torch.randn(2, 10, 512, device=dev, dtype=torch.bfloat16) | |
| y = m(x) | |
| print(f"5) nvfp4 module fwd: shape={tuple(y.shape)} finite={torch.isfinite(y).all().item()}") | |
| assert y.shape == (2, 10, 256) and torch.isfinite(y).all() | |
| # 6. back-compat: int path byte-identical to before (default fmt) | |
| mi = SVDQuantLinear(512, 256, True, 16, dtype=torch.bfloat16) # defaults w_fmt=a_fmt=int | |
| assert mi.w_fmt == "int" and mi.a_fmt == "int" | |
| print("6) int default back-compat OK") | |
| print("ALL NVFP4 TESTS PASSED") | |
Xet Storage Details
- Size:
- 2.56 kB
- Xet hash:
- fb07102361cad661a3cb2ef6f7a649806e5b0b8326775125c7ede07d8f253e32
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.