"""Grade an agent on all scenarios and print a JSON score. Usage:: python evaluate.py # default: heuristic agent python evaluate.py --agent random # random baseline python evaluate.py --agent ppo # loads ppo_hospital.zip if present python evaluate.py --output score.json """ from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Callable from agents import HeuristicAgent, RandomAgent from grader import Grader AgentFn = Callable[[dict], int] def _load_agent(name: str) -> AgentFn: """Return an agent callable for the given CLI name. If the PPO model is not found (or stable-baselines3 is not installed), fall back to the heuristic agent with a warning. """ if name == "random": return RandomAgent(seed=0) if name == "heuristic": return HeuristicAgent() if name == "ppo": model_path = Path("ppo_hospital.zip") if not model_path.exists(): print( f"[evaluate] {model_path} not found; falling back to heuristic.", file=sys.stderr, ) return HeuristicAgent() try: from agents.train_ppo import FlattenDictObsWrapper from hospital_env import HospitalEnv except ImportError as exc: # pragma: no cover - import guard print( f"[evaluate] stable-baselines3 unavailable ({exc}); " "falling back to heuristic.", file=sys.stderr, ) return HeuristicAgent() # Try MaskablePPO first (the default our trainer produces), fall # back to vanilla PPO for backwards compatibility. model = None try: from sb3_contrib import MaskablePPO model = MaskablePPO.load(str(model_path)) is_maskable = True except Exception: from stable_baselines3 import PPO model = PPO.load(str(model_path)) is_maskable = False flattener = FlattenDictObsWrapper(HospitalEnv()) def ppo_agent(obs: dict, info: dict) -> int: flat = flattener._flatten(obs) if is_maskable: # The mask comes from the env the grader is currently # running, via info["action_mask"]. We can't use the # flattener's own env because it's never reset to match. mask = info.get("action_mask") if mask is None: import numpy as np # noqa: PLC0415 mask = np.ones(model.action_space.n, dtype=bool) action, _state = model.predict( flat, deterministic=True, action_masks=mask ) else: action, _state = model.predict(flat, deterministic=True) return int(action) return ppo_agent raise ValueError(f"Unknown agent name: {name!r}") def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Grade an agent on all hospital scenarios and emit JSON." ) parser.add_argument( "--agent", default="heuristic", choices=["random", "heuristic", "ppo"], help="Agent to evaluate.", ) parser.add_argument( "--episodes", type=int, default=5, help="Episodes per scenario (lower = faster, higher = lower variance).", ) parser.add_argument( "--output", default=None, help="Optional JSON file to write the scores to.", ) parser.add_argument( "--seed", type=int, default=42, help="Base seed for deterministic grading.", ) return parser.parse_args() def main() -> dict: args = _parse_args() agent = _load_agent(args.agent) grader = Grader(n_episodes_per_scenario=args.episodes, base_seed=args.seed) scores = grader.grade(agent) text = json.dumps(scores, indent=2) print(text) if args.output: Path(args.output).write_text(text) print(f"\nScores written to {args.output}", file=sys.stderr) return scores if __name__ == "__main__": main()