exps_e2e_nf / plot_cfg.py
xingjianleng's picture
Upload folder using huggingface_hub
aab061c verified
# === Global Variables ===
BASE_FONTSIZE = 16
SMALL_FONTSIZE = 14
MEDIUM_FONTSIZE = 15
LARGE_FONTSIZE = 20
LEGEND_FONTSIZE = 17
ANNOTATION_FONTSIZE = 15
LINEWIDTH = 6
AXIS_LINEWIDTH = 1.0
GRID_LINEWIDTH = 0.5
HLINE_LINEWIDTH = AXIS_LINEWIDTH * 2
MARKERSIZE = 10
MARKEREDGEWIDTH = AXIS_LINEWIDTH * 2
# =========================
import csv
import os
import re
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import NullFormatter
mpl.rcParams["font.family"] = "serif"
mpl.rcParams["font.serif"] = ["Times New Roman", "DejaVu Serif"]
mpl.rcParams["mathtext.fontset"] = "cm"
mpl.rcParams["axes.linewidth"] = AXIS_LINEWIDTH
mpl.rcParams["axes.labelsize"] = BASE_FONTSIZE
mpl.rcParams["xtick.labelsize"] = MEDIUM_FONTSIZE
mpl.rcParams["ytick.labelsize"] = MEDIUM_FONTSIZE
mpl.rcParams["legend.fontsize"] = LEGEND_FONTSIZE
mpl.rcParams["figure.dpi"] = 100
mpl.rcParams["savefig.dpi"] = 300
EXP_ROOT = os.path.dirname(os.path.abspath(__file__))
blue_color = "#3498db"
red_color = "#e74c3c"
colors = [blue_color, red_color]
# (exp_dir_name): (ckpt, optimal_gh_string, kind)
# kind is "LDM-only" (variational, separate VAE) or "Joint" (e2e training).
EXPS = {
"sit-xl-1-f16d32-e2e-novariational-logvar0.0": (
"0250000", "0.35", "Joint",
),
"sit-xl-1-f16d64-e2e-novariational-logvar0.0": (
"0250000", "0.55", "Joint",
),
"sit-xl-1-f16d32-ldm-imagenet256-f16d32-50e-250k": (
"0250000", "0.75", "LDM-only",
),
"sit-xl-1-f16d64-ldm-imagenet256-f16d64-50e-250k": (
"0250000", "0.825", "LDM-only",
),
"sit-xl-1-dinov2-b-enc8-repae-f16d32-scratch-0.5-0.5-400k": (
"0400000", "0.8", "LDM-only",
),
"sit-xl-1-dinov2-f16d32-e2e-novariational-logvar0.0-repae-coeff0.5": (
"0400000", "0.525", "Joint",
),
}
PAIRS = [
("repa", [
"sit-xl-1-dinov2-b-enc8-repae-f16d32-scratch-0.5-0.5-400k",
"sit-xl-1-dinov2-f16d32-e2e-novariational-logvar0.0-repae-coeff0.5",
], {1}, " + REPA"),
("f16d32", [
"sit-xl-1-f16d32-e2e-novariational-logvar0.0",
"sit-xl-1-f16d32-ldm-imagenet256-f16d32-50e-250k",
], set(), ""),
("f16d64", [
"sit-xl-1-f16d64-e2e-novariational-logvar0.0",
"sit-xl-1-f16d64-ldm-imagenet256-f16d64-50e-250k",
], set(), ""),
]
def load_curve(exp):
"""Return (cfgs, fids) sorted by cfg, restricted to the per-exp optimal gh."""
ckpt, gh, _label = EXPS[exp]
folder = os.path.join(EXP_ROOT, exp)
pat = re.compile(
rf"{re.escape(exp)}_{ckpt}_cfg(\d+(?:\.\d+)?)-0\.0-{re.escape(gh)}-labelsampling-equal\.csv$"
)
rows = []
for fn in os.listdir(folder):
m = pat.match(fn)
if not m:
continue
cfg = float(m.group(1))
with open(os.path.join(folder, fn)) as f:
r = list(csv.DictReader(f))
if r and "FID" in r[0]:
rows.append((cfg, float(r[0]["FID"])))
rows.sort()
if not rows:
return [], []
cfgs, fids = zip(*rows)
return list(cfgs), list(fids)
def plot_pair(pair_name, exp_list, output_dir, flip_below=frozenset(), label_suffix=""):
fig, ax = plt.subplots(figsize=(7, 6), dpi=300)
all_fids = []
for i, exp in enumerate(exp_list):
_ckpt, gh, kind = EXPS[exp]
label = f"{kind}{label_suffix} (Intv. [0.0, {gh}])"
cfgs, fids = load_curve(exp)
if not cfgs:
print(f"[Warn] No curve points for {exp} at gh={gh}")
continue
all_fids.extend(fids)
ax.plot(
cfgs, fids,
label=label,
marker="o",
markersize=MARKERSIZE,
linewidth=LINEWIDTH,
linestyle="-",
color=colors[i],
markerfacecolor="white",
markeredgecolor=colors[i],
alpha=0.9,
markeredgewidth=MARKEREDGEWIDTH,
)
# Mark + annotate the minimum (best CFG point) with a horizontal reference.
min_idx = int(np.argmin(fids))
best_cfg, best_fid = cfgs[min_idx], fids[min_idx]
ax.axhline(
y=best_fid,
color=colors[i],
linestyle="--",
linewidth=HLINE_LINEWIDTH,
alpha=0.7,
)
if i in flip_below:
ax.text(
best_cfg, best_fid * 0.985,
f"{best_fid:.2f}",
ha="center", va="top",
fontsize=ANNOTATION_FONTSIZE, color=colors[i],
zorder=5,
)
else:
ax.text(
best_cfg, best_fid * 1.02,
f"{best_fid:.2f}",
ha="center", va="bottom",
fontsize=ANNOTATION_FONTSIZE, color=colors[i],
zorder=5,
)
if not all_fids:
plt.close(fig)
print(f"[Warn] Skipping {pair_name}: no data.")
return
ax.legend(
fontsize=LEGEND_FONTSIZE,
framealpha=0.95,
loc="best",
frameon=False,
)
ax.grid(True, alpha=0.3, linestyle="--", linewidth=GRID_LINEWIDTH)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("CFG Scale", fontsize=SMALL_FONTSIZE)
ax.set_ylabel("⟵ FID Score", fontsize=LARGE_FONTSIZE)
# Auto: log scale when the dynamic range is large.
fid_min, fid_max = float(np.nanmin(all_fids)), float(np.nanmax(all_fids))
if fid_max / max(fid_min, 1e-6) > 3:
ax.set_yscale("log")
tick_candidates = np.array([1.5, 2, 3, 4, 5, 6, 8, 10, 15, 20, 25, 30, 40, 50])
yticks = tick_candidates[
(tick_candidates >= fid_min * 0.9) & (tick_candidates <= fid_max * 1.1)
]
if len(yticks):
ax.set_yticks(yticks)
ax.set_yticklabels([f"{t:g}" for t in yticks])
ax.yaxis.set_minor_locator(plt.NullLocator())
ax.yaxis.set_minor_formatter(NullFormatter())
plt.tight_layout()
os.makedirs(output_dir, exist_ok=True)
out_png = os.path.join(output_dir, f"{pair_name}.png")
out_pdf = os.path.join(output_dir, f"{pair_name}.pdf")
plt.savefig(out_png)
plt.savefig(out_pdf)
plt.close(fig)
print(f"Wrote {out_png} and {out_pdf}")
if __name__ == "__main__":
output_dir = os.path.join(EXP_ROOT, "cfg_plots")
for pair_name, exp_list, flip_below, label_suffix in PAIRS:
plot_pair(pair_name, exp_list, output_dir, flip_below=flip_below, label_suffix=label_suffix)