Spaces:
Sleeping
Sleeping
File size: 3,083 Bytes
4904e85 1bc6b3d 4904e85 fe39ff6 4904e85 1bc6b3d fe39ff6 4904e85 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | """Benchmark module for running 911 dispatch RL tasks."""
from __future__ import annotations
import asyncio
import random
from typing import Any
from src.models import Action, DispatchAction
from src.grading import grade_episode
from src.openenv_environment import OpenEnvEnvironment
from src.tasks.registry import TaskRegistry
def list_tasks() -> list[dict[str, Any]]:
tasks = TaskRegistry.list_tasks()
return [
{"task_id": t.task_id, "name": t.name, "difficulty": t.difficulty}
for t in tasks
]
async def _run_episode_async(task_id: str, seed: int) -> tuple[float, list[float]]:
env = OpenEnvEnvironment(task_id=task_id, seed=seed)
rewards: list[float] = []
final_state = None
try:
await env.reset()
final_state = env.state()
rng = random.Random(seed)
for _ in range(1000):
legal_actions = env.legal_actions()
if legal_actions:
action = rng.choice(legal_actions)
else:
# Fallback: attempt to dispatch the first unit to the first incident.
st = env.state()
if not st.units or not st.incidents:
break
unit_id = next(iter(st.units.keys()))
incident_id = next(iter(st.incidents.keys()))
action = Action(
action_type=DispatchAction.DISPATCH,
unit_id=unit_id,
incident_id=incident_id,
)
obs, reward, done = await env.step(action)
rewards.append(reward)
final_state = env.state()
if done:
break
finally:
env.close()
if final_state is None:
from src.models import State
final_state = State(
units={},
incidents={},
episode_id="",
step_count=0,
task_id=task_id,
city_time=0.0,
metadata={},
)
# Score episodes the same way as the OpenEnv evaluation path.
final_score = grade_episode(task_id=task_id, state=final_state, rewards=rewards)
return final_score, rewards
def run_task(task_id: str, seed: int) -> dict[str, Any]:
TaskRegistry.get(task_id)
final_score, rewards = asyncio.run(_run_episode_async(task_id, seed))
return {
"task_id": task_id,
"seed": seed,
"score": max(0.0, min(1.0, final_score)),
"rewards": rewards,
}
def run_all() -> dict[str, float]:
scores: dict[str, float] = {}
for task in TaskRegistry.list_tasks():
result = run_task(task.task_id, hash(task.task_id) % 10000)
scores[task.task_id] = result["score"]
return scores
if __name__ == "__main__":
print("Available tasks:")
for task in list_tasks():
print(f" - {task['task_id']}: {task['name']} ({task['difficulty']})")
print("\nRunning all tasks...")
scores = run_all()
print("\nScores:")
for task_id, score in scores.items():
print(f" {task_id}: {score:.3f}")
|