Arun-Sanjay commited on
Commit
1091ce2
·
1 Parent(s): bc2df98

Refactor models.py: inherit from openenv Action/Observation/State

Browse files

Adds DispatchPulseAction, DispatchPulseObservation, DispatchPulseState
that inherit directly from openenv.core.env_server.types base classes.
The internal simulation models (EmergencyCall, EmergencyUnit, etc.)
remain plain Pydantic models used inside the engine.

Files changed (1) hide show
  1. models.py +100 -20
models.py CHANGED
@@ -1,4 +1,16 @@
1
- """DispatchPulse data models. All Pydantic v2."""
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
@@ -7,12 +19,96 @@ from typing import List, Optional
7
 
8
  from pydantic import BaseModel, Field
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class Position(BaseModel):
12
  """A 2D coordinate on the city grid (km)."""
13
 
14
- x: float = Field(..., ge=0.0, description="X coordinate in km")
15
- y: float = Field(..., ge=0.0, description="Y coordinate in km")
16
 
17
 
18
  class EmergencyType(str, Enum):
@@ -49,7 +145,7 @@ class UnitStatus(str, Enum):
49
 
50
  class EmergencyCall(BaseModel):
51
  call_id: str
52
- timestamp: int = Field(..., description="Minute of simulation when call arrived")
53
  caller_description: str
54
  location: Position
55
  true_type: EmergencyType
@@ -107,19 +203,3 @@ class Reward(BaseModel):
107
  triage_accuracy: float
108
  penalty: float
109
  details: str = ""
110
-
111
-
112
- class EnvironmentSnapshot(BaseModel):
113
- """Lightweight environment state snapshot for state() / debugging."""
114
-
115
- current_time: int
116
- episode_done: bool
117
- total_calls: int
118
- calls_dispatched: int
119
- calls_completed: int
120
- calls_timed_out: int
121
- calls_pending: int
122
- units_available: int
123
- step_count: int
124
- running_reward: float
125
- task_name: Optional[str] = None
 
1
+ """DispatchPulse data models.
2
+
3
+ Two layers:
4
+
5
+ 1. **OpenEnv interface models** — ``DispatchPulseAction``, ``DispatchPulseObservation``,
6
+ ``DispatchPulseState``. These inherit directly from openenv-core base classes
7
+ and form the wire format the server/client/grader exchange.
8
+
9
+ 2. **Internal simulation models** — ``Position``, ``EmergencyType``, ``Severity``,
10
+ ``UnitType``, ``UnitStatus``, ``EmergencyCall``, ``EmergencyUnit``, ``Hospital``,
11
+ ``WorldConfig``, ``Reward``. These are plain Pydantic models the simulation
12
+ engine uses internally; they never cross the OpenEnv boundary directly.
13
+ """
14
 
15
  from __future__ import annotations
16
 
 
19
 
20
  from pydantic import BaseModel, Field
21
 
22
+ # ---------------------------------------------------------------------------
23
+ # OpenEnv base classes
24
+ # ---------------------------------------------------------------------------
25
+ from openenv.core.env_server.types import Action, Observation, State
26
+
27
+
28
+ # ===========================================================================
29
+ # OpenEnv-facing wire types
30
+ # ===========================================================================
31
+
32
+
33
+ class DispatchPulseAction(Action):
34
+ """A single dispatcher action.
35
+
36
+ The agent supplies ``action_type`` plus optional fields. The simplest
37
+ possible interface for an LLM is the ``text`` field — the server will
38
+ parse it as a command line like ``"dispatch CALL-001 ALS-1 H1"``.
39
+
40
+ Supported action types:
41
+ - ``dispatch`` : send a unit to a call (call_id, unit_id, hospital_id?)
42
+ - ``classify`` : reclassify a call's severity (call_id, severity)
43
+ - ``callback`` : phone the caller back (call_id, message)
44
+ - ``wait`` : skip ahead in simulation time (minutes)
45
+ - ``view`` : free inspection (no time cost)
46
+ """
47
+
48
+ action_type: str = Field(
49
+ ..., description="One of: dispatch, classify, callback, wait, view"
50
+ )
51
+ text: str = Field(
52
+ default="",
53
+ description="Free-text representation of the action (e.g. 'dispatch CALL-001 ALS-1 H1')",
54
+ )
55
+ call_id: Optional[str] = Field(default=None)
56
+ unit_id: Optional[str] = Field(default=None)
57
+ hospital_id: Optional[str] = Field(default=None)
58
+ severity: Optional[int] = Field(default=None, ge=1, le=5)
59
+ message: Optional[str] = Field(default=None)
60
+ minutes: Optional[int] = Field(default=None, ge=1, le=5)
61
+
62
+
63
+ class DispatchPulseObservation(Observation):
64
+ """What the dispatcher sees each turn.
65
+
66
+ The ``text`` field is the human-readable dispatch center view that the
67
+ LLM agent reads. The structured fields underneath are useful for
68
+ programmatic agents and grading.
69
+ """
70
+
71
+ text: str = Field(default="", description="Formatted dispatch center view for the agent")
72
+ current_time: int = Field(default=0, description="Simulation minute")
73
+ time_limit: int = Field(default=30, description="Episode time limit (minutes)")
74
+ calls_pending: int = Field(default=0, description="Number of calls waiting for dispatch")
75
+ units_available: int = Field(default=0, description="Number of free units")
76
+ calls_completed: int = Field(default=0)
77
+ calls_timed_out: int = Field(default=0)
78
+ total_calls: int = Field(default=0)
79
+ last_action_error: Optional[str] = Field(
80
+ default=None, description="Error message from the last action, or None"
81
+ )
82
+ info_message: Optional[str] = Field(
83
+ default=None, description="Free-text message describing what just happened"
84
+ )
85
+
86
+
87
+ class DispatchPulseState(State):
88
+ """Internal state snapshot exposed via ``GET /state``."""
89
+
90
+ current_time: int = Field(default=0)
91
+ episode_done: bool = Field(default=False)
92
+ total_calls: int = Field(default=0)
93
+ calls_dispatched: int = Field(default=0)
94
+ calls_completed: int = Field(default=0)
95
+ calls_timed_out: int = Field(default=0)
96
+ calls_pending: int = Field(default=0)
97
+ units_available: int = Field(default=0)
98
+ running_reward: float = Field(default=0.0)
99
+ task_name: str = Field(default="easy")
100
+
101
+
102
+ # ===========================================================================
103
+ # Internal simulation models (plain Pydantic, never cross OpenEnv boundary)
104
+ # ===========================================================================
105
+
106
 
107
  class Position(BaseModel):
108
  """A 2D coordinate on the city grid (km)."""
109
 
110
+ x: float = Field(..., ge=0.0)
111
+ y: float = Field(..., ge=0.0)
112
 
113
 
114
  class EmergencyType(str, Enum):
 
145
 
146
  class EmergencyCall(BaseModel):
147
  call_id: str
148
+ timestamp: int
149
  caller_description: str
150
  location: Position
151
  true_type: EmergencyType
 
203
  triage_accuracy: float
204
  penalty: float
205
  details: str = ""