File size: 3,083 Bytes
4904e85
 
 
 
 
 
 
 
 
1bc6b3d
4904e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe39ff6
4904e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bc6b3d
 
fe39ff6
4904e85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Benchmark module for running 911 dispatch RL tasks."""

from __future__ import annotations

import asyncio
import random
from typing import Any

from src.models import Action, DispatchAction
from src.grading import grade_episode
from src.openenv_environment import OpenEnvEnvironment
from src.tasks.registry import TaskRegistry


def list_tasks() -> list[dict[str, Any]]:
    tasks = TaskRegistry.list_tasks()
    return [
        {"task_id": t.task_id, "name": t.name, "difficulty": t.difficulty}
        for t in tasks
    ]


async def _run_episode_async(task_id: str, seed: int) -> tuple[float, list[float]]:
    env = OpenEnvEnvironment(task_id=task_id, seed=seed)
    rewards: list[float] = []
    final_state = None

    try:
        await env.reset()
        final_state = env.state()

        rng = random.Random(seed)
        for _ in range(1000):
            legal_actions = env.legal_actions()
            if legal_actions:
                action = rng.choice(legal_actions)
            else:
                # Fallback: attempt to dispatch the first unit to the first incident.
                st = env.state()
                if not st.units or not st.incidents:
                    break
                unit_id = next(iter(st.units.keys()))
                incident_id = next(iter(st.incidents.keys()))
                action = Action(
                    action_type=DispatchAction.DISPATCH,
                    unit_id=unit_id,
                    incident_id=incident_id,
                )

            obs, reward, done = await env.step(action)
            rewards.append(reward)

            final_state = env.state()

            if done:
                break
    finally:
        env.close()

    if final_state is None:
        from src.models import State

        final_state = State(
            units={},
            incidents={},
            episode_id="",
            step_count=0,
            task_id=task_id,
            city_time=0.0,
            metadata={},
        )

    # Score episodes the same way as the OpenEnv evaluation path.
    final_score = grade_episode(task_id=task_id, state=final_state, rewards=rewards)
    return final_score, rewards


def run_task(task_id: str, seed: int) -> dict[str, Any]:
    TaskRegistry.get(task_id)
    final_score, rewards = asyncio.run(_run_episode_async(task_id, seed))
    return {
        "task_id": task_id,
        "seed": seed,
        "score": max(0.0, min(1.0, final_score)),
        "rewards": rewards,
    }


def run_all() -> dict[str, float]:
    scores: dict[str, float] = {}
    for task in TaskRegistry.list_tasks():
        result = run_task(task.task_id, hash(task.task_id) % 10000)
        scores[task.task_id] = result["score"]
    return scores


if __name__ == "__main__":
    print("Available tasks:")
    for task in list_tasks():
        print(f"  - {task['task_id']}: {task['name']} ({task['difficulty']})")
    print("\nRunning all tasks...")
    scores = run_all()
    print("\nScores:")
    for task_id, score in scores.items():
        print(f"  {task_id}: {score:.3f}")