Buckets:
| """Generate all report figures for the NVFP4 campaign -> report/figures/*.png + a 2x3 grid.""" | |
| import os, json | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| OUT = "report/figures"; os.makedirs(OUT, exist_ok=True) | |
| plt.rcParams.update({"figure.dpi": 130, "font.size": 9, "axes.grid": True, | |
| "grid.alpha": 0.3, "axes.axisbelow": True}) | |
| # ---------------- data (2026-06-13 axis; held-out velocity loss, lower=better) ---------------- | |
| QUAL = { # format -> {rank: eval_loss} | |
| "NVFP4 W4A4 (g16)": {32: 0.0390, 64: 0.0364, 128: 0.0303}, | |
| "NVFP4-W + FP8-A": {64: 0.0204, 128: 0.0169}, | |
| "INT4 W4A4 (g64)": {64: 0.0742, 128: 0.0610}, | |
| "INT4 W4A8 (smooth=0)": {16: 0.0348, 32: 0.0331, 64: 0.0297}, | |
| "INT4 W4A8 (alpha=0.5)": {16: 0.0405, 32: 0.0362, 64: 0.0325}, | |
| } | |
| COL = {"NVFP4 W4A4 (g16)": "#d62728", "NVFP4-W + FP8-A": "#9467bd", | |
| "INT4 W4A4 (g64)": "#1f77b4", "INT4 W4A8 (smooth=0)": "#2ca02c", | |
| "INT4 W4A8 (alpha=0.5)": "#7f9f7f"} | |
| # per-layer kernel speedup on Blackwell (x vs bf16), klein-4B shapes T=1536 | |
| PERLAYER = {"bf16": 1.0, "NVFP4 W4A4 r64": 2.75, "NVFP4 W4A4 r128": 2.49, | |
| "FP8 (W4+A8) r64/128": 1.21, "INT4 (any) on sm_120": 0.41} | |
| # Pareto points: (blackwell per-layer speedup, eval_loss, label, color) | |
| PARETO = [ | |
| (1.00, 0.0010, "bf16 (teacher)", "k"), | |
| (2.49, 0.0303, "NVFP4 W4A4 r128", "#d62728"), | |
| (2.75, 0.0364, "NVFP4 W4A4 r64", "#d62728"), | |
| (1.21, 0.0169, "NVFP4-W+FP8-A r128", "#9467bd"), | |
| (1.21, 0.0204, "NVFP4-W+FP8-A r64", "#9467bd"), | |
| (0.41, 0.0297, "INT4 W4A8 r64", "#2ca02c"), | |
| (0.41, 0.0610, "INT4 W4A4 r128", "#1f77b4"), | |
| ] | |
| b = json.load(open("outputs/nvfp4/benchmark.json")) | |
| bf = {(r["batch"], r["res"]): r for r in b["bf16"]} | |
| nv = {(r["batch"], r["res"]): r for r in b["nvfp4_fused"]} | |
| def fig(name, w=5, h=4): | |
| f, ax = plt.subplots(figsize=(w, h)); return f, ax, f"{OUT}/{name}.png" | |
| # 1) quality vs rank by format | |
| def plot_quality(ax): | |
| for fmt, d in QUAL.items(): | |
| xs = sorted(d); ax.plot(xs, [d[r] for r in xs], "o-", color=COL[fmt], label=fmt, lw=1.8, ms=5) | |
| ax.set_xlabel("low-rank branch rank"); ax.set_ylabel("eval-loss (vel) ↓") | |
| ax.set_title("Quality vs rank — NVFP4 ≫ INT4 at matched bits"); ax.legend(fontsize=7) | |
| ax.set_xticks([16, 32, 64, 128]) | |
| # 2) Pareto: quality vs realized Blackwell speed | |
| def plot_pareto(ax): | |
| for x, y, lab, c in PARETO: | |
| ax.scatter(x, y, s=70, color=c, zorder=3, edgecolor="k", lw=0.5) | |
| ax.annotate(lab, (x, y), fontsize=6.5, xytext=(4, 3), textcoords="offset points") | |
| ax.axvline(1.0, color="gray", ls="--", lw=0.8); ax.text(1.02, 0.066, "bf16", color="gray", fontsize=7) | |
| ax.set_xlabel("realized per-layer speedup on Blackwell (×)"); ax.set_ylabel("eval-loss ↓") | |
| ax.set_title("Pareto: quality vs speed (sm_120)\n↙ better") | |
| ax.annotate("", xy=(2.6, 0.012), xytext=(0.6, 0.058), arrowprops=dict(arrowstyle="->", color="green", alpha=.5)) | |
| # 3) per-layer kernel speedup bars | |
| def plot_perlayer(ax): | |
| names = list(PERLAYER); vals = [PERLAYER[n] for n in names] | |
| cols = ["gray", "#d62728", "#d62728", "#9467bd", "#1f77b4"] | |
| ax.bar(range(len(names)), vals, color=cols) | |
| ax.axhline(1.0, color="k", ls="--", lw=0.8) | |
| for i, v in enumerate(vals): ax.text(i, v + .04, f"{v:.2f}×", ha="center", fontsize=7) | |
| ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=25, ha="right", fontsize=6.5) | |
| ax.set_ylabel("speedup vs bf16 (×)"); ax.set_title("Per-layer GEMM speedup (real kernels)") | |
| # 4) end-to-end bf16 vs NVFP4-fused | |
| def plot_e2e(ax): | |
| res = ["512x512", "1024x1024"]; x = range(len(res)); wdt = 0.35 | |
| bft = [bf[(1, r)]["total_s"] for r in res]; nvt = [nv[(1, r)]["total_s"] for r in res] | |
| ax.bar([i - wdt/2 for i in x], bft, wdt, label="bf16", color="gray") | |
| ax.bar([i + wdt/2 for i in x], nvt, wdt, label="NVFP4-fused", color="#d62728") | |
| for i, r in enumerate(res): | |
| ax.text(i, max(bft[i], nvt[i]) + .05, f"{bf[(1,r)]['total_s']/nv[(1,r)]['total_s']:.2f}×", ha="center", fontsize=8, color="#d62728") | |
| ax.set_xticks(list(x)); ax.set_xticklabels(res); ax.set_ylabel("s / image ↓") | |
| ax.set_title("End-to-end (batch 1, 4-step)"); ax.legend(fontsize=7) | |
| # 5) throughput vs batch (bf16 saturation) | |
| def plot_batch(ax): | |
| for r, mk in [("512x512", "o-"), ("1024x1024", "s-")]: | |
| bs = [1, 2, 4]; ax.plot(bs, [bf[(B, r)]["img_per_s"] for B in bs], mk, label=f"bf16 {r}", lw=1.6) | |
| ax.set_xlabel("batch size"); ax.set_ylabel("img / s"); ax.set_xticks([1, 2, 4]) | |
| ax.set_title("Throughput vs batch — 4B is GPU-saturated"); ax.legend(fontsize=7) | |
| # 6) VRAM | |
| def plot_vram(ax): | |
| res = ["512x512", "1024x1024"]; x = range(len(res)); wdt = 0.35 | |
| ax.bar([i - wdt/2 for i in x], [bf[(1, r)]["vram_gb"] for r in res], wdt, label="bf16", color="gray") | |
| ax.bar([i + wdt/2 for i in x], [nv[(1, r)]["vram_gb"] for r in res], wdt, label="NVFP4-fused", color="#d62728") | |
| ax.set_xticks(list(x)); ax.set_xticklabels(res); ax.set_ylabel("peak VRAM (GB) ↓") | |
| ax.set_title("VRAM (−24%)"); ax.legend(fontsize=7) | |
| panels = [("quality_vs_rank", plot_quality), ("pareto_quality_speed", plot_pareto), | |
| ("perlayer_speedup", plot_perlayer), ("end_to_end", plot_e2e), | |
| ("throughput_batch", plot_batch), ("vram", plot_vram)] | |
| for nm, fn in panels: | |
| f, ax, p = fig(nm); fn(ax); f.tight_layout(); f.savefig(p); plt.close(f); print("saved", p) | |
| # grand 2x3 grid | |
| F, axes = plt.subplots(2, 3, figsize=(15, 9)) | |
| for (nm, fn), ax in zip(panels, axes.flat): | |
| fn(ax) | |
| F.suptitle("FLUX.2 klein-4B — NVFP4 SVDQuant on RTX PRO 4500 Blackwell (sm_120): quality + speed", fontsize=13, y=1.00) | |
| F.tight_layout(); F.savefig(f"{OUT}/GRID_overview.png", bbox_inches="tight"); plt.close(F) | |
| print(f"saved {OUT}/GRID_overview.png") | |
Xet Storage Details
- Size:
- 5.84 kB
- Xet hash:
- 666d3a08bd0526c38630f31efc4a8f8f77c27168d49e1f63d4db080b4838a90b
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.