| |
| BASE_FONTSIZE = 16 |
| SMALL_FONTSIZE = 14 |
| MEDIUM_FONTSIZE = 15 |
| LARGE_FONTSIZE = 20 |
| LEGEND_FONTSIZE = 17 |
| ANNOTATION_FONTSIZE = 15 |
|
|
| LINEWIDTH = 6 |
| AXIS_LINEWIDTH = 1.0 |
| GRID_LINEWIDTH = 0.5 |
| HLINE_LINEWIDTH = AXIS_LINEWIDTH * 2 |
| MARKERSIZE = 10 |
| MARKEREDGEWIDTH = AXIS_LINEWIDTH * 2 |
| |
|
|
| import csv |
| import os |
| import re |
|
|
| import matplotlib as mpl |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from matplotlib.ticker import NullFormatter |
|
|
| mpl.rcParams["font.family"] = "serif" |
| mpl.rcParams["font.serif"] = ["Times New Roman", "DejaVu Serif"] |
| mpl.rcParams["mathtext.fontset"] = "cm" |
| mpl.rcParams["axes.linewidth"] = AXIS_LINEWIDTH |
| mpl.rcParams["axes.labelsize"] = BASE_FONTSIZE |
| mpl.rcParams["xtick.labelsize"] = MEDIUM_FONTSIZE |
| mpl.rcParams["ytick.labelsize"] = MEDIUM_FONTSIZE |
| mpl.rcParams["legend.fontsize"] = LEGEND_FONTSIZE |
| mpl.rcParams["figure.dpi"] = 100 |
| mpl.rcParams["savefig.dpi"] = 300 |
|
|
|
|
| EXP_ROOT = os.path.dirname(os.path.abspath(__file__)) |
|
|
| blue_color = "#3498db" |
| red_color = "#e74c3c" |
| colors = [blue_color, red_color] |
|
|
|
|
| |
| |
| EXPS = { |
| "sit-xl-1-f16d32-e2e-novariational-logvar0.0": ( |
| "0250000", "0.35", "Joint", |
| ), |
| "sit-xl-1-f16d64-e2e-novariational-logvar0.0": ( |
| "0250000", "0.55", "Joint", |
| ), |
| "sit-xl-1-f16d32-ldm-imagenet256-f16d32-50e-250k": ( |
| "0250000", "0.75", "LDM-only", |
| ), |
| "sit-xl-1-f16d64-ldm-imagenet256-f16d64-50e-250k": ( |
| "0250000", "0.825", "LDM-only", |
| ), |
| "sit-xl-1-dinov2-b-enc8-repae-f16d32-scratch-0.5-0.5-400k": ( |
| "0400000", "0.8", "LDM-only", |
| ), |
| "sit-xl-1-dinov2-f16d32-e2e-novariational-logvar0.0-repae-coeff0.5": ( |
| "0400000", "0.525", "Joint", |
| ), |
| } |
|
|
|
|
| PAIRS = [ |
| ("repa", [ |
| "sit-xl-1-dinov2-b-enc8-repae-f16d32-scratch-0.5-0.5-400k", |
| "sit-xl-1-dinov2-f16d32-e2e-novariational-logvar0.0-repae-coeff0.5", |
| ], {1}, " + REPA"), |
| ("f16d32", [ |
| "sit-xl-1-f16d32-e2e-novariational-logvar0.0", |
| "sit-xl-1-f16d32-ldm-imagenet256-f16d32-50e-250k", |
| ], set(), ""), |
| ("f16d64", [ |
| "sit-xl-1-f16d64-e2e-novariational-logvar0.0", |
| "sit-xl-1-f16d64-ldm-imagenet256-f16d64-50e-250k", |
| ], set(), ""), |
| ] |
|
|
|
|
| def load_curve(exp): |
| """Return (cfgs, fids) sorted by cfg, restricted to the per-exp optimal gh.""" |
| ckpt, gh, _label = EXPS[exp] |
| folder = os.path.join(EXP_ROOT, exp) |
| pat = re.compile( |
| rf"{re.escape(exp)}_{ckpt}_cfg(\d+(?:\.\d+)?)-0\.0-{re.escape(gh)}-labelsampling-equal\.csv$" |
| ) |
| rows = [] |
| for fn in os.listdir(folder): |
| m = pat.match(fn) |
| if not m: |
| continue |
| cfg = float(m.group(1)) |
| with open(os.path.join(folder, fn)) as f: |
| r = list(csv.DictReader(f)) |
| if r and "FID" in r[0]: |
| rows.append((cfg, float(r[0]["FID"]))) |
| rows.sort() |
| if not rows: |
| return [], [] |
| cfgs, fids = zip(*rows) |
| return list(cfgs), list(fids) |
|
|
|
|
| def plot_pair(pair_name, exp_list, output_dir, flip_below=frozenset(), label_suffix=""): |
| fig, ax = plt.subplots(figsize=(7, 6), dpi=300) |
|
|
| all_fids = [] |
| for i, exp in enumerate(exp_list): |
| _ckpt, gh, kind = EXPS[exp] |
| label = f"{kind}{label_suffix} (Intv. [0.0, {gh}])" |
| cfgs, fids = load_curve(exp) |
| if not cfgs: |
| print(f"[Warn] No curve points for {exp} at gh={gh}") |
| continue |
| all_fids.extend(fids) |
|
|
| ax.plot( |
| cfgs, fids, |
| label=label, |
| marker="o", |
| markersize=MARKERSIZE, |
| linewidth=LINEWIDTH, |
| linestyle="-", |
| color=colors[i], |
| markerfacecolor="white", |
| markeredgecolor=colors[i], |
| alpha=0.9, |
| markeredgewidth=MARKEREDGEWIDTH, |
| ) |
|
|
| |
| min_idx = int(np.argmin(fids)) |
| best_cfg, best_fid = cfgs[min_idx], fids[min_idx] |
| ax.axhline( |
| y=best_fid, |
| color=colors[i], |
| linestyle="--", |
| linewidth=HLINE_LINEWIDTH, |
| alpha=0.7, |
| ) |
| if i in flip_below: |
| ax.text( |
| best_cfg, best_fid * 0.985, |
| f"{best_fid:.2f}", |
| ha="center", va="top", |
| fontsize=ANNOTATION_FONTSIZE, color=colors[i], |
| zorder=5, |
| ) |
| else: |
| ax.text( |
| best_cfg, best_fid * 1.02, |
| f"{best_fid:.2f}", |
| ha="center", va="bottom", |
| fontsize=ANNOTATION_FONTSIZE, color=colors[i], |
| zorder=5, |
| ) |
|
|
| if not all_fids: |
| plt.close(fig) |
| print(f"[Warn] Skipping {pair_name}: no data.") |
| return |
|
|
| ax.legend( |
| fontsize=LEGEND_FONTSIZE, |
| framealpha=0.95, |
| loc="best", |
| frameon=False, |
| ) |
| ax.grid(True, alpha=0.3, linestyle="--", linewidth=GRID_LINEWIDTH) |
| ax.spines["top"].set_visible(False) |
| ax.spines["right"].set_visible(False) |
|
|
| ax.set_xlabel("CFG Scale", fontsize=SMALL_FONTSIZE) |
| ax.set_ylabel("⟵ FID Score", fontsize=LARGE_FONTSIZE) |
|
|
| |
| fid_min, fid_max = float(np.nanmin(all_fids)), float(np.nanmax(all_fids)) |
| if fid_max / max(fid_min, 1e-6) > 3: |
| ax.set_yscale("log") |
| tick_candidates = np.array([1.5, 2, 3, 4, 5, 6, 8, 10, 15, 20, 25, 30, 40, 50]) |
| yticks = tick_candidates[ |
| (tick_candidates >= fid_min * 0.9) & (tick_candidates <= fid_max * 1.1) |
| ] |
| if len(yticks): |
| ax.set_yticks(yticks) |
| ax.set_yticklabels([f"{t:g}" for t in yticks]) |
| ax.yaxis.set_minor_locator(plt.NullLocator()) |
| ax.yaxis.set_minor_formatter(NullFormatter()) |
|
|
| plt.tight_layout() |
| os.makedirs(output_dir, exist_ok=True) |
| out_png = os.path.join(output_dir, f"{pair_name}.png") |
| out_pdf = os.path.join(output_dir, f"{pair_name}.pdf") |
| plt.savefig(out_png) |
| plt.savefig(out_pdf) |
| plt.close(fig) |
| print(f"Wrote {out_png} and {out_pdf}") |
|
|
|
|
| if __name__ == "__main__": |
| output_dir = os.path.join(EXP_ROOT, "cfg_plots") |
| for pair_name, exp_list, flip_below, label_suffix in PAIRS: |
| plot_pair(pair_name, exp_list, output_dir, flip_below=flip_below, label_suffix=label_suffix) |
|
|