Buckets:

Mercity/FluxDistill / scripts /29_save_real_quant.py
Pranav2748's picture
download
raw
4.97 kB
"""Save a DEPLOYABLE NVFP4 W4A4 klein-4B checkpoint at a given rank (real Nunchaku FP4 kernel).
Conversion block is the proven one from scripts/28, parameterized by RANK + OUT and with the
heavy bench matrix replaced by a light teacher-vs-quant correctness montage. Saves the
NunchakuFlux2Transformer2DModel state_dict + per-layer wtscales (in the safetensors metadata,
same convention as scripts/28's r128 deploy file).
Usage: python3 -u scripts/29_save_real_quant.py RANK OUT_PATH
"""
import sys, json, os, time, torch
from safetensors.torch import save_file
from flux2distill.model_utils import load_pipeline
from flux2distill.eval_utils import side_by_side
from flux2distill.nunchaku_export import quantize_pack_nvfp4
from nunchaku.models.transformers.transformer_flux2 import NunchakuFlux2Transformer2DModel
from nunchaku.models.linear import SVDQW4A4Linear
RANK = int(sys.argv[1])
OUT = sys.argv[2]
PROMPTS = [
'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"',
"a close-up of a smiling young woman holding up five fingers, natural window light, sharp focus on the hand",
]
print(f"=== save real NVFP4 W4A4 r{RANK} -> {OUT} ===", flush=True)
pipe = load_pipeline(device="cuda")
tf = pipe.transformer; tf.eval().requires_grad_(False)
def gen():
imgs = []
for p in PROMPTS:
g = torch.Generator("cuda").manual_seed(0)
imgs.append(pipe(prompt=p, num_inference_steps=4, guidance_scale=1.0,
height=512, width=512, generator=g).images[0])
return imgs
print("teacher gen (reference)...", flush=True)
t_imgs = gen()
# ---- capture source weights BEFORE patching (patching discards the fused originals) ----
src = {}
for i, b in enumerate(tf.transformer_blocks):
a = b.attn
src[("d", i, "attn.to_qkv")] = torch.cat([a.to_q.weight, a.to_k.weight, a.to_v.weight], 0).clone()
src[("d", i, "attn.to_out.0")] = a.to_out[0].weight.clone()
src[("d", i, "attn.to_added_qkv")] = torch.cat([a.add_q_proj.weight, a.add_k_proj.weight, a.add_v_proj.weight], 0).clone()
src[("d", i, "attn.to_add_out")] = a.to_add_out.weight.clone()
src[("d", i, "ff.linear_in")] = b.ff.linear_in.weight.clone()
src[("d", i, "ff.linear_out")] = b.ff.linear_out.weight.clone()
src[("d", i, "ff_context.linear_in")] = b.ff_context.linear_in.weight.clone()
src[("d", i, "ff_context.linear_out")] = b.ff_context.linear_out.weight.clone()
for i, b in enumerate(tf.single_transformer_blocks):
src[("s", i, "Win")] = b.attn.to_qkv_mlp_proj.weight.clone()
src[("s", i, "Wout")] = b.attn.to_out.weight.clone()
# ---- patch to fused NVFP4 model + convert every Linear ----
tf.__class__ = NunchakuFlux2Transformer2DModel
tf._patch_model(precision="nvfp4", rank=RANK, torch_dtype=torch.bfloat16)
wtscales = {}; t0 = time.time(); n = 0
for name, m in tf.named_modules():
if not isinstance(m, SVDQW4A4Linear): continue
m.to_empty(device="cuda")
if m.bias is not None: m.bias.zero_()
parts = name.split("."); bt = "d" if parts[0] == "transformer_blocks" else "s"; idx = int(parts[1]); local = ".".join(parts[2:])
if bt == "d": W = src[("d", idx, local)]
elif local == "attn.qkv_proj": W = src[("s", idx, "Win")][:m.out_features]
elif local == "attn.mlp_fc1": W = src[("s", idx, "Win")][9216:9216 + m.out_features]
elif local == "attn.out_proj": W = src[("s", idx, "Wout")][:, :m.in_features]
elif local == "attn.mlp_fc2": W = src[("s", idx, "Wout")][:, 3072:3072 + m.in_features]
bufs = quantize_pack_nvfp4(W.float().cuda(), rank=RANK, refine=3)
for nm in ("qweight", "wscales", "wcscales", "proj_down", "proj_up", "smooth_factor"):
getattr(m, nm).data.copy_(bufs[nm].reshape(getattr(m, nm).shape).to(getattr(m, nm).dtype))
m.smooth_factor_orig.data.copy_(bufs["smooth_factor"].reshape(m.smooth_factor_orig.shape).to(m.smooth_factor_orig.dtype))
m.wtscale = bufs["wtscale"]; wtscales[name] = bufs["wtscale"]; n += 1
del src; torch.cuda.empty_cache()
print(f"converted {n} Linears in {time.time()-t0:.0f}s", flush=True)
# ---- save ----
os.makedirs(os.path.dirname(OUT), exist_ok=True)
sd = {k: v.contiguous().cpu() for k, v in tf.state_dict().items()}
meta = {"wtscales": json.dumps(wtscales), "rank": str(RANK), "precision": "nvfp4",
"model": "klein-4B", "format": "NunchakuFlux2Transformer2DModel state_dict"}
save_file(sd, OUT, metadata=meta)
gb = sum(v.numel() * v.element_size() for v in sd.values()) / 1e9
print(f"SAVED -> {OUT} ({gb:.2f}GB)", flush=True)
# ---- correctness: render the same prompts on the real FP4 kernel, montage vs teacher ----
print("quant gen (correctness check)...", flush=True)
q_imgs = gen()
cmpdir = os.path.dirname(OUT)
for i, (t, q) in enumerate(zip(t_imgs, q_imgs)):
side_by_side(t, q, "teacher", f"NVFP4-W4A4-r{RANK}", PROMPTS[i]).save(os.path.join(cmpdir, f"cmp_r{RANK}_{i}.png"))
print(f"DONE r{RANK} -> {OUT} (montages: {cmpdir}/cmp_r{RANK}_*.png)", flush=True)

Xet Storage Details

Size:
4.97 kB
·
Xet hash:
50410f86d1f04c5ca2cbbb102fb44534d187d2ef1c4fd3b9273456016b8e10d9

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.