Spaces:
Sleeping
Sleeping
| """Benchmark module for running 911 dispatch RL tasks.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import random | |
| from typing import Any | |
| from src.models import Action, DispatchAction | |
| from src.grading import grade_episode | |
| from src.openenv_environment import OpenEnvEnvironment | |
| from src.tasks.registry import TaskRegistry | |
| def list_tasks() -> list[dict[str, Any]]: | |
| tasks = TaskRegistry.list_tasks() | |
| return [ | |
| {"task_id": t.task_id, "name": t.name, "difficulty": t.difficulty} | |
| for t in tasks | |
| ] | |
| async def _run_episode_async(task_id: str, seed: int) -> tuple[float, list[float]]: | |
| env = OpenEnvEnvironment(task_id=task_id, seed=seed) | |
| rewards: list[float] = [] | |
| final_state = None | |
| try: | |
| await env.reset() | |
| final_state = env.state() | |
| rng = random.Random(seed) | |
| for _ in range(1000): | |
| legal_actions = env.legal_actions() | |
| if legal_actions: | |
| action = rng.choice(legal_actions) | |
| else: | |
| # Fallback: attempt to dispatch the first unit to the first incident. | |
| st = env.state() | |
| if not st.units or not st.incidents: | |
| break | |
| unit_id = next(iter(st.units.keys())) | |
| incident_id = next(iter(st.incidents.keys())) | |
| action = Action( | |
| action_type=DispatchAction.DISPATCH, | |
| unit_id=unit_id, | |
| incident_id=incident_id, | |
| ) | |
| obs, reward, done = await env.step(action) | |
| rewards.append(reward) | |
| final_state = env.state() | |
| if done: | |
| break | |
| finally: | |
| env.close() | |
| if final_state is None: | |
| from src.models import State | |
| final_state = State( | |
| units={}, | |
| incidents={}, | |
| episode_id="", | |
| step_count=0, | |
| task_id=task_id, | |
| city_time=0.0, | |
| metadata={}, | |
| ) | |
| # Score episodes the same way as the OpenEnv evaluation path. | |
| final_score = grade_episode(task_id=task_id, state=final_state, rewards=rewards) | |
| return final_score, rewards | |
| def run_task(task_id: str, seed: int) -> dict[str, Any]: | |
| TaskRegistry.get(task_id) | |
| final_score, rewards = asyncio.run(_run_episode_async(task_id, seed)) | |
| return { | |
| "task_id": task_id, | |
| "seed": seed, | |
| "score": max(0.0, min(1.0, final_score)), | |
| "rewards": rewards, | |
| } | |
| def run_all() -> dict[str, float]: | |
| scores: dict[str, float] = {} | |
| for task in TaskRegistry.list_tasks(): | |
| result = run_task(task.task_id, hash(task.task_id) % 10000) | |
| scores[task.task_id] = result["score"] | |
| return scores | |
| if __name__ == "__main__": | |
| print("Available tasks:") | |
| for task in list_tasks(): | |
| print(f" - {task['task_id']}: {task['name']} ({task['difficulty']})") | |
| print("\nRunning all tasks...") | |
| scores = run_all() | |
| print("\nScores:") | |
| for task_id, score in scores.items(): | |
| print(f" {task_id}: {score:.3f}") | |