Buckets:
| """Head-to-head figures + sensitive-probe montage. | |
| Reads outputs/eval/h2h/metrics.json (quality) + outputs/nvfp4/benchmark_headtohead.json (speed) and | |
| writes report/figures/h2h_*.png. Montage reads outputs/eval/h2h_probes/<model>/{0,1,2}.png. | |
| Usage: python3 scripts/42_h2h_figures.py | |
| """ | |
| import os, json | |
| import matplotlib; matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from PIL import Image, ImageDraw | |
| OUT = "report/figures"; os.makedirs(OUT, exist_ok=True) | |
| plt.rcParams.update({"figure.dpi": 130, "font.size": 9, "axes.grid": True, "grid.alpha": 0.3, "axes.axisbelow": True}) | |
| M = json.load(open("outputs/eval/h2h/metrics.json")) | |
| # map metrics.json dir-keys -> display labels (order = teacher, plain r0, ours r128 real, BFL fp8) | |
| LABEL = [ | |
| ("ours128", "ours\nNVFP4 r128", "#d62728"), | |
| ("fq0", "plain\nNVFP4 r0", "#1f77b4"), | |
| ("fq128", "ours r128\n(fake-q)", "#ff9896"), | |
| ("bfl_fp8", "BFL fp8\n(W8A8)", "#2ca02c"), | |
| ] | |
| def get(key, metric): | |
| for k, v in M.items(): | |
| if key in k: | |
| return v.get(metric) | |
| return None | |
| present = [(k, lbl, c) for (k, lbl, c) in LABEL if any(k in mk for mk in M)] | |
| def bar(metric, title, fname, lower_better=True): | |
| xs = [lbl for _, lbl, _ in present]; ys = [get(k, metric) for k, _, _ in present]; cs = [c for _, _, c in present] | |
| if any(y is None for y in ys): print(f"skip {fname}: missing {metric}"); return | |
| f, ax = plt.subplots(figsize=(5, 3.4)) | |
| ax.bar(xs, ys, color=cs) | |
| for i, y in enumerate(ys): ax.text(i, y, f"{y:.3f}" if y < 10 else f"{y:.1f}", ha="center", va="bottom", fontsize=8) | |
| ax.set_title(title + (" (lower=closer to teacher)" if lower_better else " (higher=closer)")); ax.set_ylabel(metric) | |
| f.tight_layout(); f.savefig(f"{OUT}/{fname}.png"); plt.close(f); print("saved", fname) | |
| bar("LPIPS", "LPIPS vs teacher", "h2h_lpips") | |
| bar("PSNR", "PSNR vs teacher", "h2h_psnr", lower_better=False) | |
| bar("FID_vs_teacher", "FID vs teacher", "h2h_fid_teacher") | |
| # Pareto: fidelity (PSNR vs teacher) vs Blackwell end-to-end speedup | |
| try: | |
| spd = json.load(open("outputs/nvfp4/benchmark_headtohead.json")) | |
| pts = [] | |
| e2e = spd["end_to_end"]["512px_4step_bs1"] | |
| pts.append((1.0, get("fq0", "PSNR"), "plain r0", "#1f77b4")) | |
| pts.append((e2e["ours_nvfp4_r128_fused"]["speedup"], get("ours128", "PSNR"), "ours r128 (real)", "#d62728")) | |
| f, ax = plt.subplots(figsize=(5, 3.6)) | |
| for x, y, l, c in pts: | |
| if y is None: continue | |
| ax.scatter([x], [y], c=c, s=70); ax.annotate(l, (x, y), textcoords="offset points", xytext=(6, 4), fontsize=8) | |
| ax.set_xlabel("end-to-end speedup vs bf16 (512px, Blackwell)"); ax.set_ylabel("PSNR vs teacher (higher=closer)") | |
| ax.set_title("Quality (fidelity) vs speed"); f.tight_layout(); f.savefig(f"{OUT}/h2h_pareto.png"); plt.close(f); print("saved h2h_pareto") | |
| except Exception as e: | |
| print("pareto skipped:", e) | |
| # Speed + VRAM bars (teacher vs ours, 512 & 1024) | |
| try: | |
| spd = json.load(open("outputs/nvfp4/benchmark_headtohead.json"))["end_to_end"] | |
| res = ["512px_4step_bs1", "1024px_4step_bs1"] | |
| for metric, fname, ttl in [("s_img", "h2h_speed", "s/img (lower=faster)"), ("vram_gb", "h2h_vram", "peak VRAM GB")]: | |
| f, ax = plt.subplots(figsize=(5, 3.4)); x = range(len(res)); w = 0.35 | |
| tv = [spd[r]["teacher_bf16"][metric] for r in res] | |
| ov = [spd[r]["ours_nvfp4_r128_fused"][metric] for r in res] | |
| ax.bar([i - w/2 for i in x], tv, w, label="bf16 teacher", color="#7f7f7f") | |
| ax.bar([i + w/2 for i in x], ov, w, label="ours NVFP4 r128", color="#d62728") | |
| ax.set_xticks(list(x)); ax.set_xticklabels(["512px", "1024px"]); ax.set_title(ttl); ax.legend() | |
| f.tight_layout(); f.savefig(f"{OUT}/{fname}.png"); plt.close(f); print("saved", fname) | |
| except Exception as e: | |
| print("speed/vram skipped:", e) | |
| # Montage: teacher | ours r128 | plain r0 | BFL fp8 on the 3 sensitive probes | |
| cols = [("teacher", "outputs/eval/h2h_probes/teacher"), | |
| ("ours NVFP4 r128", "outputs/eval/h2h_probes/ours128"), | |
| ("plain NVFP4 r0", "outputs/eval/h2h_probes/fq0"), | |
| ("BFL fp8 (W8A8)", "outputs/eval/h2h_probes/bfl_fp8")] | |
| cols = [(l, d) for l, d in cols if os.path.isdir(d)] | |
| probes = json.load(open("outputs/eval/probes.json")) | |
| if cols: | |
| S, pad, top, lft = 320, 6, 24, 110 | |
| W = lft + len(cols) * (S + pad) + pad; H = top + len(probes) * (S + pad) + pad | |
| canvas = Image.new("RGB", (W, H), "white"); dr = ImageDraw.Draw(canvas) | |
| for j, (lbl, _) in enumerate(cols): | |
| dr.text((lft + j * (S + pad) + 4, 6), lbl, fill="black") | |
| for i, p in enumerate(probes): | |
| dr.text((4, top + i * (S + pad) + S // 2), p["category"], fill="black") | |
| for j, (_, d) in enumerate(cols): | |
| fp = os.path.join(d, f"{p['idx']:05d}.png") | |
| if os.path.exists(fp): | |
| canvas.paste(Image.open(fp).convert("RGB").resize((S, S)), (lft + j * (S + pad), top + i * (S + pad))) | |
| canvas.save(f"{OUT}/h2h_montage.png"); print("saved h2h_montage") | |
| else: | |
| print("montage skipped: no probe dirs yet") | |
Xet Storage Details
- Size:
- 5.06 kB
- Xet hash:
- d339027f9ff25f70494b468a98ebde2610a0b2cc86a8ab33fc1d2cb96f257ef5
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.