| |
| """Plot budget-curve metrics from caption survey JSON outputs.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| from typing import Any |
|
|
| import matplotlib |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
|
|
| METRICS = [ |
| ("coverage_rate", "Budget Eligibility@B", "up"), |
| ("distinct_n.2", "Distinct-2@B", "up"), |
| ("distinct_n.3", "Distinct-3@B", "up"), |
| ("ngram_top_k_mass.2", "Top-100 Bigram Mass@B", "down"), |
| ("ngram_top_k_mass.3", "Top-100 Trigram Mass@B", "down"), |
| ("violation_rate", "Violation Rate@B", "down"), |
| ("repeated_4gram_rate", "Repeated 4-gram Rate@B", "down"), |
| ] |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Plot caption survey budget curves") |
| parser.add_argument("--input", action="append", required=True, help="Survey JSON path (repeatable)") |
| parser.add_argument("--output-dir", required=True, help="Directory for output PNG plots") |
| parser.add_argument( |
| "--long-coverage-threshold", |
| type=float, |
| default=0.5, |
| help="budget-eligibility@64 threshold used to split long vs short regimes", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def nested_get(mapping: dict[str, Any], path: str) -> float | None: |
| current: Any = mapping |
| for part in path.split("."): |
| if not isinstance(current, dict) or part not in current: |
| return None |
| current = current[part] |
| return float(current) if isinstance(current, (int, float)) else None |
|
|
|
|
| def load_rows(paths: list[str]) -> list[dict[str, Any]]: |
| rows: list[dict[str, Any]] = [] |
| for raw_path in paths: |
| payload = json.loads(Path(raw_path).read_text(encoding="utf-8")) |
| if "results" in payload: |
| for item in payload.get("results", []): |
| summary = item.get("summary") or item.get("survey_summary") |
| if not isinstance(summary, dict): |
| continue |
| entry = item.get("entry") or {} |
| length_controlled = summary.get("length_controlled") or {} |
| budgets = sorted(int(key) for key in length_controlled.keys()) |
| if not budgets: |
| continue |
| cov64 = nested_get(length_controlled.get("64", {}), "coverage_rate") or 0.0 |
| full = summary.get("full_length_reference") or {} |
| avg_tokens = full.get("avg_tokens", full.get("avg_lexical_tokens", 0.0)) |
| rows.append( |
| { |
| "name": entry.get("name", Path(raw_path).stem), |
| "family": entry.get("source_family", "unknown"), |
| "group": entry.get("group", "unknown"), |
| "description": entry.get("description", ""), |
| "captioner": entry.get("captioner", ""), |
| "avg_tokens": float(avg_tokens), |
| "coverage64": float(cov64), |
| "budgets": budgets, |
| "length_controlled": length_controlled, |
| } |
| ) |
| continue |
|
|
| if "length_controlled" in payload: |
| length_controlled = payload.get("length_controlled") or {} |
| budgets = sorted(int(key) for key in length_controlled.keys()) |
| if not budgets: |
| continue |
| cov64 = nested_get(length_controlled.get("64", {}), "coverage_rate") or 0.0 |
| full = payload.get("full_length_reference") or {} |
| avg_tokens = full.get("avg_tokens", full.get("avg_lexical_tokens", 0.0)) |
| stem = Path(raw_path).stem |
| name = stem.removesuffix("_1m").removesuffix("_50k") |
| family = "unknown" |
| if "datacomp" in name: |
| family = "datacomp" |
| elif "pd12m" in name: |
| family = "pd12m" |
| rows.append( |
| { |
| "name": name, |
| "family": family, |
| "group": "direct_summary", |
| "description": "", |
| "captioner": "", |
| "avg_tokens": float(avg_tokens), |
| "coverage64": float(cov64), |
| "budgets": budgets, |
| "length_controlled": length_controlled, |
| } |
| ) |
| return rows |
|
|
|
|
| def label_for_row(row: dict[str, Any]) -> str: |
| name = row["name"] |
| if name.startswith("ours_"): |
| label = f"ours:{name.removeprefix('ours_')}" |
| elif name.startswith("ref_"): |
| label = f"ref:{name.removeprefix('ref_')}" |
| else: |
| label = name |
| if name == "ref_cc12m_qwen3vl8b": |
| label += "†" |
| return label |
|
|
|
|
| def decorate_metric_label(metric_label: str, direction: str) -> str: |
| arrow = "↑" if direction == "up" else "↓" |
| return f"{metric_label} {arrow}" |
|
|
|
|
| def style_for_row(row: dict[str, Any]) -> dict[str, Any]: |
| if row["name"].startswith("ours_"): |
| return {"linewidth": 2.8, "alpha": 0.95, "linestyle": "-"} |
| return {"linewidth": 1.6, "alpha": 0.85, "linestyle": "--"} |
|
|
|
|
| def series_for_metric(row: dict[str, Any], metric_key: str) -> tuple[list[int], list[float]]: |
| xs: list[int] = [] |
| ys: list[float] = [] |
| for budget in row["budgets"]: |
| summary = row["length_controlled"].get(str(budget), {}) |
| value = nested_get(summary, metric_key) |
| if value is None: |
| continue |
| xs.append(budget) |
| ys.append(value) |
| return xs, ys |
|
|
|
|
| def save_metric_plot( |
| rows: list[dict[str, Any]], |
| metric_key: str, |
| metric_label: str, |
| direction: str, |
| regime_name: str, |
| output_path: Path, |
| ) -> None: |
| fig, ax = plt.subplots(figsize=(10.5, 6.2)) |
| for row in sorted(rows, key=lambda item: (item["family"], item["name"])): |
| xs, ys = series_for_metric(row, metric_key) |
| if not xs: |
| continue |
| ax.plot(xs, ys, marker="o", label=label_for_row(row), **style_for_row(row)) |
|
|
| decorated_label = decorate_metric_label(metric_label, direction) |
| ax.set_title(f"{decorated_label} by Budget ({regime_name})") |
| ax.set_xlabel("Token Budget") |
| ax.set_ylabel(decorated_label) |
| ax.set_xticks(sorted({budget for row in rows for budget in row["budgets"]})) |
| ax.grid(True, alpha=0.25) |
| ax.legend(fontsize=8, ncol=2) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| fig.tight_layout() |
| fig.savefig(output_path, dpi=180) |
| plt.close(fig) |
|
|
|
|
| def save_family_plot(rows: list[dict[str, Any]], family: str, output_path: Path) -> None: |
| family_rows = [row for row in rows if row["family"] == family] |
| if not family_rows: |
| return |
| fig, axes = plt.subplots(2, 3, figsize=(14, 8.5)) |
| axes = axes.flatten() |
| for axis, (metric_key, metric_label, direction) in zip(axes, METRICS[:6], strict=False): |
| for row in sorted(family_rows, key=lambda item: item["name"]): |
| xs, ys = series_for_metric(row, metric_key) |
| if not xs: |
| continue |
| axis.plot(xs, ys, marker="o", label=label_for_row(row), **style_for_row(row)) |
| axis.set_title(decorate_metric_label(metric_label, direction)) |
| axis.set_xlabel("Budget") |
| axis.grid(True, alpha=0.25) |
| handles, labels = axes[0].get_legend_handles_labels() |
| if handles: |
| fig.legend(handles, labels, loc="lower center", ncol=2, fontsize=8) |
| fig.suptitle(f"{family} Budget Curves", y=0.98) |
| fig.tight_layout(rect=(0, 0.05, 1, 0.96)) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| fig.savefig(output_path, dpi=180) |
| plt.close(fig) |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| rows = load_rows(args.input) |
| if not rows: |
| raise SystemExit("No survey rows loaded") |
|
|
| output_dir = Path(args.output_dir) |
| long_rows = [row for row in rows if row["coverage64"] >= args.long_coverage_threshold] |
| short_rows = [row for row in rows if row["coverage64"] < args.long_coverage_threshold] |
|
|
| for metric_key, metric_label, direction in METRICS: |
| if long_rows: |
| save_metric_plot( |
| long_rows, |
| metric_key, |
| metric_label, |
| direction, |
| "long-regime", |
| output_dir / "overview" / "long" / f"{metric_key.replace('.', '_')}.png", |
| ) |
| if short_rows: |
| save_metric_plot( |
| short_rows, |
| metric_key, |
| metric_label, |
| direction, |
| "short-regime", |
| output_dir / "overview" / "short" / f"{metric_key.replace('.', '_')}.png", |
| ) |
|
|
| for family in sorted({row["family"] for row in rows}): |
| save_family_plot(rows, family, output_dir / "families" / f"{family}.png") |
|
|
| manifest = { |
| "inputs": args.input, |
| "output_dir": str(output_dir), |
| "long_coverage_threshold": args.long_coverage_threshold, |
| "rows_loaded": len(rows), |
| "long_rows": [row["name"] for row in long_rows], |
| "short_rows": [row["name"] for row in short_rows], |
| "metrics": [metric_key for metric_key, _, _ in METRICS], |
| } |
| (output_dir / "plot_manifest.json").write_text(json.dumps(manifest, indent=2, ensure_ascii=False) + "\n", encoding="utf-8") |
| print(json.dumps(manifest, indent=2, ensure_ascii=False)) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|