hospital-ed / evaluate.py
testingaccc's picture
Upload folder using huggingface_hub
0fe00d1 verified
Raw
History Blame Contribute Delete
4.25 kB
"""Grade an agent on all scenarios and print a JSON score.
Usage::
python evaluate.py # default: heuristic agent
python evaluate.py --agent random # random baseline
python evaluate.py --agent ppo # loads ppo_hospital.zip if present
python evaluate.py --output score.json
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Callable
from agents import HeuristicAgent, RandomAgent
from grader import Grader
AgentFn = Callable[[dict], int]
def _load_agent(name: str) -> AgentFn:
"""Return an agent callable for the given CLI name.
If the PPO model is not found (or stable-baselines3 is not
installed), fall back to the heuristic agent with a warning.
"""
if name == "random":
return RandomAgent(seed=0)
if name == "heuristic":
return HeuristicAgent()
if name == "ppo":
model_path = Path("ppo_hospital.zip")
if not model_path.exists():
print(
f"[evaluate] {model_path} not found; falling back to heuristic.",
file=sys.stderr,
)
return HeuristicAgent()
try:
from agents.train_ppo import FlattenDictObsWrapper
from hospital_env import HospitalEnv
except ImportError as exc: # pragma: no cover - import guard
print(
f"[evaluate] stable-baselines3 unavailable ({exc}); "
"falling back to heuristic.",
file=sys.stderr,
)
return HeuristicAgent()
# Try MaskablePPO first (the default our trainer produces), fall
# back to vanilla PPO for backwards compatibility.
model = None
try:
from sb3_contrib import MaskablePPO
model = MaskablePPO.load(str(model_path))
is_maskable = True
except Exception:
from stable_baselines3 import PPO
model = PPO.load(str(model_path))
is_maskable = False
flattener = FlattenDictObsWrapper(HospitalEnv())
def ppo_agent(obs: dict, info: dict) -> int:
flat = flattener._flatten(obs)
if is_maskable:
# The mask comes from the env the grader is currently
# running, via info["action_mask"]. We can't use the
# flattener's own env because it's never reset to match.
mask = info.get("action_mask")
if mask is None:
import numpy as np # noqa: PLC0415
mask = np.ones(model.action_space.n, dtype=bool)
action, _state = model.predict(
flat, deterministic=True, action_masks=mask
)
else:
action, _state = model.predict(flat, deterministic=True)
return int(action)
return ppo_agent
raise ValueError(f"Unknown agent name: {name!r}")
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Grade an agent on all hospital scenarios and emit JSON."
)
parser.add_argument(
"--agent",
default="heuristic",
choices=["random", "heuristic", "ppo"],
help="Agent to evaluate.",
)
parser.add_argument(
"--episodes",
type=int,
default=5,
help="Episodes per scenario (lower = faster, higher = lower variance).",
)
parser.add_argument(
"--output",
default=None,
help="Optional JSON file to write the scores to.",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Base seed for deterministic grading.",
)
return parser.parse_args()
def main() -> dict:
args = _parse_args()
agent = _load_agent(args.agent)
grader = Grader(n_episodes_per_scenario=args.episodes, base_seed=args.seed)
scores = grader.grade(agent)
text = json.dumps(scores, indent=2)
print(text)
if args.output:
Path(args.output).write_text(text)
print(f"\nScores written to {args.output}", file=sys.stderr)
return scores
if __name__ == "__main__":
main()