Buckets:

Mercity/FluxDistill / scripts /test_nvfp4.py
Pranav2748's picture
download
raw
2.56 kB
"""Unit test for the NVFP4 / FP8 fake-quant primitives (run on CUDA — fp8 casts need it)."""
import torch
from flux2distill.svdquant import (_round_e2m1, fake_quant_nvfp4, fake_quant_fp8,
fake_quant_weight, fake_quant_act, SVDQuantLinear,
_E2M1_LEVELS)
dev = "cuda"
ok = True
# 1. E2M1 rounding snaps to the 8 levels
a = torch.tensor([0.0, 0.1, 0.3, 0.6, 0.9, 1.2, 1.7, 2.4, 3.2, 5.5, 7.0], device=dev)
r = _round_e2m1(a.clamp(max=6.0))
print("1) e2m1 round:", [round(v, 2) for v in r.tolist()])
assert set(r.tolist()) <= set(_E2M1_LEVELS), "values not on E2M1 grid!"
assert _round_e2m1(torch.tensor([0.3], device=dev)).item() == 0.5 # 0.3 -> 0.5
assert _round_e2m1(torch.tensor([2.4], device=dev)).item() == 2.0 # 2.4 -> 2.0
assert _round_e2m1(torch.tensor([3.2], device=dev)).item() == 3.0
# 2. weight recon: nvfp4-g16 vs int4-g64 on a gaussian weight
torch.manual_seed(0)
W = (torch.randn(512, 512, device=dev) * 0.05)
def rel(q, W): return ((q - W).norm() / W.norm()).item()
i4 = rel(fake_quant_weight(W, 4, 64), W)
nv = rel(fake_quant_nvfp4(W, 16), W)
print(f"2) gaussian weight rel-err: int4-g64={i4:.4f} nvfp4-g16={nv:.4f}")
# 3. outlier channel (where finer grouping should help nvfp4)
W2 = W.clone(); W2[:, 0] *= 60.0
i4o = rel(fake_quant_weight(W2, 4, 64), W2)
nvo = rel(fake_quant_nvfp4(W2, 16), W2)
print(f"3) +60x outlier col rel-err: int4-g64={i4o:.4f} nvfp4-g16={nvo:.4f}")
# 4. fp8 activation recon
xa = torch.randn(8, 512, device=dev)
print(f"4) fp8 act rel-err (per-token)={rel(fake_quant_fp8(xa, 0), xa):.4f} "
f"int8={rel(fake_quant_act(xa, 8, 0), xa):.4f}")
# 5. module forward (nvfp4 W4A4) produces finite output, right shape
m = SVDQuantLinear(512, 256, True, 16, w_bits=4, a_bits=4, w_group=16, a_group=16,
dtype=torch.bfloat16, w_fmt="nvfp4", a_fmt="nvfp4").to(dev)
m.lora_down.normal_(0, 0.02); m.lora_up.normal_(0, 0.02)
m.w_res.copy_(fake_quant_nvfp4(torch.randn(256, 512, device=dev) * 0.05, 16).to(torch.bfloat16))
x = torch.randn(2, 10, 512, device=dev, dtype=torch.bfloat16)
y = m(x)
print(f"5) nvfp4 module fwd: shape={tuple(y.shape)} finite={torch.isfinite(y).all().item()}")
assert y.shape == (2, 10, 256) and torch.isfinite(y).all()
# 6. back-compat: int path byte-identical to before (default fmt)
mi = SVDQuantLinear(512, 256, True, 16, dtype=torch.bfloat16) # defaults w_fmt=a_fmt=int
assert mi.w_fmt == "int" and mi.a_fmt == "int"
print("6) int default back-compat OK")
print("ALL NVFP4 TESTS PASSED")

Xet Storage Details

Size:
2.56 kB
·
Xet hash:
fb07102361cad661a3cb2ef6f7a649806e5b0b8326775125c7ede07d8f253e32

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.