Buckets:
| """Single-layer validation of the NVFP4 weight converter: quantize+pack a real klein-4B Linear, | |
| load into the real SVDQW4A4Linear kernel, and check the kernel output matches our intended | |
| dequantized weight (convention correct) and is close to bf16 (quant error reasonable).""" | |
| import sys, torch | |
| from nunchaku.models.linear import SVDQW4A4Linear | |
| from flux2distill.nunchaku_export import quantize_pack_nvfp4 | |
| torch.manual_seed(0) | |
| dev = "cuda"; dt = torch.bfloat16 | |
| OUT, IN, RANK, M = 3072, 3072, 128, 1536 | |
| # a realistic weight: load a real klein-4B out-proj if available, else random | |
| try: | |
| from flux2distill.model_utils import load_transformer | |
| tf = load_transformer(dtype="bfloat16", device="cpu") | |
| W = tf.transformer_blocks[0].attn.to_out[0].weight.data.float().to(dev) # (out,in) | |
| OUT, IN = W.shape | |
| print(f"using real klein-4B Linear transformer_blocks.0.attn.to_out.0 ({OUT}x{IN})") | |
| except Exception as e: | |
| W = (torch.randn(OUT, IN, device=dev) * 0.05) | |
| print(f"using random weight ({OUT}x{IN}) [{e}]") | |
| bufs = quantize_pack_nvfp4(W, rank=RANK, refine=3) | |
| W_eff = bufs["W_eff"] | |
| m = SVDQW4A4Linear(IN, OUT, rank=RANK, bias=False, precision="nvfp4", torch_dtype=dt, device=dev) | |
| def setbuf(name, t): | |
| p = getattr(m, name) | |
| print(f" {name}: buf{tuple(p.shape)}/{p.dtype} <- {tuple(t.shape)}/{t.dtype} numel {p.numel()}=={t.numel()}? {p.numel()==t.numel()}") | |
| p.data.copy_(t.reshape(p.shape).to(p.dtype)) | |
| setbuf("qweight", bufs["qweight"]) | |
| setbuf("wscales", bufs["wscales"]) | |
| setbuf("wcscales", bufs["wcscales"]) | |
| setbuf("proj_down", bufs["proj_down"]) | |
| setbuf("proj_up", bufs["proj_up"]) | |
| setbuf("smooth_factor", bufs["smooth_factor"]) | |
| setbuf("smooth_factor_orig", bufs["smooth_factor"]) | |
| m.wtscale = bufs["wtscale"] | |
| print(f" wtscale(alpha) = {m.wtscale:.6g}") | |
| from flux2distill.svdquant import fake_quant_nvfp4 | |
| x = torch.randn(1, M, IN, device=dev, dtype=dt) | |
| with torch.no_grad(): | |
| y_k = m(x).float() | |
| y_eff = (x.float() @ W_eff.t()) # full-prec x | |
| y_bf = (x.float() @ W.t()) | |
| # reference that ALSO 4-bit-quantizes the activation (kernel: Q4(x)@Rq + x@L) | |
| xq = fake_quant_nvfp4(x, group=16).float() | |
| y_ref = xq @ bufs["Rq"].t() + x.float() @ bufs["L"].t() | |
| def rel(a, b): return ((a - b).norm() / (b.norm() + 1e-8)).item() | |
| print(f"\n rel(kernel, y_ref[Q4(x)@Rq + x@L]) = {rel(y_k, y_ref):.4f} <- CONVENTION (want ~0)") | |
| print(f" rel(kernel, W_eff full-prec x) = {rel(y_k, y_eff):.4f} (incl. act-quant)") | |
| print(f" rel(W_eff, bf16) = {rel(y_eff, y_bf):.4f} <- weight quant error") | |
| print(f" rel(kernel, bf16) = {rel(y_k, y_bf):.4f} <- total W4A4") | |
| print("VERDICT:", "CONVENTION OK" if rel(y_k, y_ref) < 0.04 else "CONVENTION MISMATCH (iterate)") | |
Xet Storage Details
- Size:
- 2.77 kB
- Xet hash:
- b12596df5733d3e0b054c8364f0f20a5315da5a4106eca42ee41b26f14f9fcd5
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.