File size: 5,258 Bytes
08ff31f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
"""Aggregate LIBERO eval JSONs into a per-(speed, suite) CSV.

Layout expected:
    <results_dir>/speed_<tag>/{spatial,goal,object,long_t*_*}_<tag>.json

For each speed_<tag> subdirectory, computes per-suite (spatial / goal / object /
long) and overall success rate, mean_steps_success and mean_steps_all by
concatenating the `episodes` lists from the relevant json files. The five
`long_t*_*.json` shards are merged into a single `long` row.
"""

from __future__ import annotations

import argparse
import csv
import json
import math
from pathlib import Path

SUITE_LABELS = {
    "spatial": ["spatial"],
    "goal": ["goal"],
    "object": ["object"],
    "long": ["long_t0_1", "long_t2_3", "long_t4_5", "long_t6_7", "long_t8_9"],
}


def _label_from_stem(stem: str) -> str:
    # filename like "long_t0_1_0p75x" or "spatial_1x" -> drop the trailing speed tag
    return "_".join(stem.split("_")[:-1])


def _aggregate(episodes: list[dict]) -> dict:
    n = len(episodes)
    succ = [e for e in episodes if e.get("success")]
    steps_all = [e["steps"] for e in episodes]
    steps_succ = [e["steps"] for e in succ]
    return {
        "n_episodes": n,
        "n_success": len(succ),
        "success_rate": len(succ) / n if n else math.nan,
        "mean_steps_success": (sum(steps_succ) / len(steps_succ)) if steps_succ else math.nan,
        "mean_steps_all": (sum(steps_all) / n) if n else math.nan,
    }


def _collect_speed_dir(speed_dir: Path) -> dict[str, dict]:
    """Return {suite_name: aggregate_dict} for one speed_<tag> directory."""
    by_label: dict[str, list[dict]] = {}
    for fp in sorted(speed_dir.glob("*.json")):
        label = _label_from_stem(fp.stem)
        with fp.open() as f:
            data = json.load(f)
        by_label[label] = data.get("episodes", [])

    rows: dict[str, dict] = {}
    all_eps: list[dict] = []
    for suite, labels in SUITE_LABELS.items():
        eps: list[dict] = []
        missing = [lbl for lbl in labels if lbl not in by_label]
        for lbl in labels:
            eps.extend(by_label.get(lbl, []))
        if not eps:
            print(f"  [warn] {speed_dir.name}: no episodes for suite={suite} (missing={missing})")
            continue
        if missing:
            print(f"  [warn] {speed_dir.name}: suite={suite} missing shards {missing}")
        rows[suite] = _aggregate(eps)
        all_eps.extend(eps)

    if all_eps:
        rows["overall"] = _aggregate(all_eps)
    return rows


def _speed_from_dirname(name: str) -> str:
    # "speed_0p75x" -> "0.75", "speed_1x" -> "1.0"
    tag = name.removeprefix("speed_").removesuffix("x")
    return tag.replace("p", ".") if "p" in tag else f"{float(tag):.1f}"


def main() -> None:
    ap = argparse.ArgumentParser(description=__doc__)
    ap.add_argument(
        "results_dir",
        type=Path,
        help="Directory containing speed_<tag>/ subdirectories with eval JSONs.",
    )
    ap.add_argument(
        "-o",
        "--output",
        type=Path,
        default=None,
        help="Output CSV path (default: <results_dir>/eval_summary.csv)",
    )
    args = ap.parse_args()

    if not args.results_dir.is_dir():
        ap.error(f"results_dir does not exist: {args.results_dir}")

    out_path = args.output or (args.results_dir / "eval_summary.csv")

    speed_dirs = sorted(p for p in args.results_dir.glob("speed_*") if p.is_dir())
    if not speed_dirs:
        ap.error(f"no speed_*/ subdirectories under {args.results_dir}")

    rows: list[dict] = []
    suite_order = list(SUITE_LABELS.keys()) + ["overall"]

    for sd in speed_dirs:
        speed = _speed_from_dirname(sd.name)
        suite_rows = _collect_speed_dir(sd)
        for suite in suite_order:
            if suite not in suite_rows:
                continue
            agg = suite_rows[suite]
            rows.append(
                {
                    "speed": speed,
                    "speed_tag": sd.name.removeprefix("speed_"),
                    "suite": suite,
                    "n_episodes": agg["n_episodes"],
                    "n_success": agg["n_success"],
                    "success_rate": round(agg["success_rate"], 4),
                    "mean_steps_success": round(agg["mean_steps_success"], 2),
                    "mean_steps_all": round(agg["mean_steps_all"], 2),
                }
            )

    fieldnames = [
        "speed",
        "speed_tag",
        "suite",
        "n_episodes",
        "n_success",
        "success_rate",
        "mean_steps_success",
        "mean_steps_all",
    ]
    with out_path.open("w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        w.writerows(rows)

    print(f"\nWrote {len(rows)} rows -> {out_path}")
    # also print a quick console table
    print()
    print(f"{'speed':<6} {'suite':<8} {'success':>10} {'sr':>7} {'steps_succ':>12} {'steps_all':>11}")
    for r in rows:
        print(
            f"{r['speed']:<6} {r['suite']:<8} "
            f"{r['n_success']:>4}/{r['n_episodes']:<5} "
            f"{r['success_rate']*100:>6.1f}% "
            f"{r['mean_steps_success']:>12.1f} "
            f"{r['mean_steps_all']:>11.1f}"
        )


if __name__ == "__main__":
    main()