Buckets:

Mercity/FluxDistill / scripts /42_h2h_figures.py
Pranav2748's picture
download
raw
5.06 kB
"""Head-to-head figures + sensitive-probe montage.
Reads outputs/eval/h2h/metrics.json (quality) + outputs/nvfp4/benchmark_headtohead.json (speed) and
writes report/figures/h2h_*.png. Montage reads outputs/eval/h2h_probes/<model>/{0,1,2}.png.
Usage: python3 scripts/42_h2h_figures.py
"""
import os, json
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
OUT = "report/figures"; os.makedirs(OUT, exist_ok=True)
plt.rcParams.update({"figure.dpi": 130, "font.size": 9, "axes.grid": True, "grid.alpha": 0.3, "axes.axisbelow": True})
M = json.load(open("outputs/eval/h2h/metrics.json"))
# map metrics.json dir-keys -> display labels (order = teacher, plain r0, ours r128 real, BFL fp8)
LABEL = [
("ours128", "ours\nNVFP4 r128", "#d62728"),
("fq0", "plain\nNVFP4 r0", "#1f77b4"),
("fq128", "ours r128\n(fake-q)", "#ff9896"),
("bfl_fp8", "BFL fp8\n(W8A8)", "#2ca02c"),
]
def get(key, metric):
for k, v in M.items():
if key in k:
return v.get(metric)
return None
present = [(k, lbl, c) for (k, lbl, c) in LABEL if any(k in mk for mk in M)]
def bar(metric, title, fname, lower_better=True):
xs = [lbl for _, lbl, _ in present]; ys = [get(k, metric) for k, _, _ in present]; cs = [c for _, _, c in present]
if any(y is None for y in ys): print(f"skip {fname}: missing {metric}"); return
f, ax = plt.subplots(figsize=(5, 3.4))
ax.bar(xs, ys, color=cs)
for i, y in enumerate(ys): ax.text(i, y, f"{y:.3f}" if y < 10 else f"{y:.1f}", ha="center", va="bottom", fontsize=8)
ax.set_title(title + (" (lower=closer to teacher)" if lower_better else " (higher=closer)")); ax.set_ylabel(metric)
f.tight_layout(); f.savefig(f"{OUT}/{fname}.png"); plt.close(f); print("saved", fname)
bar("LPIPS", "LPIPS vs teacher", "h2h_lpips")
bar("PSNR", "PSNR vs teacher", "h2h_psnr", lower_better=False)
bar("FID_vs_teacher", "FID vs teacher", "h2h_fid_teacher")
# Pareto: fidelity (PSNR vs teacher) vs Blackwell end-to-end speedup
try:
spd = json.load(open("outputs/nvfp4/benchmark_headtohead.json"))
pts = []
e2e = spd["end_to_end"]["512px_4step_bs1"]
pts.append((1.0, get("fq0", "PSNR"), "plain r0", "#1f77b4"))
pts.append((e2e["ours_nvfp4_r128_fused"]["speedup"], get("ours128", "PSNR"), "ours r128 (real)", "#d62728"))
f, ax = plt.subplots(figsize=(5, 3.6))
for x, y, l, c in pts:
if y is None: continue
ax.scatter([x], [y], c=c, s=70); ax.annotate(l, (x, y), textcoords="offset points", xytext=(6, 4), fontsize=8)
ax.set_xlabel("end-to-end speedup vs bf16 (512px, Blackwell)"); ax.set_ylabel("PSNR vs teacher (higher=closer)")
ax.set_title("Quality (fidelity) vs speed"); f.tight_layout(); f.savefig(f"{OUT}/h2h_pareto.png"); plt.close(f); print("saved h2h_pareto")
except Exception as e:
print("pareto skipped:", e)
# Speed + VRAM bars (teacher vs ours, 512 & 1024)
try:
spd = json.load(open("outputs/nvfp4/benchmark_headtohead.json"))["end_to_end"]
res = ["512px_4step_bs1", "1024px_4step_bs1"]
for metric, fname, ttl in [("s_img", "h2h_speed", "s/img (lower=faster)"), ("vram_gb", "h2h_vram", "peak VRAM GB")]:
f, ax = plt.subplots(figsize=(5, 3.4)); x = range(len(res)); w = 0.35
tv = [spd[r]["teacher_bf16"][metric] for r in res]
ov = [spd[r]["ours_nvfp4_r128_fused"][metric] for r in res]
ax.bar([i - w/2 for i in x], tv, w, label="bf16 teacher", color="#7f7f7f")
ax.bar([i + w/2 for i in x], ov, w, label="ours NVFP4 r128", color="#d62728")
ax.set_xticks(list(x)); ax.set_xticklabels(["512px", "1024px"]); ax.set_title(ttl); ax.legend()
f.tight_layout(); f.savefig(f"{OUT}/{fname}.png"); plt.close(f); print("saved", fname)
except Exception as e:
print("speed/vram skipped:", e)
# Montage: teacher | ours r128 | plain r0 | BFL fp8 on the 3 sensitive probes
cols = [("teacher", "outputs/eval/h2h_probes/teacher"),
("ours NVFP4 r128", "outputs/eval/h2h_probes/ours128"),
("plain NVFP4 r0", "outputs/eval/h2h_probes/fq0"),
("BFL fp8 (W8A8)", "outputs/eval/h2h_probes/bfl_fp8")]
cols = [(l, d) for l, d in cols if os.path.isdir(d)]
probes = json.load(open("outputs/eval/probes.json"))
if cols:
S, pad, top, lft = 320, 6, 24, 110
W = lft + len(cols) * (S + pad) + pad; H = top + len(probes) * (S + pad) + pad
canvas = Image.new("RGB", (W, H), "white"); dr = ImageDraw.Draw(canvas)
for j, (lbl, _) in enumerate(cols):
dr.text((lft + j * (S + pad) + 4, 6), lbl, fill="black")
for i, p in enumerate(probes):
dr.text((4, top + i * (S + pad) + S // 2), p["category"], fill="black")
for j, (_, d) in enumerate(cols):
fp = os.path.join(d, f"{p['idx']:05d}.png")
if os.path.exists(fp):
canvas.paste(Image.open(fp).convert("RGB").resize((S, S)), (lft + j * (S + pad), top + i * (S + pad)))
canvas.save(f"{OUT}/h2h_montage.png"); print("saved h2h_montage")
else:
print("montage skipped: no probe dirs yet")

Xet Storage Details

Size:
5.06 kB
·
Xet hash:
d339027f9ff25f70494b468a98ebde2610a0b2cc86a8ab33fc1d2cb96f257ef5

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.