recap-t2i-evaluation-code-2026 / eval_code /scripts /plot_caption_survey_curves.py
Authors
Initial anonymous NeurIPS 2026 E&D code and results release
7f59fb7 verified
#!/usr/bin/env python3
"""Plot budget-curve metrics from caption survey JSON outputs."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
METRICS = [
("coverage_rate", "Budget Eligibility@B", "up"),
("distinct_n.2", "Distinct-2@B", "up"),
("distinct_n.3", "Distinct-3@B", "up"),
("ngram_top_k_mass.2", "Top-100 Bigram Mass@B", "down"),
("ngram_top_k_mass.3", "Top-100 Trigram Mass@B", "down"),
("violation_rate", "Violation Rate@B", "down"),
("repeated_4gram_rate", "Repeated 4-gram Rate@B", "down"),
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Plot caption survey budget curves")
parser.add_argument("--input", action="append", required=True, help="Survey JSON path (repeatable)")
parser.add_argument("--output-dir", required=True, help="Directory for output PNG plots")
parser.add_argument(
"--long-coverage-threshold",
type=float,
default=0.5,
help="budget-eligibility@64 threshold used to split long vs short regimes",
)
return parser.parse_args()
def nested_get(mapping: dict[str, Any], path: str) -> float | None:
current: Any = mapping
for part in path.split("."):
if not isinstance(current, dict) or part not in current:
return None
current = current[part]
return float(current) if isinstance(current, (int, float)) else None
def load_rows(paths: list[str]) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
for raw_path in paths:
payload = json.loads(Path(raw_path).read_text(encoding="utf-8"))
if "results" in payload:
for item in payload.get("results", []):
summary = item.get("summary") or item.get("survey_summary")
if not isinstance(summary, dict):
continue
entry = item.get("entry") or {}
length_controlled = summary.get("length_controlled") or {}
budgets = sorted(int(key) for key in length_controlled.keys())
if not budgets:
continue
cov64 = nested_get(length_controlled.get("64", {}), "coverage_rate") or 0.0
full = summary.get("full_length_reference") or {}
avg_tokens = full.get("avg_tokens", full.get("avg_lexical_tokens", 0.0))
rows.append(
{
"name": entry.get("name", Path(raw_path).stem),
"family": entry.get("source_family", "unknown"),
"group": entry.get("group", "unknown"),
"description": entry.get("description", ""),
"captioner": entry.get("captioner", ""),
"avg_tokens": float(avg_tokens),
"coverage64": float(cov64),
"budgets": budgets,
"length_controlled": length_controlled,
}
)
continue
if "length_controlled" in payload:
length_controlled = payload.get("length_controlled") or {}
budgets = sorted(int(key) for key in length_controlled.keys())
if not budgets:
continue
cov64 = nested_get(length_controlled.get("64", {}), "coverage_rate") or 0.0
full = payload.get("full_length_reference") or {}
avg_tokens = full.get("avg_tokens", full.get("avg_lexical_tokens", 0.0))
stem = Path(raw_path).stem
name = stem.removesuffix("_1m").removesuffix("_50k")
family = "unknown"
if "datacomp" in name:
family = "datacomp"
elif "pd12m" in name:
family = "pd12m"
rows.append(
{
"name": name,
"family": family,
"group": "direct_summary",
"description": "",
"captioner": "",
"avg_tokens": float(avg_tokens),
"coverage64": float(cov64),
"budgets": budgets,
"length_controlled": length_controlled,
}
)
return rows
def label_for_row(row: dict[str, Any]) -> str:
name = row["name"]
if name.startswith("ours_"):
label = f"ours:{name.removeprefix('ours_')}"
elif name.startswith("ref_"):
label = f"ref:{name.removeprefix('ref_')}"
else:
label = name
if name == "ref_cc12m_qwen3vl8b":
label += "†"
return label
def decorate_metric_label(metric_label: str, direction: str) -> str:
arrow = "↑" if direction == "up" else "↓"
return f"{metric_label} {arrow}"
def style_for_row(row: dict[str, Any]) -> dict[str, Any]:
if row["name"].startswith("ours_"):
return {"linewidth": 2.8, "alpha": 0.95, "linestyle": "-"}
return {"linewidth": 1.6, "alpha": 0.85, "linestyle": "--"}
def series_for_metric(row: dict[str, Any], metric_key: str) -> tuple[list[int], list[float]]:
xs: list[int] = []
ys: list[float] = []
for budget in row["budgets"]:
summary = row["length_controlled"].get(str(budget), {})
value = nested_get(summary, metric_key)
if value is None:
continue
xs.append(budget)
ys.append(value)
return xs, ys
def save_metric_plot(
rows: list[dict[str, Any]],
metric_key: str,
metric_label: str,
direction: str,
regime_name: str,
output_path: Path,
) -> None:
fig, ax = plt.subplots(figsize=(10.5, 6.2))
for row in sorted(rows, key=lambda item: (item["family"], item["name"])):
xs, ys = series_for_metric(row, metric_key)
if not xs:
continue
ax.plot(xs, ys, marker="o", label=label_for_row(row), **style_for_row(row))
decorated_label = decorate_metric_label(metric_label, direction)
ax.set_title(f"{decorated_label} by Budget ({regime_name})")
ax.set_xlabel("Token Budget")
ax.set_ylabel(decorated_label)
ax.set_xticks(sorted({budget for row in rows for budget in row["budgets"]}))
ax.grid(True, alpha=0.25)
ax.legend(fontsize=8, ncol=2)
output_path.parent.mkdir(parents=True, exist_ok=True)
fig.tight_layout()
fig.savefig(output_path, dpi=180)
plt.close(fig)
def save_family_plot(rows: list[dict[str, Any]], family: str, output_path: Path) -> None:
family_rows = [row for row in rows if row["family"] == family]
if not family_rows:
return
fig, axes = plt.subplots(2, 3, figsize=(14, 8.5))
axes = axes.flatten()
for axis, (metric_key, metric_label, direction) in zip(axes, METRICS[:6], strict=False):
for row in sorted(family_rows, key=lambda item: item["name"]):
xs, ys = series_for_metric(row, metric_key)
if not xs:
continue
axis.plot(xs, ys, marker="o", label=label_for_row(row), **style_for_row(row))
axis.set_title(decorate_metric_label(metric_label, direction))
axis.set_xlabel("Budget")
axis.grid(True, alpha=0.25)
handles, labels = axes[0].get_legend_handles_labels()
if handles:
fig.legend(handles, labels, loc="lower center", ncol=2, fontsize=8)
fig.suptitle(f"{family} Budget Curves", y=0.98)
fig.tight_layout(rect=(0, 0.05, 1, 0.96))
output_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(output_path, dpi=180)
plt.close(fig)
def main() -> int:
args = parse_args()
rows = load_rows(args.input)
if not rows:
raise SystemExit("No survey rows loaded")
output_dir = Path(args.output_dir)
long_rows = [row for row in rows if row["coverage64"] >= args.long_coverage_threshold]
short_rows = [row for row in rows if row["coverage64"] < args.long_coverage_threshold]
for metric_key, metric_label, direction in METRICS:
if long_rows:
save_metric_plot(
long_rows,
metric_key,
metric_label,
direction,
"long-regime",
output_dir / "overview" / "long" / f"{metric_key.replace('.', '_')}.png",
)
if short_rows:
save_metric_plot(
short_rows,
metric_key,
metric_label,
direction,
"short-regime",
output_dir / "overview" / "short" / f"{metric_key.replace('.', '_')}.png",
)
for family in sorted({row["family"] for row in rows}):
save_family_plot(rows, family, output_dir / "families" / f"{family}.png")
manifest = {
"inputs": args.input,
"output_dir": str(output_dir),
"long_coverage_threshold": args.long_coverage_threshold,
"rows_loaded": len(rows),
"long_rows": [row["name"] for row in long_rows],
"short_rows": [row["name"] for row in short_rows],
"metrics": [metric_key for metric_key, _, _ in METRICS],
}
(output_dir / "plot_manifest.json").write_text(json.dumps(manifest, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
print(json.dumps(manifest, indent=2, ensure_ascii=False))
return 0
if __name__ == "__main__":
raise SystemExit(main())