Arun-Sanjay commited on
Commit
82ce364
Β·
1 Parent(s): 7c367b4

Fix Phase 2: add GET /tasks and POST /grader endpoints

Browse files

Phase 2 validator requires the environment to declare >=3 tasks with
graders that the grader service can discover. The previous create_app()
setup only exposed generic /reset /step /state routes, so the validator
reported 'Not enough tasks with graders'.

Adds three DispatchPulse-specific endpoints on top of create_app:
- GET /tasks list all 3 tasks with metadata
- GET /tasks/{task_id} single task metadata
- POST /grader score an episode (two modes:
1) silent run of the given task
2) replay a scripted action log)

Also declares the 3 tasks explicitly in openenv.yaml so static
validators that parse the manifest can discover them too.

Files changed (2) hide show
  1. openenv.yaml +29 -0
  2. server/app.py +223 -12
openenv.yaml CHANGED
@@ -4,3 +4,32 @@ type: space
4
  runtime: fastapi
5
  app: server.app:app
6
  port: 8000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  runtime: fastapi
5
  app: server.app:app
6
  port: 8000
7
+
8
+ # Graded tasks β€” each has a deterministic grader returning a score in [0.0, 1.0]
9
+ tasks:
10
+ - id: easy
11
+ name: easy
12
+ difficulty: easy
13
+ description: >
14
+ Routine urban shift. Five calls over 30 minutes, four units (ALS, BLS,
15
+ fire engine, police), one well-equipped hospital. Callers report
16
+ accurately. Optimal play scores ~0.85+.
17
+ has_grader: true
18
+
19
+ - id: medium
20
+ name: medium
21
+ difficulty: medium
22
+ description: >
23
+ Urban scenario. 15 calls in 45 minutes, 6 units, 2 hospitals. Mass
24
+ casualty bus accident at minute 12 and 20% caller inaccuracy.
25
+ Reasonable play scores ~0.55-0.70.
26
+ has_grader: true
27
+
28
+ - id: hard
29
+ name: hard
30
+ difficulty: hard
31
+ description: >
32
+ Earthquake response scenario. 30 calls in 60 minutes, 8 units,
33
+ 3 hospitals (one on diversion). 35% caller misreporting due to panic.
34
+ Strong play scores ~0.40-0.55.
35
+ has_grader: true
server/app.py CHANGED
@@ -1,20 +1,24 @@
1
  """FastAPI application for DispatchPulse.
2
 
3
- Uses ``create_app(env_factory, ActionCls, ObservationCls)`` from
4
- openenv-core's HTTP server. This wrapper:
5
-
6
- - When ``ENABLE_WEB_INTERFACE=true`` (set in the Dockerfile): serves the
7
- OpenEnv Gradio web UI at ``/`` so judges visiting the Space in a browser
8
- see a friendly project page.
9
- - Always registers the standard ``/reset``, ``/step``, ``/state``,
10
- ``/health``, ``/metadata``, ``/schema``, and ``/ws`` routes β€” these are
11
- what the hackathon grader actually hits.
12
  """
13
 
14
  from __future__ import annotations
15
 
16
  import os
17
  import sys
 
 
 
 
18
 
19
  # Support both in-repo and standalone imports.
20
  try:
@@ -25,15 +29,18 @@ except ImportError: # pragma: no cover
25
  from openenv.core.env_server.http_server import create_app
26
  from server.environment import DispatchPulseEnvironment
27
 
28
- # Import the typed Action / Observation classes from the project root models.py
29
  _PKG_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
30
  if _PKG_ROOT not in sys.path:
31
  sys.path.insert(0, _PKG_ROOT)
32
 
33
  from models import DispatchPulseAction, DispatchPulseObservation # noqa: E402
 
 
 
 
34
 
