Buckets:
| """Generate eval images on the REAL kernel + measure speed/VRAM (one model per process). | |
| Modes: | |
| teacher bf16 teacher (baseline speed/VRAM + reference images) | |
| ours:RANK REAL Nunchaku NVFP4 W4A4 fused kernel (in-memory convert via our exporter) — the | |
| actual deployable artifact; measures real per-step latency + peak VRAM. | |
| Always generates+times (no skip) so the speed number is real. Reads outputs/eval/prompts.json; | |
| seed=idx (paired across modes). Saves {OUT}/{idx:05d}.png + outputs/eval/timing_<TAG>.json. | |
| Usage: python3 -u scripts/35_gen_real.py MODE OUT_DIR TAG [START] [COUNT] [RES] | |
| """ | |
| import sys, json, os, time, statistics as st, torch | |
| from flux2distill.model_utils import load_pipeline | |
| MODE, OUT, TAG = sys.argv[1], sys.argv[2], sys.argv[3] | |
| START = int(sys.argv[4]) if len(sys.argv) > 4 else 0 | |
| COUNT = int(sys.argv[5]) if len(sys.argv) > 5 else 10**9 | |
| RES = int(sys.argv[6]) if len(sys.argv) > 6 else 512 | |
| os.makedirs(OUT, exist_ok=True) | |
| prompts = json.load(open(os.environ.get('PROMPTS_JSON', 'outputs/eval/prompts.json')))[START:START + COUNT] | |
| print(f"=== REAL-gen MODE={MODE} OUT={OUT} TAG={TAG} N={len(prompts)} RES={RES} ===", flush=True) | |
| pipe = load_pipeline(device='cuda'); tf = pipe.transformer; tf.eval().requires_grad_(False) | |
| if MODE.startswith('ours:'): | |
| RANK = int(MODE.split(':')[1]) | |
| from flux2distill.nunchaku_export import quantize_pack_nvfp4 | |
| from nunchaku.models.transformers.transformer_flux2 import NunchakuFlux2Transformer2DModel | |
| from nunchaku.models.linear import SVDQW4A4Linear | |
| src = {} | |
| for i, b in enumerate(tf.transformer_blocks): | |
| a = b.attn | |
| src[("d", i, "attn.to_qkv")] = torch.cat([a.to_q.weight, a.to_k.weight, a.to_v.weight], 0).clone() | |
| src[("d", i, "attn.to_out.0")] = a.to_out[0].weight.clone() | |
| src[("d", i, "attn.to_added_qkv")] = torch.cat([a.add_q_proj.weight, a.add_k_proj.weight, a.add_v_proj.weight], 0).clone() | |
| src[("d", i, "attn.to_add_out")] = a.to_add_out.weight.clone() | |
| src[("d", i, "ff.linear_in")] = b.ff.linear_in.weight.clone() | |
| src[("d", i, "ff.linear_out")] = b.ff.linear_out.weight.clone() | |
| src[("d", i, "ff_context.linear_in")] = b.ff_context.linear_in.weight.clone() | |
| src[("d", i, "ff_context.linear_out")] = b.ff_context.linear_out.weight.clone() | |
| for i, b in enumerate(tf.single_transformer_blocks): | |
| src[("s", i, "Win")] = b.attn.to_qkv_mlp_proj.weight.clone() | |
| src[("s", i, "Wout")] = b.attn.to_out.weight.clone() | |
| tf.__class__ = NunchakuFlux2Transformer2DModel | |
| tf._patch_model(precision="nvfp4", rank=RANK, torch_dtype=torch.bfloat16) | |
| t0 = time.time(); n = 0 | |
| for name, m in tf.named_modules(): | |
| if not isinstance(m, SVDQW4A4Linear): | |
| continue | |
| m.to_empty(device="cuda") | |
| if m.bias is not None: | |
| m.bias.zero_() | |
| parts = name.split("."); bt = "d" if parts[0] == "transformer_blocks" else "s"; idx = int(parts[1]); local = ".".join(parts[2:]) | |
| if bt == "d": W = src[("d", idx, local)] | |
| elif local == "attn.qkv_proj": W = src[("s", idx, "Win")][:m.out_features] | |
| elif local == "attn.mlp_fc1": W = src[("s", idx, "Win")][9216:9216 + m.out_features] | |
| elif local == "attn.out_proj": W = src[("s", idx, "Wout")][:, :m.in_features] | |
| elif local == "attn.mlp_fc2": W = src[("s", idx, "Wout")][:, 3072:3072 + m.in_features] | |
| bufs = quantize_pack_nvfp4(W.float().cuda(), rank=RANK, refine=3) | |
| for nm in ("qweight", "wscales", "wcscales", "proj_down", "proj_up", "smooth_factor"): | |
| getattr(m, nm).data.copy_(bufs[nm].reshape(getattr(m, nm).shape).to(getattr(m, nm).dtype)) | |
| m.smooth_factor_orig.data.copy_(bufs["smooth_factor"].reshape(m.smooth_factor_orig.shape).to(m.smooth_factor_orig.dtype)) | |
| m.wtscale = bufs["wtscale"]; n += 1 | |
| del src; torch.cuda.empty_cache() | |
| print(f"converted {n} Linears -> REAL NVFP4 W4A4 kernel in {time.time()-t0:.0f}s", flush=True) | |
| elif MODE != 'teacher': | |
| raise SystemExit(f"unknown MODE {MODE}") | |
| # per-step transformer timing | |
| steps = [] | |
| _fwd = tf.forward | |
| def timed(*a, **k): | |
| torch.cuda.synchronize(); s = time.perf_counter(); o = _fwd(*a, **k); torch.cuda.synchronize(); steps.append(time.perf_counter() - s); return o | |
| tf.forward = timed | |
| def gen(prompt, seed): | |
| # NO autocast: the fused NVFP4 kernel manages its own dtypes (autocast segfaults it). | |
| g = torch.Generator('cuda').manual_seed(seed) | |
| return pipe(prompt=prompt, num_inference_steps=4, guidance_scale=1.0, height=RES, width=RES, generator=g).images[0] | |
| torch.cuda.reset_peak_memory_stats() | |
| t0 = time.time(); times = [] | |
| for d in prompts: | |
| ts = time.perf_counter() | |
| gen(d['prompt'], d['idx']).save(os.path.join(OUT, f"{d['idx']:05d}.png")) | |
| times.append(time.perf_counter() - ts) | |
| if len(times) % 25 == 0: | |
| print(f" {len(times)}/{len(prompts)} ({time.time()-t0:.0f}s)", flush=True) | |
| # drop first (warmup) for speed stats | |
| sp = times[1:] if len(times) > 1 else times | |
| stp = steps[4:] if len(steps) > 4 else steps # skip first image's 4 steps | |
| timing = {"mode": MODE, "tag": TAG, "res": RES, "n": len(times), | |
| "s_per_img_median": round(st.median(sp), 4) if sp else None, | |
| "img_per_s": round(len(sp) / sum(sp), 3) if sp else None, | |
| "step_ms_median": round(st.median(stp) * 1000, 2) if stp else None, | |
| "peak_vram_gb": round(torch.cuda.max_memory_allocated() / 1e9, 2)} | |
| json.dump(timing, open(f"outputs/eval/timing_{TAG}.json", "w"), indent=2) | |
| print(f"DONE {MODE} -> {OUT} | {timing}", flush=True) | |
Xet Storage Details
- Size:
- 5.6 kB
- Xet hash:
- f22517523d15b07eb48deec4868e9ed327d31617312746dc6b6cac39fe5a6860
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.