| |
| """Run a reference StateShiftBench evaluation.""" |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
|
|
| from stateshiftbench.runner import run_case |
| from stateshiftbench.schemas import validate_case |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data", required=True, help="Directory containing case JSON files.") |
| parser.add_argument("--strategy", default="direct", choices=["direct", "stact_reference"]) |
| parser.add_argument("--limit", type=int, default=None) |
| parser.add_argument("--output", required=True) |
| args = parser.parse_args() |
|
|
| paths = sorted(Path(args.data).glob("*.json")) |
| if args.limit is not None: |
| paths = paths[: args.limit] |
|
|
| output = Path(args.output) |
| output.parent.mkdir(parents=True, exist_ok=True) |
| with output.open("w") as f: |
| for path in paths: |
| case = json.loads(path.read_text()) |
| validate_case(case) |
| episode = run_case(case, strategy=args.strategy) |
| f.write(json.dumps(episode, sort_keys=True) + "\n") |
| print(f"wrote {len(paths)} episodes to {output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|