File size: 5,226 Bytes
b21cd69
296a94d
b21cd69
296a94d
 
 
 
b21cd69
 
296a94d
 
b21cd69
 
 
 
db1d448
296a94d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db1d448
296a94d
db1d448
 
296a94d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3e0ac5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296a94d
 
 
 
 
 
 
 
e3e0ac5
296a94d
e3e0ac5
296a94d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
Plot latest Spider benchmark results.

Outputs in the latest folder under benchmarks/results_pro/:
- metrics_overview.png: EM/SM/ExecAcc + latency (avg, p50, p95)
- latency_per_stage.png: bar of average per-stage latency
- latency_histogram.png: latency distribution across samples
"""

from __future__ import annotations

import json
from pathlib import Path
import matplotlib.pyplot as plt

ROOT = Path("benchmarks/results_pro")


def _latest_run_dir() -> Path:
    summaries = sorted(
        ROOT.glob("*/summary.json"), key=lambda p: p.stat().st_mtime, reverse=True
    )
    if not summaries:
        raise SystemExit("❌ No benchmark results found under benchmarks/results_pro/")
    return summaries[0].parent


def _load_summary(run: Path) -> dict:
    return json.loads((run / "summary.json").read_text(encoding="utf-8"))


def _load_eval_rows(run: Path) -> list[dict]:
    lines = (run / "eval.jsonl").read_text(encoding="utf-8").splitlines()
    return [json.loads(x) for x in lines]


def plot_metrics_overview(run: Path, summary: dict) -> None:
    # EM/SM/ExecAcc on [0,1]; latency in ms (show as seconds for scale)
    labels = ["EM", "SM", "ExecAcc", "avg(s)", "p50(s)", "p95(s)"]
    values = [
        summary.get("EM", 0.0),
        summary.get("SM", 0.0),
        summary.get("ExecAcc", 0.0),
        summary.get("avg_latency_ms", 0.0) / 1000.0,
        summary.get("p50_latency_ms", 0.0) / 1000.0,
        summary.get("p95_latency_ms", 0.0) / 1000.0,
    ]

    plt.figure(figsize=(9, 5))
    bars = plt.bar(labels, values)
    for b, v in zip(bars, values):
        plt.text(b.get_x() + b.get_width() / 2, v, f"{v:.2f}", ha="center", va="bottom")
    plt.title("Metrics Overview (Spider)")
    plt.ylim(0, max(1.0, max(values) * 1.15 if values else 1.0))
    plt.tight_layout()
    plt.savefig(run / "metrics_overview.png")
    plt.close()


def plot_latency_hist(run: Path, rows: list[dict]) -> None:
    latencies = [
        r.get("latency_ms", 0)
        for r in rows
        if isinstance(r.get("latency_ms"), (int, float))
    ]
    if not latencies:
        return
    plt.figure(figsize=(9, 4))
    plt.hist(latencies, bins=min(20, max(5, int(len(latencies) ** 0.5))))
    plt.title("Latency Distribution (ms)")
    plt.xlabel("Latency (ms)")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.savefig(run / "latency_histogram.png")
    plt.close()


def plot_latency_per_stage(run: Path, summary: dict, rows: list[dict]) -> None:
    stages = [
        "detector",
        "planner",
        "generator",
        "safety",
        "executor",
        "verifier",
        "repair",
    ]
    # prefer summary keys if available; else derive from traces
    raw_values = [summary.get(f"{s}_avg_ms") for s in stages]
    # convert Nones to 0.0
    values: list[float] = [float(v or 0.0) for v in raw_values]

    if not any(values):
        totals = {s: 0.0 for s in stages}
        counts = {s: 0 for s in stages}
        for r in rows:
            trace = r.get("trace") or r.get("traces") or []
            for t in trace:
                s = t.get("stage")
                if s in totals:
                    ms = t.get("ms", t.get("duration_ms", 0.0))
                    try:
                        totals[s] += float(ms)
                        counts[s] += 1
                    except Exception:
                        pass
        values = [round(totals[s] / counts[s], 2) if counts[s] else 0.0 for s in stages]

    plt.figure(figsize=(10, 5))
    bars = plt.bar(stages, values)
    for b, v in zip(bars, values):
        plt.text(
            b.get_x() + b.get_width() / 2,
            float(v),
            f"{v:.1f}",
            ha="center",
            va="bottom",
        )
    plt.title("Average Latency per Stage (ms)")
    plt.xlabel("Stage")
    plt.ylabel("Latency (ms)")
    plt.tight_layout()
    plt.savefig(run / "latency_per_stage.png")
    plt.close()


def plot_errors_overview(run: Path) -> None:
    p = run / "trace.jsonl"
    if not p.exists():
        return
    from collections import Counter

    counts: Counter[str] = Counter()
    with p.open("r", encoding="utf-8") as f:
        for line in f:
            try:
                obj = json.loads(line)
            except Exception:
                continue
            for t in obj.get("trace", []):
                et = t.get("error_type")
                if et:
                    counts[et] += 1
    if not counts:
        return
    labels, values = zip(*sorted(counts.items(), key=lambda x: x[1], reverse=True))
    plt.figure(figsize=(9, 4))
    plt.bar(labels, values)
    plt.title("Errors by Type")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.savefig(run / "errors_overview.png")
    plt.close()


def main() -> None:
    run = _latest_run_dir()
    print(f"πŸ“‚ Using latest run: {run.name}")
    summary = _load_summary(run)
    rows = _load_eval_rows(run)
    plot_metrics_overview(run, summary)
    plot_latency_hist(run, rows)
    plot_latency_per_stage(run, summary, rows)
    plot_errors_overview(run)
    print(
        "βœ… Saved: metrics_overview.png, latency_histogram.png, latency_per_stage.png, errors_overview.png"
    )


if __name__ == "__main__":
    main()