Buckets:
| """Save a real packed INT4 (W4A4) klein-4B checkpoint (Nunchaku int4 kernel). | |
| Same conversion loop as scripts/29, but precision='int4': group-64 symmetric int4 residual + | |
| bf16 per-group scales (quantize_pack_int4), no global/channel weight scale. INT4 W4A4 is the | |
| deployable format for RTX 20/30/40 (Turing/Ampere/Ada INT4 IMMA); on this Blackwell box it RUNS | |
| but is slow (no INT4 tensor cores) — so this is a CORRECTNESS check, not a speed bench. | |
| Usage: python3 -u scripts/30_save_int4.py RANK OUT_PATH | |
| """ | |
| import sys, json, os, time, torch | |
| from safetensors.torch import save_file | |
| from flux2distill.model_utils import load_pipeline | |
| from flux2distill.eval_utils import side_by_side | |
| from flux2distill.nunchaku_export import quantize_pack_int4 | |
| from nunchaku.models.transformers.transformer_flux2 import NunchakuFlux2Transformer2DModel | |
| from nunchaku.models.linear import SVDQW4A4Linear | |
| RANK = int(sys.argv[1]) | |
| OUT = sys.argv[2] | |
| PROMPTS = [ | |
| 'a vintage bookshop storefront with a wooden sign that reads "THE OPEN PAGE"', | |
| "a close-up of a smiling young woman holding up five fingers, natural window light, sharp focus on the hand", | |
| ] | |
| print(f"=== save real INT4 W4A4 r{RANK} -> {OUT} ===", flush=True) | |
| pipe = load_pipeline(device="cuda") | |
| tf = pipe.transformer; tf.eval().requires_grad_(False) | |
| def gen(): | |
| imgs = [] | |
| for p in PROMPTS: | |
| g = torch.Generator("cuda").manual_seed(0) | |
| imgs.append(pipe(prompt=p, num_inference_steps=4, guidance_scale=1.0, | |
| height=512, width=512, generator=g).images[0]) | |
| return imgs | |
| print("teacher gen (reference)...", flush=True) | |
| t_imgs = gen() | |
| # ---- capture source weights BEFORE patching ---- | |
| src = {} | |
| for i, b in enumerate(tf.transformer_blocks): | |
| a = b.attn | |
| src[("d", i, "attn.to_qkv")] = torch.cat([a.to_q.weight, a.to_k.weight, a.to_v.weight], 0).clone() | |
| src[("d", i, "attn.to_out.0")] = a.to_out[0].weight.clone() | |
| src[("d", i, "attn.to_added_qkv")] = torch.cat([a.add_q_proj.weight, a.add_k_proj.weight, a.add_v_proj.weight], 0).clone() | |
| src[("d", i, "attn.to_add_out")] = a.to_add_out.weight.clone() | |
| src[("d", i, "ff.linear_in")] = b.ff.linear_in.weight.clone() | |
| src[("d", i, "ff.linear_out")] = b.ff.linear_out.weight.clone() | |
| src[("d", i, "ff_context.linear_in")] = b.ff_context.linear_in.weight.clone() | |
| src[("d", i, "ff_context.linear_out")] = b.ff_context.linear_out.weight.clone() | |
| for i, b in enumerate(tf.single_transformer_blocks): | |
| src[("s", i, "Win")] = b.attn.to_qkv_mlp_proj.weight.clone() | |
| src[("s", i, "Wout")] = b.attn.to_out.weight.clone() | |
| # ---- patch to fused INT4 model + convert every Linear ---- | |
| tf.__class__ = NunchakuFlux2Transformer2DModel | |
| tf._patch_model(precision="int4", rank=RANK, torch_dtype=torch.bfloat16) | |
| t0 = time.time(); n = 0 | |
| for name, m in tf.named_modules(): | |
| if not isinstance(m, SVDQW4A4Linear): continue | |
| m.to_empty(device="cuda") | |
| if m.bias is not None: m.bias.zero_() | |
| parts = name.split("."); bt = "d" if parts[0] == "transformer_blocks" else "s"; idx = int(parts[1]); local = ".".join(parts[2:]) | |
| if bt == "d": W = src[("d", idx, local)] | |
| elif local == "attn.qkv_proj": W = src[("s", idx, "Win")][:m.out_features] | |
| elif local == "attn.mlp_fc1": W = src[("s", idx, "Win")][9216:9216 + m.out_features] | |
| elif local == "attn.out_proj": W = src[("s", idx, "Wout")][:, :m.in_features] | |
| elif local == "attn.mlp_fc2": W = src[("s", idx, "Wout")][:, 3072:3072 + m.in_features] | |
| bufs = quantize_pack_int4(W.float().cuda(), rank=RANK, refine=3) | |
| for nm in ("qweight", "wscales", "proj_down", "proj_up", "smooth_factor"): | |
| getattr(m, nm).data.copy_(bufs[nm].reshape(getattr(m, nm).shape).to(getattr(m, nm).dtype)) | |
| m.smooth_factor_orig.data.copy_(bufs["smooth_factor"].reshape(m.smooth_factor_orig.shape).to(m.smooth_factor_orig.dtype)) | |
| n += 1 | |
| del src; torch.cuda.empty_cache() | |
| print(f"converted {n} Linears in {time.time()-t0:.0f}s", flush=True) | |
| # ---- save ---- | |
| os.makedirs(os.path.dirname(OUT), exist_ok=True) | |
| sd = {k: v.contiguous().cpu() for k, v in tf.state_dict().items()} | |
| meta = {"rank": str(RANK), "precision": "int4", "model": "klein-4B", | |
| "format": "NunchakuFlux2Transformer2DModel state_dict", | |
| "note": "INT4 W4A4; deployable+fast on RTX 20/30/40, slow on Blackwell (no INT4 tensor cores)"} | |
| save_file(sd, OUT, metadata=meta) | |
| gb = sum(v.numel() * v.element_size() for v in sd.values()) / 1e9 | |
| print(f"SAVED -> {OUT} ({gb:.2f}GB)", flush=True) | |
| # ---- correctness: render on the real int4 kernel (slow here), montage vs teacher ---- | |
| print("quant gen (correctness check; INT4 is slow on Blackwell)...", flush=True) | |
| q_imgs = gen() | |
| cmpdir = os.path.dirname(OUT) | |
| for i, (t, q) in enumerate(zip(t_imgs, q_imgs)): | |
| side_by_side(t, q, "teacher", f"INT4-W4A4-r{RANK}", PROMPTS[i]).save(os.path.join(cmpdir, f"cmp_int4_r{RANK}_{i}.png")) | |
| print(f"DONE int4 r{RANK} -> {OUT} (montages: {cmpdir}/cmp_int4_r{RANK}_*.png)", flush=True) | |
Xet Storage Details
- Size:
- 4.96 kB
- Xet hash:
- 5878975b1b5d522b611b078891477e43dcddce1c52f524cf2f548fd77980c84d
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.