File size: 1,241 Bytes
89e75e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python3
"""Run a reference StateShiftBench evaluation."""

import argparse
import json
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from stateshiftbench.runner import run_case
from stateshiftbench.schemas import validate_case


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", required=True, help="Directory containing case JSON files.")
    parser.add_argument("--strategy", default="direct", choices=["direct", "stact_reference"])
    parser.add_argument("--limit", type=int, default=None)
    parser.add_argument("--output", required=True)
    args = parser.parse_args()

    paths = sorted(Path(args.data).glob("*.json"))
    if args.limit is not None:
        paths = paths[: args.limit]

    output = Path(args.output)
    output.parent.mkdir(parents=True, exist_ok=True)
    with output.open("w") as f:
        for path in paths:
            case = json.loads(path.read_text())
            validate_case(case)
            episode = run_case(case, strategy=args.strategy)
            f.write(json.dumps(episode, sort_keys=True) + "\n")
    print(f"wrote {len(paths)} episodes to {output}")


if __name__ == "__main__":
    main()