File size: 9,455 Bytes
7f59fb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
#!/usr/bin/env python3
"""Plot budget-curve metrics from caption survey JSON outputs."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt


METRICS = [
    ("coverage_rate", "Budget Eligibility@B", "up"),
    ("distinct_n.2", "Distinct-2@B", "up"),
    ("distinct_n.3", "Distinct-3@B", "up"),
    ("ngram_top_k_mass.2", "Top-100 Bigram Mass@B", "down"),
    ("ngram_top_k_mass.3", "Top-100 Trigram Mass@B", "down"),
    ("violation_rate", "Violation Rate@B", "down"),
    ("repeated_4gram_rate", "Repeated 4-gram Rate@B", "down"),
]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Plot caption survey budget curves")
    parser.add_argument("--input", action="append", required=True, help="Survey JSON path (repeatable)")
    parser.add_argument("--output-dir", required=True, help="Directory for output PNG plots")
    parser.add_argument(
        "--long-coverage-threshold",
        type=float,
        default=0.5,
        help="budget-eligibility@64 threshold used to split long vs short regimes",
    )
    return parser.parse_args()


def nested_get(mapping: dict[str, Any], path: str) -> float | None:
    current: Any = mapping
    for part in path.split("."):
        if not isinstance(current, dict) or part not in current:
            return None
        current = current[part]
    return float(current) if isinstance(current, (int, float)) else None


def load_rows(paths: list[str]) -> list[dict[str, Any]]:
    rows: list[dict[str, Any]] = []
    for raw_path in paths:
        payload = json.loads(Path(raw_path).read_text(encoding="utf-8"))
        if "results" in payload:
            for item in payload.get("results", []):
                summary = item.get("summary") or item.get("survey_summary")
                if not isinstance(summary, dict):
                    continue
                entry = item.get("entry") or {}
                length_controlled = summary.get("length_controlled") or {}
                budgets = sorted(int(key) for key in length_controlled.keys())
                if not budgets:
                    continue
                cov64 = nested_get(length_controlled.get("64", {}), "coverage_rate") or 0.0
                full = summary.get("full_length_reference") or {}
                avg_tokens = full.get("avg_tokens", full.get("avg_lexical_tokens", 0.0))
                rows.append(
                    {
                        "name": entry.get("name", Path(raw_path).stem),
                        "family": entry.get("source_family", "unknown"),
                        "group": entry.get("group", "unknown"),
                        "description": entry.get("description", ""),
                        "captioner": entry.get("captioner", ""),
                        "avg_tokens": float(avg_tokens),
                        "coverage64": float(cov64),
                        "budgets": budgets,
                        "length_controlled": length_controlled,
                    }
                )
            continue

        if "length_controlled" in payload:
            length_controlled = payload.get("length_controlled") or {}
            budgets = sorted(int(key) for key in length_controlled.keys())
            if not budgets:
                continue
            cov64 = nested_get(length_controlled.get("64", {}), "coverage_rate") or 0.0
            full = payload.get("full_length_reference") or {}
            avg_tokens = full.get("avg_tokens", full.get("avg_lexical_tokens", 0.0))
            stem = Path(raw_path).stem
            name = stem.removesuffix("_1m").removesuffix("_50k")
            family = "unknown"
            if "datacomp" in name:
                family = "datacomp"
            elif "pd12m" in name:
                family = "pd12m"
            rows.append(
                {
                    "name": name,
                    "family": family,
                    "group": "direct_summary",
                    "description": "",
                    "captioner": "",
                    "avg_tokens": float(avg_tokens),
                    "coverage64": float(cov64),
                    "budgets": budgets,
                    "length_controlled": length_controlled,
                }
            )
    return rows


def label_for_row(row: dict[str, Any]) -> str:
    name = row["name"]
    if name.startswith("ours_"):
        label = f"ours:{name.removeprefix('ours_')}"
    elif name.startswith("ref_"):
        label = f"ref:{name.removeprefix('ref_')}"
    else:
        label = name
    if name == "ref_cc12m_qwen3vl8b":
        label += "†"
    return label


def decorate_metric_label(metric_label: str, direction: str) -> str:
    arrow = "↑" if direction == "up" else "↓"
    return f"{metric_label} {arrow}"


def style_for_row(row: dict[str, Any]) -> dict[str, Any]:
    if row["name"].startswith("ours_"):
        return {"linewidth": 2.8, "alpha": 0.95, "linestyle": "-"}
    return {"linewidth": 1.6, "alpha": 0.85, "linestyle": "--"}


def series_for_metric(row: dict[str, Any], metric_key: str) -> tuple[list[int], list[float]]:
    xs: list[int] = []
    ys: list[float] = []
    for budget in row["budgets"]:
        summary = row["length_controlled"].get(str(budget), {})
        value = nested_get(summary, metric_key)
        if value is None:
            continue
        xs.append(budget)
        ys.append(value)
    return xs, ys


def save_metric_plot(
    rows: list[dict[str, Any]],
    metric_key: str,
    metric_label: str,
    direction: str,
    regime_name: str,
    output_path: Path,
) -> None:
    fig, ax = plt.subplots(figsize=(10.5, 6.2))
    for row in sorted(rows, key=lambda item: (item["family"], item["name"])):
        xs, ys = series_for_metric(row, metric_key)
        if not xs:
            continue
        ax.plot(xs, ys, marker="o", label=label_for_row(row), **style_for_row(row))

    decorated_label = decorate_metric_label(metric_label, direction)
    ax.set_title(f"{decorated_label} by Budget ({regime_name})")
    ax.set_xlabel("Token Budget")
    ax.set_ylabel(decorated_label)
    ax.set_xticks(sorted({budget for row in rows for budget in row["budgets"]}))
    ax.grid(True, alpha=0.25)
    ax.legend(fontsize=8, ncol=2)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    fig.tight_layout()
    fig.savefig(output_path, dpi=180)
    plt.close(fig)


def save_family_plot(rows: list[dict[str, Any]], family: str, output_path: Path) -> None:
    family_rows = [row for row in rows if row["family"] == family]
    if not family_rows:
        return
    fig, axes = plt.subplots(2, 3, figsize=(14, 8.5))
    axes = axes.flatten()
    for axis, (metric_key, metric_label, direction) in zip(axes, METRICS[:6], strict=False):
        for row in sorted(family_rows, key=lambda item: item["name"]):
            xs, ys = series_for_metric(row, metric_key)
            if not xs:
                continue
            axis.plot(xs, ys, marker="o", label=label_for_row(row), **style_for_row(row))
        axis.set_title(decorate_metric_label(metric_label, direction))
        axis.set_xlabel("Budget")
        axis.grid(True, alpha=0.25)
    handles, labels = axes[0].get_legend_handles_labels()
    if handles:
        fig.legend(handles, labels, loc="lower center", ncol=2, fontsize=8)
    fig.suptitle(f"{family} Budget Curves", y=0.98)
    fig.tight_layout(rect=(0, 0.05, 1, 0.96))
    output_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(output_path, dpi=180)
    plt.close(fig)


def main() -> int:
    args = parse_args()
    rows = load_rows(args.input)
    if not rows:
        raise SystemExit("No survey rows loaded")

    output_dir = Path(args.output_dir)
    long_rows = [row for row in rows if row["coverage64"] >= args.long_coverage_threshold]
    short_rows = [row for row in rows if row["coverage64"] < args.long_coverage_threshold]

    for metric_key, metric_label, direction in METRICS:
        if long_rows:
            save_metric_plot(
                long_rows,
                metric_key,
                metric_label,
                direction,
                "long-regime",
                output_dir / "overview" / "long" / f"{metric_key.replace('.', '_')}.png",
            )
        if short_rows:
            save_metric_plot(
                short_rows,
                metric_key,
                metric_label,
                direction,
                "short-regime",
                output_dir / "overview" / "short" / f"{metric_key.replace('.', '_')}.png",
            )

    for family in sorted({row["family"] for row in rows}):
        save_family_plot(rows, family, output_dir / "families" / f"{family}.png")

    manifest = {
        "inputs": args.input,
        "output_dir": str(output_dir),
        "long_coverage_threshold": args.long_coverage_threshold,
        "rows_loaded": len(rows),
        "long_rows": [row["name"] for row in long_rows],
        "short_rows": [row["name"] for row in short_rows],
        "metrics": [metric_key for metric_key, _, _ in METRICS],
    }
    (output_dir / "plot_manifest.json").write_text(json.dumps(manifest, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
    print(json.dumps(manifest, indent=2, ensure_ascii=False))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())