35
- # Pass the class (factory) so each session gets its own env instance.
36
- # ``env_name`` controls the web UI title and README lookup.
37
  app = create_app(
38
  DispatchPulseEnvironment,
39
  DispatchPulseAction,
@@ -43,6 +50,210 @@ app = create_app(
43
  )
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def main() -> None:
47
  """Entry point for ``uv run server`` or direct execution."""
48
  import uvicorn
 
1
  """FastAPI application for DispatchPulse.
2
 
3
+ Uses ``create_app(...)`` from openenv-core for the standard ``/reset``,
4
+ ``/step``, ``/state``, ``/health``, ``/metadata``, ``/schema``, ``/ws`` routes
5
+ plus the Gradio UI at ``/`` (when ``ENABLE_WEB_INTERFACE=true``).
6
+
7
+ On top of that baseline we add two DispatchPulse-specific endpoints the
8
+ hackathon grader discovers:
9
+
10
+ - ``GET /tasks`` β€” list the 3 graded tasks with metadata
11
+ - ``POST /grader`` β€” score an episode or explicit call log against a task
12
  """
13
 
14
  from __future__ import annotations
15
 
16
  import os
17
  import sys
18
+ from typing import Any, Dict, List, Optional
19
+
20
+ from fastapi import HTTPException
21
+ from pydantic import BaseModel, Field
22
 
23
  # Support both in-repo and standalone imports.
24
  try:
 
29
  from openenv.core.env_server.http_server import create_app
30
  from server.environment import DispatchPulseEnvironment
31
 
32
+ # Import project root modules
33
  _PKG_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
34
  if _PKG_ROOT not in sys.path:
35
  sys.path.insert(0, _PKG_ROOT)
36
 
37
  from models import DispatchPulseAction, DispatchPulseObservation # noqa: E402
38
+ from grader import grade_simulation # noqa: E402
39
+ from reward import calculate_episode_reward # noqa: E402
40
+ from scenario_loader import VALID_TASKS, load_scenario # noqa: E402
41
+ from simulation import DispatchSimulation # noqa: E402
42
 
43
+ # Create the standard OpenEnv app (Gradio UI + HTTP API routes).
 
44
  app = create_app(
45
  DispatchPulseEnvironment,
46
  DispatchPulseAction,
 
50
  )
51
 
52
 
53
+ # ---------------------------------------------------------------------------
54
+ # Task catalog β€” 3 graded tasks with metadata for GET /tasks
55
+ # ---------------------------------------------------------------------------
56
+
57
+
58
+ class TaskInfo(BaseModel):
59
+ """Metadata for a single graded task."""
60
+
61
+ task_id: str
62
+ name: str
63
+ difficulty: str = Field(..., description="easy | medium | hard")
64
+ description: str
65
+ max_steps: int
66
+ time_limit_minutes: int
67
+ num_calls: int
68
+ num_units: int
69
+ num_hospitals: int
70
+ caller_inaccuracy: float
71
+ has_grader: bool = True
72
+
73
+
74
+ class TaskListResponse(BaseModel):
75
+ """Response for GET /tasks."""
76
+
77
+ tasks: List[TaskInfo]
78
+ count: int
79
+
80
+
81
+ def _task_info(task_id: str) -> TaskInfo:
82
+ scenario = load_scenario(task_id)
83
+ world_cfg = scenario.get("world_config", {}) or {}
84
+ return TaskInfo(
85
+ task_id=task_id,
86
+ name=scenario.get("name", task_id),
87
+ difficulty=task_id,
88
+ description=(scenario.get("description") or "").strip(),
89
+ max_steps=int(world_cfg.get("time_limit_minutes", 30)),
90
+ time_limit_minutes=int(world_cfg.get("time_limit_minutes", 30)),
91
+ num_calls=len(scenario.get("calls", [])),
92
+ num_units=len(scenario.get("units", [])),
93
+ num_hospitals=len(scenario.get("hospitals", [])),
94
+ caller_inaccuracy=float(scenario.get("caller_inaccuracy", 0.0)),
95
+ has_grader=True,
96
+ )
97
+
98
+
99
+ @app.get("/tasks", tags=["DispatchPulse"], response_model=TaskListResponse)
100
+ def list_tasks() -> TaskListResponse:
101
+ """Return the full list of graded tasks.
102
+
103
+ DispatchPulse ships with exactly three deterministic tasks β€” ``easy``,
104
+ ``medium``, ``hard`` β€” each with its own grader that returns a score in
105
+ [0.0, 1.0] at episode end.
106
+ """
107
+ infos = [_task_info(t) for t in VALID_TASKS]
108
+ return TaskListResponse(tasks=infos, count=len(infos))
109
+
110
+
111
+ @app.get("/tasks/{task_id}", tags=["DispatchPulse"], response_model=TaskInfo)
112
+ def get_task(task_id: str) -> TaskInfo:
113
+ """Return metadata for a single task by id."""
114
+ if task_id not in VALID_TASKS:
115
+ raise HTTPException(
116
+ status_code=404,
117
+ detail=f"unknown task_id '{task_id}' (valid: {', '.join(VALID_TASKS)})",
118
+ )
119
+ return _task_info(task_id)
120
+
121
+
122
+ # ---------------------------------------------------------------------------
123
+ # Grader β€” POST /grader
124
+ # ---------------------------------------------------------------------------
125
+
126
+
127
+ class GraderRequest(BaseModel):
128
+ """Request body for POST /grader.
129
+
130
+ Provide either an ``episode_id`` (to grade a live episode that's already
131
+ been run) or an explicit ``task_id`` + action log (to re-run and grade a
132
+ scripted episode without needing any server-side state).
133
+ """
134
+
135
+ task_id: Optional[str] = Field(
136
+ default=None, description="One of: easy | medium | hard"
137
+ )
138
+ seed: int = Field(default=42, description="Random seed for reproducibility")
139
+ actions: Optional[List[Dict[str, Any]]] = Field(
140
+ default=None,
141
+ description=(
142
+ "Ordered list of actions to replay (each item has "
143
+ "action_type and any required args). When omitted, the grader "
144
+ "scores the simulation as-is at its current state."
145
+ ),
146
+ )
147
+
148
+
149
+ class GraderResult(BaseModel):
150
+ """Response from POST /grader."""
151
+
152
+ task_id: str
153
+ score: float = Field(..., ge=0.0, le=1.0)
154
+ passed: bool
155
+ details: str
156
+ survival_score: float
157
+ efficiency_score: float
158
+ triage_accuracy: float
159
+ penalty: float
160
+ completed_calls: int
161
+ timed_out_calls: int
162
+ total_calls: int
163
+
164
+
165
+ def _replay_actions(sim: DispatchSimulation, actions: List[Dict[str, Any]]) -> None:
166
+ """Replay a scripted action list through a fresh simulation."""
167
+ max_steps = 500
168
+ for idx, act in enumerate(actions):
169
+ if idx >= max_steps or sim.episode_done:
170
+ break
171
+ atype = (act.get("action_type") or "").strip().lower()
172
+ if atype == "dispatch":
173
+ sim.dispatch(
174
+ call_id=str(act.get("call_id", "")),
175
+ unit_id=str(act.get("unit_id", "")),
176
+ hospital_id=act.get("hospital_id"),
177
+ )
178
+ sim.advance_time(1)
179
+ elif atype == "classify":
180
+ try:
181
+ sev = int(act.get("severity", 3))
182
+ except (TypeError, ValueError):
183
+ sev = 3
184
+ sim.classify(str(act.get("call_id", "")), sev)
185
+ sim.advance_time(1)
186
+ elif atype == "callback":
187
+ sim.callback(
188
+ str(act.get("call_id", "")),
189
+ str(act.get("message", act.get("question", ""))),
190
+ )
191
+ sim.advance_time(1)
192
+ elif atype == "wait":
193
+ try:
194
+ mins = int(act.get("minutes", 1))
195
+ except (TypeError, ValueError):
196
+ mins = 1
197
+ sim.advance_time(max(1, min(mins, sim.config.max_wait_step_minutes)))
198
+ elif atype == "view":
199
+ continue
200
+ else:
201
+ sim.advance_time(1)
202
+
203
+ # If we ran out of actions before the episode ended, fast-forward the
204
+ # clock so all remaining calls time out and the episode terminates.
205
+ while not sim.episode_done:
206
+ sim.advance_time(sim.config.time_limit_minutes)
207
+
208
+
209
+ @app.post("/grader", tags=["DispatchPulse"], response_model=GraderResult)
210
+ def grade_task(payload: GraderRequest) -> GraderResult:
211
+ """Run the grader for a task.
212
+
213
+ Two modes:
214
+ 1. ``task_id`` only β†’ score a silent run (all calls timeout) as a
215
+ sanity check that the task loads and has a valid grader.
216
+ 2. ``task_id + actions`` β†’ replay the scripted action log then score.
217
+ """
218
+ task_id = (payload.task_id or "easy").strip().lower()
219
+ if task_id not in VALID_TASKS:
220
+ raise HTTPException(
221
+ status_code=404,
222
+ detail=f"unknown task_id '{task_id}' (valid: {', '.join(VALID_TASKS)})",
223
+ )
224
+
225
+ scenario = load_scenario(task_id)
226
+ sim = DispatchSimulation(scenario, seed=int(payload.seed))
227
+
228
+ if payload.actions:
229
+ _replay_actions(sim, payload.actions)
230
+ else:
231
+ # No actions provided: run the episode to completion with no decisions.
232
+ while not sim.episode_done:
233
+ sim.advance_time(sim.config.time_limit_minutes)
234
+
235
+ reward = calculate_episode_reward(
236
+ sim.completed_calls,
237
+ sim.timed_out_calls,
238
+ sim.total_calls(),
239
+ sim.dispatches,
240
+ )
241
+
242
+ return GraderResult(
243
+ task_id=task_id,
244
+ score=reward.total,
245
+ passed=reward.total >= 0.20,
246
+ details=reward.details,
247
+ survival_score=reward.survival_score,
248
+ efficiency_score=reward.efficiency_score,
249
+ triage_accuracy=reward.triage_accuracy,
250
+ penalty=reward.penalty,
251
+ completed_calls=len(sim.completed_calls),
252
+ timed_out_calls=len(sim.timed_out_calls),
253
+ total_calls=sim.total_calls(),
254
+ )
255
+
256
+
257
  def main() -> None:
258
  """Entry point for ``uv run server`` or direct execution."""
259
  import uvicorn