OpenRA-Bench / scripts /audit_scenarios.py
yxc20098's picture
Defect-fix wave: restore the no-cheat bar across 10 packs
e08e6a1
Raw
History Blame Contribute Delete
7.9 kB
"""Re-audit every active pack against the no-cheat bar.
The bench's central authoring rule is "every lazy / brute / stall
policy must LOSE on every level"; a pack where the `stall` (observe-
only) policy WINS at any level is defective ("laziest-play-wins").
This script re-runs the static bar across all 200 packs after the
fact — engine fixes since authoring can have drifted things, so the
audit catches benchmark rot.
Optionally adds an EMPIRICAL layer: from one or more
`run_eval --out` report JSONs, flags packs where every model wins
(too easy / a trivial idiom dominates) or every model loses
(unsolvable / a load-bearing predicate is mis-tuned).
Run from the repo root:
python scripts/audit_scenarios.py [report.json ...]
"""
from __future__ import annotations
import json
import sys
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
# Project-relative imports.
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
from openra_bench.eval_core import run_level # noqa: E402
from openra_bench.scenarios import load_pack # noqa: E402
from openra_bench.scenarios.loader import PACKS_DIR, compile_level # noqa: E402
LEVELS = ("easy", "medium", "hard")
SEEDS = (1,)
# Packs where "do nothing" is the INTENDED winning policy by design —
# the audit's stall-wins-is-a-defect rule doesn't apply. Each pack
# here is justified in its own test file's docstring.
STALL_WINS_BY_DESIGN: frozenset[str] = frozenset({
# Positional-discipline pack: pre-positioned `stance:2` flanker
# crossfire — the player's job is NOT to move the flankers, so
# `observe` IS the intended policy. See tests/test_def_with_ambush.py.
"def-with-ambush",
})
def stall(_render_state, Command):
return [Command.observe()]
def _active_packs() -> list[Path]:
out: list[Path] = []
for p in sorted(PACKS_DIR.glob("*.yaml")):
if p.name.startswith(("_", "TEMPLATE")):
continue
try:
d = load_pack(p)
except Exception: # noqa: BLE001
continue
if getattr(d.meta, "status", "active") == "quarantine":
continue
out.append(p)
return out
def _one(pack_path: Path, level: str, seed: int) -> dict:
"""Run stall on one cell; report whether it (defectively) won."""
try:
c = compile_level(load_pack(pack_path), level)
if not c.map_supported:
return {"pack": pack_path.stem, "level": level, "seed": seed,
"outcome": "skip", "reason": "map_not_loadable"}
res = run_level(c, stall, seed=seed)
return {"pack": pack_path.stem, "level": level, "seed": seed,
"outcome": res.outcome, "turns": res.turns}
except Exception as e: # noqa: BLE001
return {"pack": pack_path.stem, "level": level, "seed": seed,
"outcome": "error", "reason": f"{type(e).__name__}: {e}"[:120]}
def static_audit(concurrency: int = 8) -> list[dict]:
"""Run stall on every active pack × level. Returns per-cell records."""
packs = _active_packs()
tasks = [(p, lv, sd) for p in packs for lv in LEVELS for sd in SEEDS]
print(f"[static] {len(packs)} packs × {len(LEVELS)} levels × "
f"{len(SEEDS)} seeds = {len(tasks)} episodes", file=sys.stderr)
out: list[dict] = []
with ThreadPoolExecutor(max_workers=concurrency) as ex:
futs = [ex.submit(_one, *t) for t in tasks]
for i, f in enumerate(as_completed(futs), 1):
rec = f.result()
out.append(rec)
if i % 50 == 0:
print(f" {i}/{len(tasks)} done", file=sys.stderr)
return out
def report_static(records: list[dict]) -> None:
"""Print the static-audit summary + the defect list."""
by_pack: dict[str, list[dict]] = defaultdict(list)
for r in records:
by_pack[r["pack"]].append(r)
stall_wins: list[tuple[str, str]] = [] # (pack, level)
errors: list[tuple[str, str, str]] = [] # (pack, level, reason)
skips: list[tuple[str, str, str]] = []
clean_packs = 0
exempt_noted: list[str] = []
for pack, rs in by_pack.items():
outs = {r["level"]: r for r in rs}
if pack in STALL_WINS_BY_DESIGN:
exempt_noted.append(pack)
continue
if all(r["outcome"] == "loss" for r in rs):
clean_packs += 1
for lv, r in outs.items():
if r["outcome"] == "win":
stall_wins.append((pack, lv))
elif r["outcome"] == "draw":
stall_wins.append((pack, f"{lv}/DRAW"))
elif r["outcome"] == "error":
errors.append((pack, lv, r.get("reason", "")))
elif r["outcome"] == "skip":
skips.append((pack, lv, r.get("reason", "")))
print("=" * 72)
print(f"STATIC AUDIT — stall policy on {len(by_pack)} packs × {len(LEVELS)} levels")
print("=" * 72)
print(f" clean (all 3 levels LOSS — intended) : {clean_packs}")
print(f" defects (stall wins or draws) : {len(stall_wins)}")
print(f" errors : {len(errors)}")
print(f" skipped (map not loadable) : {len(skips)}")
print(f" exempt (stall-wins-by-design) : {len(exempt_noted)}"
f" {exempt_noted if exempt_noted else ''}")
if stall_wins:
print("\n--- DEFECTS — laziest play wins (re-author or retune) ---")
for pack, lv in sorted(stall_wins):
print(f" {pack} [{lv}]")
if errors:
print("\n--- ERRORS ---")
for pack, lv, why in sorted(errors)[:20]:
print(f" {pack} [{lv}] {why}")
if len(errors) > 20:
print(f" ...{len(errors) - 20} more")
def empirical_audit(report_paths: list[str]) -> None:
"""From run_eval JSON reports, flag packs where every model wins
(too easy) or every model loses (unsolvable). Needs ≥2 models'
reports to mean anything."""
print()
print("=" * 72)
print(f"EMPIRICAL AUDIT — {len(report_paths)} model report(s)")
print("=" * 72)
# cell → list of (model, outcome)
per_cell: dict[str, list[tuple[str, str]]] = defaultdict(list)
for path in report_paths:
try:
rep = json.loads(Path(path).read_text())
except Exception as e: # noqa: BLE001
print(f" could not load {path}: {e}", file=sys.stderr)
continue
model = rep.get("model") or Path(path).stem
for e in rep.get("episodes", []):
if e.get("outcome") in {"win", "loss", "draw"}:
per_cell[e["cell"]].append((model, e["outcome"]))
n_models = len(report_paths)
too_easy, unsolvable = [], []
for cell, rs in per_cell.items():
if len({m for m, _ in rs}) < n_models:
continue # not all models covered this cell
outs = [o for _, o in rs]
if all(o == "win" for o in outs):
too_easy.append(cell)
elif all(o == "loss" for o in outs):
unsolvable.append(cell)
print(f" cells covered by ALL {n_models} models : {sum(1 for rs in per_cell.values() if len({m for m,_ in rs}) >= n_models)}")
print(f" too easy (every model wins) : {len(too_easy)}")
print(f" unsolvable (every model loses) : {len(unsolvable)}")
if too_easy:
print("\n--- TOO EASY (every model wins) ---")
for c in sorted(too_easy):
print(f" {c}")
if unsolvable:
print("\n--- UNSOLVABLE (every model loses) ---")
for c in sorted(unsolvable):
print(f" {c}")
def main(argv: list[str]) -> int:
records = static_audit()
report_static(records)
if argv[1:]:
empirical_audit(argv[1:])
return 0
if __name__ == "__main__":
raise SystemExit(main(sys.argv))