Buckets:

Mercity/FluxDistill / scripts /26_convert_test.py
Pranav2748's picture
download
raw
2.77 kB
"""Single-layer validation of the NVFP4 weight converter: quantize+pack a real klein-4B Linear,
load into the real SVDQW4A4Linear kernel, and check the kernel output matches our intended
dequantized weight (convention correct) and is close to bf16 (quant error reasonable)."""
import sys, torch
from nunchaku.models.linear import SVDQW4A4Linear
from flux2distill.nunchaku_export import quantize_pack_nvfp4
torch.manual_seed(0)
dev = "cuda"; dt = torch.bfloat16
OUT, IN, RANK, M = 3072, 3072, 128, 1536
# a realistic weight: load a real klein-4B out-proj if available, else random
try:
from flux2distill.model_utils import load_transformer
tf = load_transformer(dtype="bfloat16", device="cpu")
W = tf.transformer_blocks[0].attn.to_out[0].weight.data.float().to(dev) # (out,in)
OUT, IN = W.shape
print(f"using real klein-4B Linear transformer_blocks.0.attn.to_out.0 ({OUT}x{IN})")
except Exception as e:
W = (torch.randn(OUT, IN, device=dev) * 0.05)
print(f"using random weight ({OUT}x{IN}) [{e}]")
bufs = quantize_pack_nvfp4(W, rank=RANK, refine=3)
W_eff = bufs["W_eff"]
m = SVDQW4A4Linear(IN, OUT, rank=RANK, bias=False, precision="nvfp4", torch_dtype=dt, device=dev)
def setbuf(name, t):
p = getattr(m, name)
print(f" {name}: buf{tuple(p.shape)}/{p.dtype} <- {tuple(t.shape)}/{t.dtype} numel {p.numel()}=={t.numel()}? {p.numel()==t.numel()}")
p.data.copy_(t.reshape(p.shape).to(p.dtype))
setbuf("qweight", bufs["qweight"])
setbuf("wscales", bufs["wscales"])
setbuf("wcscales", bufs["wcscales"])
setbuf("proj_down", bufs["proj_down"])
setbuf("proj_up", bufs["proj_up"])
setbuf("smooth_factor", bufs["smooth_factor"])
setbuf("smooth_factor_orig", bufs["smooth_factor"])
m.wtscale = bufs["wtscale"]
print(f" wtscale(alpha) = {m.wtscale:.6g}")
from flux2distill.svdquant import fake_quant_nvfp4
x = torch.randn(1, M, IN, device=dev, dtype=dt)
with torch.no_grad():
y_k = m(x).float()
y_eff = (x.float() @ W_eff.t()) # full-prec x
y_bf = (x.float() @ W.t())
# reference that ALSO 4-bit-quantizes the activation (kernel: Q4(x)@Rq + x@L)
xq = fake_quant_nvfp4(x, group=16).float()
y_ref = xq @ bufs["Rq"].t() + x.float() @ bufs["L"].t()
def rel(a, b): return ((a - b).norm() / (b.norm() + 1e-8)).item()
print(f"\n rel(kernel, y_ref[Q4(x)@Rq + x@L]) = {rel(y_k, y_ref):.4f} <- CONVENTION (want ~0)")
print(f" rel(kernel, W_eff full-prec x) = {rel(y_k, y_eff):.4f} (incl. act-quant)")
print(f" rel(W_eff, bf16) = {rel(y_eff, y_bf):.4f} <- weight quant error")
print(f" rel(kernel, bf16) = {rel(y_k, y_bf):.4f} <- total W4A4")
print("VERDICT:", "CONVENTION OK" if rel(y_k, y_ref) < 0.04 else "CONVENTION MISMATCH (iterate)")

Xet Storage Details

Size:
2.77 kB
·
Xet hash:
b12596df5733d3e0b054c8364f0f20a5315da5a4106eca42ee41b26f14f9fcd5

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.