File size: 3,731 Bytes
2733f3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Run the scripted-optimal baseline across all 12 templates × 5 procgen variants
and print a summary table. This is the smoke-check that the env is healthy and
the baseline ceiling is preserved.

Usage:
    python scripts/eval_baseline.py
    python scripts/eval_baseline.py --templates-only
    python scripts/eval_baseline.py --episodes-per-scenario 3
    python scripts/eval_baseline.py --output eval/results/baseline.jsonl
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from statistics import mean

from unified_incident_env.models import UnifiedIncidentAction
from unified_incident_env.server.challenge import (
    SCENARIOS,
    list_baselines,
)
from unified_incident_env.server.environment import UnifiedIncidentEnvironment


def run_one(scenario_id: str) -> dict:
    env = UnifiedIncidentEnvironment()
    obs = env.reset(scenario_id=scenario_id)
    baseline = list_baselines(scenario_id=scenario_id).baselines[0]
    for step in baseline.actions:
        obs = env.step(step.action)
        if obs.done:
            break
    return {
        "scenario_id": scenario_id,
        "template_id": SCENARIOS[scenario_id].get("template_id", scenario_id),
        "is_procgen": SCENARIOS[scenario_id].get("is_procgen", False),
        "final_score": float(obs.final_score),
        "incident_resolved": bool(obs.incident_resolved),
        "tick_count": int(obs.tick_count),
        "breakdown": dict(obs.score_breakdown),
    }


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--templates-only", action="store_true",
                        help="Run only the 12 base templates, skip procgen variants.")
    parser.add_argument("--episodes-per-scenario", type=int, default=1,
                        help="Number of times to run each scenario (deterministic, so default 1).")
    parser.add_argument("--output", type=str, default=None,
                        help="Optional JSONL output path.")
    args = parser.parse_args()

    if args.templates_only:
        scenario_ids = sorted(sid for sid, sc in SCENARIOS.items() if not sc.get("is_procgen"))
    else:
        scenario_ids = sorted(SCENARIOS.keys())

    results = []
    for sid in scenario_ids:
        for _ in range(args.episodes_per_scenario):
            r = run_one(sid)
            results.append(r)

    print(f"\n{'scenario':<40} {'score':>7}  {'resolved':>9}  {'ticks':>5}")
    print("-" * 70)
    for r in results:
        flag = "OK" if r["incident_resolved"] else "X"
        print(f"{r['scenario_id']:<40} {r['final_score']:>7.3f}  {flag:>9}  {r['tick_count']:>5}")

    print()
    by_template: dict[str, list[float]] = {}
    for r in results:
        by_template.setdefault(r["template_id"], []).append(r["final_score"])
    print(f"{'template':<40} {'mean':>7}  {'min':>7}  {'max':>7}  {'n':>3}")
    print("-" * 70)
    for tid, scores in sorted(by_template.items()):
        print(f"{tid:<40} {mean(scores):>7.3f}  {min(scores):>7.3f}  {max(scores):>7.3f}  {len(scores):>3}")

    overall_mean = mean(r["final_score"] for r in results)
    overall_resolved = sum(r["incident_resolved"] for r in results)
    print(f"\nOverall: mean={overall_mean:.3f}, resolved={overall_resolved}/{len(results)}")
    if overall_mean > 0.80:
        print("WARNING: scripted baseline ceiling exceeded 0.80 — see docs/REWARD_DESIGN.md §4")

    if args.output:
        out = Path(args.output)
        out.parent.mkdir(parents=True, exist_ok=True)
        with out.open("w") as f:
            for r in results:
                f.write(json.dumps(r) + "\n")
        print(f"\nWrote {len(results)} rows -> {out}")


if __name__ == "__main__":
    main()