# === Global Variables === BASE_FONTSIZE = 16 SMALL_FONTSIZE = 14 MEDIUM_FONTSIZE = 15 LARGE_FONTSIZE = 20 LEGEND_FONTSIZE = 17 ANNOTATION_FONTSIZE = 15 LINEWIDTH = 6 AXIS_LINEWIDTH = 1.0 GRID_LINEWIDTH = 0.5 HLINE_LINEWIDTH = AXIS_LINEWIDTH * 2 MARKERSIZE = 10 MARKEREDGEWIDTH = AXIS_LINEWIDTH * 2 # ========================= import csv import os import re import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np from matplotlib.ticker import NullFormatter mpl.rcParams["font.family"] = "serif" mpl.rcParams["font.serif"] = ["Times New Roman", "DejaVu Serif"] mpl.rcParams["mathtext.fontset"] = "cm" mpl.rcParams["axes.linewidth"] = AXIS_LINEWIDTH mpl.rcParams["axes.labelsize"] = BASE_FONTSIZE mpl.rcParams["xtick.labelsize"] = MEDIUM_FONTSIZE mpl.rcParams["ytick.labelsize"] = MEDIUM_FONTSIZE mpl.rcParams["legend.fontsize"] = LEGEND_FONTSIZE mpl.rcParams["figure.dpi"] = 100 mpl.rcParams["savefig.dpi"] = 300 EXP_ROOT = os.path.dirname(os.path.abspath(__file__)) blue_color = "#3498db" red_color = "#e74c3c" colors = [blue_color, red_color] # (exp_dir_name): (ckpt, optimal_gh_string, kind) # kind is "LDM-only" (variational, separate VAE) or "Joint" (e2e training). EXPS = { "sit-xl-1-f16d32-e2e-novariational-logvar0.0": ( "0250000", "0.35", "Joint", ), "sit-xl-1-f16d64-e2e-novariational-logvar0.0": ( "0250000", "0.55", "Joint", ), "sit-xl-1-f16d32-ldm-imagenet256-f16d32-50e-250k": ( "0250000", "0.75", "LDM-only", ), "sit-xl-1-f16d64-ldm-imagenet256-f16d64-50e-250k": ( "0250000", "0.825", "LDM-only", ), "sit-xl-1-dinov2-b-enc8-repae-f16d32-scratch-0.5-0.5-400k": ( "0400000", "0.8", "LDM-only", ), "sit-xl-1-dinov2-f16d32-e2e-novariational-logvar0.0-repae-coeff0.5": ( "0400000", "0.525", "Joint", ), } PAIRS = [ ("repa", [ "sit-xl-1-dinov2-b-enc8-repae-f16d32-scratch-0.5-0.5-400k", "sit-xl-1-dinov2-f16d32-e2e-novariational-logvar0.0-repae-coeff0.5", ], {1}, " + REPA"), ("f16d32", [ "sit-xl-1-f16d32-e2e-novariational-logvar0.0", "sit-xl-1-f16d32-ldm-imagenet256-f16d32-50e-250k", ], set(), ""), ("f16d64", [ "sit-xl-1-f16d64-e2e-novariational-logvar0.0", "sit-xl-1-f16d64-ldm-imagenet256-f16d64-50e-250k", ], set(), ""), ] def load_curve(exp): """Return (cfgs, fids) sorted by cfg, restricted to the per-exp optimal gh.""" ckpt, gh, _label = EXPS[exp] folder = os.path.join(EXP_ROOT, exp) pat = re.compile( rf"{re.escape(exp)}_{ckpt}_cfg(\d+(?:\.\d+)?)-0\.0-{re.escape(gh)}-labelsampling-equal\.csv$" ) rows = [] for fn in os.listdir(folder): m = pat.match(fn) if not m: continue cfg = float(m.group(1)) with open(os.path.join(folder, fn)) as f: r = list(csv.DictReader(f)) if r and "FID" in r[0]: rows.append((cfg, float(r[0]["FID"]))) rows.sort() if not rows: return [], [] cfgs, fids = zip(*rows) return list(cfgs), list(fids) def plot_pair(pair_name, exp_list, output_dir, flip_below=frozenset(), label_suffix=""): fig, ax = plt.subplots(figsize=(7, 6), dpi=300) all_fids = [] for i, exp in enumerate(exp_list): _ckpt, gh, kind = EXPS[exp] label = f"{kind}{label_suffix} (Intv. [0.0, {gh}])" cfgs, fids = load_curve(exp) if not cfgs: print(f"[Warn] No curve points for {exp} at gh={gh}") continue all_fids.extend(fids) ax.plot( cfgs, fids, label=label, marker="o", markersize=MARKERSIZE, linewidth=LINEWIDTH, linestyle="-", color=colors[i], markerfacecolor="white", markeredgecolor=colors[i], alpha=0.9, markeredgewidth=MARKEREDGEWIDTH, ) # Mark + annotate the minimum (best CFG point) with a horizontal reference. min_idx = int(np.argmin(fids)) best_cfg, best_fid = cfgs[min_idx], fids[min_idx] ax.axhline( y=best_fid, color=colors[i], linestyle="--", linewidth=HLINE_LINEWIDTH, alpha=0.7, ) if i in flip_below: ax.text( best_cfg, best_fid * 0.985, f"{best_fid:.2f}", ha="center", va="top", fontsize=ANNOTATION_FONTSIZE, color=colors[i], zorder=5, ) else: ax.text( best_cfg, best_fid * 1.02, f"{best_fid:.2f}", ha="center", va="bottom", fontsize=ANNOTATION_FONTSIZE, color=colors[i], zorder=5, ) if not all_fids: plt.close(fig) print(f"[Warn] Skipping {pair_name}: no data.") return ax.legend( fontsize=LEGEND_FONTSIZE, framealpha=0.95, loc="best", frameon=False, ) ax.grid(True, alpha=0.3, linestyle="--", linewidth=GRID_LINEWIDTH) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.set_xlabel("CFG Scale", fontsize=SMALL_FONTSIZE) ax.set_ylabel("⟵ FID Score", fontsize=LARGE_FONTSIZE) # Auto: log scale when the dynamic range is large. fid_min, fid_max = float(np.nanmin(all_fids)), float(np.nanmax(all_fids)) if fid_max / max(fid_min, 1e-6) > 3: ax.set_yscale("log") tick_candidates = np.array([1.5, 2, 3, 4, 5, 6, 8, 10, 15, 20, 25, 30, 40, 50]) yticks = tick_candidates[ (tick_candidates >= fid_min * 0.9) & (tick_candidates <= fid_max * 1.1) ] if len(yticks): ax.set_yticks(yticks) ax.set_yticklabels([f"{t:g}" for t in yticks]) ax.yaxis.set_minor_locator(plt.NullLocator()) ax.yaxis.set_minor_formatter(NullFormatter()) plt.tight_layout() os.makedirs(output_dir, exist_ok=True) out_png = os.path.join(output_dir, f"{pair_name}.png") out_pdf = os.path.join(output_dir, f"{pair_name}.pdf") plt.savefig(out_png) plt.savefig(out_pdf) plt.close(fig) print(f"Wrote {out_png} and {out_pdf}") if __name__ == "__main__": output_dir = os.path.join(EXP_ROOT, "cfg_plots") for pair_name, exp_list, flip_below, label_suffix in PAIRS: plot_pair(pair_name, exp_list, output_dir, flip_below=flip_below, label_suffix=label_suffix)