Buckets:

Mercity/FluxDistill / scripts /make_nvfp4_figures.py
Pranav2748's picture
download
raw
5.84 kB
"""Generate all report figures for the NVFP4 campaign -> report/figures/*.png + a 2x3 grid."""
import os, json
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
OUT = "report/figures"; os.makedirs(OUT, exist_ok=True)
plt.rcParams.update({"figure.dpi": 130, "font.size": 9, "axes.grid": True,
"grid.alpha": 0.3, "axes.axisbelow": True})
# ---------------- data (2026-06-13 axis; held-out velocity loss, lower=better) ----------------
QUAL = { # format -> {rank: eval_loss}
"NVFP4 W4A4 (g16)": {32: 0.0390, 64: 0.0364, 128: 0.0303},
"NVFP4-W + FP8-A": {64: 0.0204, 128: 0.0169},
"INT4 W4A4 (g64)": {64: 0.0742, 128: 0.0610},
"INT4 W4A8 (smooth=0)": {16: 0.0348, 32: 0.0331, 64: 0.0297},
"INT4 W4A8 (alpha=0.5)": {16: 0.0405, 32: 0.0362, 64: 0.0325},
}
COL = {"NVFP4 W4A4 (g16)": "#d62728", "NVFP4-W + FP8-A": "#9467bd",
"INT4 W4A4 (g64)": "#1f77b4", "INT4 W4A8 (smooth=0)": "#2ca02c",
"INT4 W4A8 (alpha=0.5)": "#7f9f7f"}
# per-layer kernel speedup on Blackwell (x vs bf16), klein-4B shapes T=1536
PERLAYER = {"bf16": 1.0, "NVFP4 W4A4 r64": 2.75, "NVFP4 W4A4 r128": 2.49,
"FP8 (W4+A8) r64/128": 1.21, "INT4 (any) on sm_120": 0.41}
# Pareto points: (blackwell per-layer speedup, eval_loss, label, color)
PARETO = [
(1.00, 0.0010, "bf16 (teacher)", "k"),
(2.49, 0.0303, "NVFP4 W4A4 r128", "#d62728"),
(2.75, 0.0364, "NVFP4 W4A4 r64", "#d62728"),
(1.21, 0.0169, "NVFP4-W+FP8-A r128", "#9467bd"),
(1.21, 0.0204, "NVFP4-W+FP8-A r64", "#9467bd"),
(0.41, 0.0297, "INT4 W4A8 r64", "#2ca02c"),
(0.41, 0.0610, "INT4 W4A4 r128", "#1f77b4"),
]
b = json.load(open("outputs/nvfp4/benchmark.json"))
bf = {(r["batch"], r["res"]): r for r in b["bf16"]}
nv = {(r["batch"], r["res"]): r for r in b["nvfp4_fused"]}
def fig(name, w=5, h=4):
f, ax = plt.subplots(figsize=(w, h)); return f, ax, f"{OUT}/{name}.png"
# 1) quality vs rank by format
def plot_quality(ax):
for fmt, d in QUAL.items():
xs = sorted(d); ax.plot(xs, [d[r] for r in xs], "o-", color=COL[fmt], label=fmt, lw=1.8, ms=5)
ax.set_xlabel("low-rank branch rank"); ax.set_ylabel("eval-loss (vel) ↓")
ax.set_title("Quality vs rank — NVFP4 ≫ INT4 at matched bits"); ax.legend(fontsize=7)
ax.set_xticks([16, 32, 64, 128])
# 2) Pareto: quality vs realized Blackwell speed
def plot_pareto(ax):
for x, y, lab, c in PARETO:
ax.scatter(x, y, s=70, color=c, zorder=3, edgecolor="k", lw=0.5)
ax.annotate(lab, (x, y), fontsize=6.5, xytext=(4, 3), textcoords="offset points")
ax.axvline(1.0, color="gray", ls="--", lw=0.8); ax.text(1.02, 0.066, "bf16", color="gray", fontsize=7)
ax.set_xlabel("realized per-layer speedup on Blackwell (×)"); ax.set_ylabel("eval-loss ↓")
ax.set_title("Pareto: quality vs speed (sm_120)\n↙ better")
ax.annotate("", xy=(2.6, 0.012), xytext=(0.6, 0.058), arrowprops=dict(arrowstyle="->", color="green", alpha=.5))
# 3) per-layer kernel speedup bars
def plot_perlayer(ax):
names = list(PERLAYER); vals = [PERLAYER[n] for n in names]
cols = ["gray", "#d62728", "#d62728", "#9467bd", "#1f77b4"]
ax.bar(range(len(names)), vals, color=cols)
ax.axhline(1.0, color="k", ls="--", lw=0.8)
for i, v in enumerate(vals): ax.text(i, v + .04, f"{v:.2f}×", ha="center", fontsize=7)
ax.set_xticks(range(len(names))); ax.set_xticklabels(names, rotation=25, ha="right", fontsize=6.5)
ax.set_ylabel("speedup vs bf16 (×)"); ax.set_title("Per-layer GEMM speedup (real kernels)")
# 4) end-to-end bf16 vs NVFP4-fused
def plot_e2e(ax):
res = ["512x512", "1024x1024"]; x = range(len(res)); wdt = 0.35
bft = [bf[(1, r)]["total_s"] for r in res]; nvt = [nv[(1, r)]["total_s"] for r in res]
ax.bar([i - wdt/2 for i in x], bft, wdt, label="bf16", color="gray")
ax.bar([i + wdt/2 for i in x], nvt, wdt, label="NVFP4-fused", color="#d62728")
for i, r in enumerate(res):
ax.text(i, max(bft[i], nvt[i]) + .05, f"{bf[(1,r)]['total_s']/nv[(1,r)]['total_s']:.2f}×", ha="center", fontsize=8, color="#d62728")
ax.set_xticks(list(x)); ax.set_xticklabels(res); ax.set_ylabel("s / image ↓")
ax.set_title("End-to-end (batch 1, 4-step)"); ax.legend(fontsize=7)
# 5) throughput vs batch (bf16 saturation)
def plot_batch(ax):
for r, mk in [("512x512", "o-"), ("1024x1024", "s-")]:
bs = [1, 2, 4]; ax.plot(bs, [bf[(B, r)]["img_per_s"] for B in bs], mk, label=f"bf16 {r}", lw=1.6)
ax.set_xlabel("batch size"); ax.set_ylabel("img / s"); ax.set_xticks([1, 2, 4])
ax.set_title("Throughput vs batch — 4B is GPU-saturated"); ax.legend(fontsize=7)
# 6) VRAM
def plot_vram(ax):
res = ["512x512", "1024x1024"]; x = range(len(res)); wdt = 0.35
ax.bar([i - wdt/2 for i in x], [bf[(1, r)]["vram_gb"] for r in res], wdt, label="bf16", color="gray")
ax.bar([i + wdt/2 for i in x], [nv[(1, r)]["vram_gb"] for r in res], wdt, label="NVFP4-fused", color="#d62728")
ax.set_xticks(list(x)); ax.set_xticklabels(res); ax.set_ylabel("peak VRAM (GB) ↓")
ax.set_title("VRAM (−24%)"); ax.legend(fontsize=7)
panels = [("quality_vs_rank", plot_quality), ("pareto_quality_speed", plot_pareto),
("perlayer_speedup", plot_perlayer), ("end_to_end", plot_e2e),
("throughput_batch", plot_batch), ("vram", plot_vram)]
for nm, fn in panels:
f, ax, p = fig(nm); fn(ax); f.tight_layout(); f.savefig(p); plt.close(f); print("saved", p)
# grand 2x3 grid
F, axes = plt.subplots(2, 3, figsize=(15, 9))
for (nm, fn), ax in zip(panels, axes.flat):
fn(ax)
F.suptitle("FLUX.2 klein-4B — NVFP4 SVDQuant on RTX PRO 4500 Blackwell (sm_120): quality + speed", fontsize=13, y=1.00)
F.tight_layout(); F.savefig(f"{OUT}/GRID_overview.png", bbox_inches="tight"); plt.close(F)
print(f"saved {OUT}/GRID_overview.png")

Xet Storage Details

Size:
5.84 kB
·
Xet hash:
666d3a08bd0526c38630f31efc4a8f8f77c27168d49e1f63d4db080b4838a90b

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.