Spaces:
Sleeping
Sleeping
| """Grade an agent on all scenarios and print a JSON score. | |
| Usage:: | |
| python evaluate.py # default: heuristic agent | |
| python evaluate.py --agent random # random baseline | |
| python evaluate.py --agent ppo # loads ppo_hospital.zip if present | |
| python evaluate.py --output score.json | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Callable | |
| from agents import HeuristicAgent, RandomAgent | |
| from grader import Grader | |
| AgentFn = Callable[[dict], int] | |
| def _load_agent(name: str) -> AgentFn: | |
| """Return an agent callable for the given CLI name. | |
| If the PPO model is not found (or stable-baselines3 is not | |
| installed), fall back to the heuristic agent with a warning. | |
| """ | |
| if name == "random": | |
| return RandomAgent(seed=0) | |
| if name == "heuristic": | |
| return HeuristicAgent() | |
| if name == "ppo": | |
| model_path = Path("ppo_hospital.zip") | |
| if not model_path.exists(): | |
| print( | |
| f"[evaluate] {model_path} not found; falling back to heuristic.", | |
| file=sys.stderr, | |
| ) | |
| return HeuristicAgent() | |
| try: | |
| from agents.train_ppo import FlattenDictObsWrapper | |
| from hospital_env import HospitalEnv | |
| except ImportError as exc: # pragma: no cover - import guard | |
| print( | |
| f"[evaluate] stable-baselines3 unavailable ({exc}); " | |
| "falling back to heuristic.", | |
| file=sys.stderr, | |
| ) | |
| return HeuristicAgent() | |
| # Try MaskablePPO first (the default our trainer produces), fall | |
| # back to vanilla PPO for backwards compatibility. | |
| model = None | |
| try: | |
| from sb3_contrib import MaskablePPO | |
| model = MaskablePPO.load(str(model_path)) | |
| is_maskable = True | |
| except Exception: | |
| from stable_baselines3 import PPO | |
| model = PPO.load(str(model_path)) | |
| is_maskable = False | |
| flattener = FlattenDictObsWrapper(HospitalEnv()) | |
| def ppo_agent(obs: dict, info: dict) -> int: | |
| flat = flattener._flatten(obs) | |
| if is_maskable: | |
| # The mask comes from the env the grader is currently | |
| # running, via info["action_mask"]. We can't use the | |
| # flattener's own env because it's never reset to match. | |
| mask = info.get("action_mask") | |
| if mask is None: | |
| import numpy as np # noqa: PLC0415 | |
| mask = np.ones(model.action_space.n, dtype=bool) | |
| action, _state = model.predict( | |
| flat, deterministic=True, action_masks=mask | |
| ) | |
| else: | |
| action, _state = model.predict(flat, deterministic=True) | |
| return int(action) | |
| return ppo_agent | |
| raise ValueError(f"Unknown agent name: {name!r}") | |
| def _parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Grade an agent on all hospital scenarios and emit JSON." | |
| ) | |
| parser.add_argument( | |
| "--agent", | |
| default="heuristic", | |
| choices=["random", "heuristic", "ppo"], | |
| help="Agent to evaluate.", | |
| ) | |
| parser.add_argument( | |
| "--episodes", | |
| type=int, | |
| default=5, | |
| help="Episodes per scenario (lower = faster, higher = lower variance).", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| default=None, | |
| help="Optional JSON file to write the scores to.", | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=42, | |
| help="Base seed for deterministic grading.", | |
| ) | |
| return parser.parse_args() | |
| def main() -> dict: | |
| args = _parse_args() | |
| agent = _load_agent(args.agent) | |
| grader = Grader(n_episodes_per_scenario=args.episodes, base_seed=args.seed) | |
| scores = grader.grade(agent) | |
| text = json.dumps(scores, indent=2) | |
| print(text) | |
| if args.output: | |
| Path(args.output).write_text(text) | |
| print(f"\nScores written to {args.output}", file=sys.stderr) | |
| return scores | |
| if __name__ == "__main__": | |
| main() | |