Buckets:

Mercity/FluxDistill / scripts /38_bfl_dequant.py
Pranav2748's picture
download
raw
2.44 kB
"""Dequantize BFL's ModelOpt NVFP4 klein-4B -> a normal bf16 single-file, so diffusers
from_single_file can ingest it (the standard converter handles flux->diffusers naming once the
weights are unquantized; it only choked on the quant layout).
NVFP4 (ModelOpt two-level): weight U8[out,in/2] (2x E2M1/byte) * weight_scale FP8[out,in/16]
(per-group-16) * weight_scale_2 F32 (per-tensor). This is W-only dequant (bf16 activations);
input_scale (static act quant) is dropped -> an UPPER BOUND on BFL quality (favourable to BFL).
Validates against the original klein-4B bf16 single-file (same flux naming) -> rel-err should be
the ~few-% 4-bit quant error. Usage: python3 scripts/38_bfl_dequant.py
"""
import torch
from safetensors import safe_open
from safetensors.torch import save_file
SRC = "models/klein-4b-nvfp4/flux-2-klein-4b-nvfp4.safetensors"
ORIG = "models/klein-4b/flux-2-klein-4b.safetensors"
DST = "models/klein-4b-nvfp4/flux-2-klein-4b-bf16-from-nvfp4.safetensors"
_E2M1 = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.])
def dequant(wu8, wscale_fp8, wscale2):
out, inh = wu8.shape
wu8 = wu8.to(torch.int32)
def dec(n):
sign = torch.where((n & 0x8) > 0, -1.0, 1.0)
return sign * _E2M1[(n & 0x7).long()]
lo, hi = wu8 & 0xF, (wu8 >> 4) & 0xF
vals = torch.stack([dec(lo), dec(hi)], dim=-1).reshape(out, inh * 2) # elem 2j=lo, 2j+1=hi
ws = wscale_fp8.float().repeat_interleave(16, dim=1) # (out, in)
return (vals * ws * float(wscale2)).to(torch.bfloat16)
bf = safe_open(SRC, 'pt')
og = safe_open(ORIG, 'pt')
bkeys = set(bf.keys()); okeys = list(og.keys())
out = {}; errs = []
for k in okeys:
if k.endswith('.weight') and (k + '_scale') in bkeys:
w = dequant(bf.get_tensor(k), bf.get_tensor(k + '_scale'), bf.get_tensor(k + '_scale_2'))
out[k] = w
ow = og.get_tensor(k).float()
rel = (w.float() - ow).norm() / (ow.norm() + 1e-8)
errs.append((k, float(rel)))
else:
out[k] = bf.get_tensor(k)
print(f"dequantized {len(errs)} quantized weights; copied {len(out)-len(errs)} non-quant tensors")
errs.sort(key=lambda x: -x[1])
print("rel-err vs original klein-4B bf16 (worst 5, best 1, mean):")
for k, e in errs[:5]:
print(f" {e:.4f} {k}")
print(f" {errs[-1][1]:.4f} {errs[-1][0]} (best)")
print(f" mean={sum(e for _, e in errs)/len(errs):.4f}")
save_file(out, DST)
print(f"saved bf16 single-file -> {DST}")

Xet Storage Details

Size:
2.44 kB
·
Xet hash:
e2c4e7cc92971dd995fc1e792f59311f2345e43089451ea39a1a803bad9a7cff

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.