aparekh02 commited on
Commit
cb054fe
Β·
verified Β·
1 Parent(s): 1c6e7aa

initial push: overflow_env with Gradio RL demo UI

Browse files
DESIGN.md ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Overflow Environment β€” Low-Level Design Document
2
+
3
+ ## Table of Contents
4
+
5
+ 1. [Architecture Overview](#1-architecture-overview)
6
+ 2. [File-by-File Breakdown](#2-file-by-file-breakdown)
7
+ 3. [Data Models (Wire Format)](#3-data-models-wire-format)
8
+ 4. [Simulation Internals](#4-simulation-internals)
9
+ 5. [Step-by-Step Execution Pipeline](#5-step-by-step-execution-pipeline)
10
+ 6. [Distance and Collision Model](#6-distance-and-collision-model)
11
+ 7. [Reward Function β€” Complete Breakdown](#7-reward-function--complete-breakdown)
12
+ 8. [Scripted Car AI](#8-scripted-car-ai)
13
+ 9. [Action Parsing β€” How LLM Output Becomes a Decision](#9-action-parsing--how-llm-output-becomes-a-decision)
14
+ 10. [Observation Text Format](#10-observation-text-format)
15
+ 11. [Server Protocol β€” What Training Scripts Must Send](#11-server-protocol--what-training-scripts-must-send)
16
+ 12. [Training Integration β€” GRPO / TRL](#12-training-integration--grpo--trl)
17
+ 13. [Episode Dynamics and RL Characteristics](#13-episode-dynamics-and-rl-characteristics)
18
+ 14. [Configuration Constants](#14-configuration-constants)
19
+ 15. [Docker and Deployment](#15-docker-and-deployment)
20
+
21
+ ---
22
+
23
+ ## 1. Architecture Overview
24
+
25
+ ```
26
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
27
+ β”‚ Training Script (GRPO) β”‚
28
+ β”‚ calls reset(), reads observation, calls step(action) β”‚
29
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
30
+ β”‚ WebSocket (persistent session)
31
+ β”‚ JSON messages over ws://host:8000/ws
32
+ β–Ό
33
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
34
+ β”‚ FastAPI Server (app.py) β”‚
35
+ β”‚ create_app(OverflowEnvironment, OverflowAction, β”‚
36
+ β”‚ OverflowObservation) β”‚
37
+ β”‚ β”‚
38
+ β”‚ Endpoints: β”‚
39
+ β”‚ WS /ws ← primary (stateful session) β”‚
40
+ β”‚ POST /reset ← HTTP fallback β”‚
41
+ β”‚ POST /step ← HTTP fallback β”‚
42
+ β”‚ GET /state ← HTTP fallback β”‚
43
+ β”‚ GET /health ← health check β”‚
44
+ β”‚ GET /schema ← JSON schemas for action/obs/state β”‚
45
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
46
+ β”‚
47
+ β–Ό
48
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
49
+ β”‚ OverflowEnvironment (pure Python) β”‚
50
+ β”‚ β”‚
51
+ β”‚ Internal state: β”‚
52
+ β”‚ _cars: List[Car] (5 cars, car 0 = agent) β”‚
53
+ β”‚ _state: OverflowState (episode tracking) β”‚
54
+ β”‚ _rng: random.Random (seeded per episode) β”‚
55
+ β”‚ _done: bool β”‚
56
+ β”‚ β”‚
57
+ β”‚ Methods: β”‚
58
+ β”‚ reset(seed, episode_id) β†’ OverflowObservation β”‚
59
+ β”‚ step(OverflowAction) β†’ OverflowObservation β”‚
60
+ β”‚ state (property) β†’ OverflowState β”‚
61
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
62
+ ```
63
+
64
+ **Key invariant**: The training loop calls `reset()`. The LLM agent only calls `step()` via the training harness. Agents can never reset β€” if they could undo consequences, training breaks.
65
+
66
+ **Session model**: Each WebSocket connection gets its own `OverflowEnvironment` instance. The `create_app` function receives the class (factory), not an instance. When a WebSocket connects, the server instantiates a fresh environment for that session.
67
+
68
+ ---
69
+
70
+ ## 2. File-by-File Breakdown
71
+
72
+ ### `models.py` β€” Pydantic data models
73
+
74
+ Defines three classes inheriting from OpenEnv core types:
75
+
76
+ | Class | Parent | Purpose |
77
+ |-------|--------|---------|
78
+ | `OverflowAction(Action)` | `openenv.core.env_server.types.Action` | What the LLM sends each step |
79
+ | `OverflowObservation(Observation)` | `openenv.core.env_server.types.Observation` | What the environment returns |
80
+ | `OverflowState(State)` | `openenv.core.env_server.types.State` | Internal state exposed via `/state` |
81
+
82
+ All three are Pydantic `BaseModel` subclasses. The parent classes provide `metadata: Dict[str, Any]` (on Action and Observation) and `episode_id: str`, `step_count: int` (on State). The parent `Observation` provides `done: bool` and `reward: float | None`.
83
+
84
+ ### `server/overflow_environment.py` β€” All game logic
85
+
86
+ Contains:
87
+ - `Car` dataclass β€” per-car state (id, lane, position, speed, goal, is_agent, reached_goal)
88
+ - `_parse_decision()` β€” tolerant action parser
89
+ - `_compute_reasoning_bonus()` β€” reasoning quality scorer
90
+ - `_scripted_car_action()` β€” NPC car AI
91
+ - `_apply_action()` β€” mutates a car's speed/lane
92
+ - `_generate_scene_description()` β€” builds the text observation
93
+ - `OverflowEnvironment(Environment)` β€” the main class with `reset()`, `step()`, `state`
94
+
95
+ ### `server/app.py` β€” FastAPI wiring
96
+
97
+ Introspects `create_app` to determine if it expects a factory (class) or an instance. Passes `OverflowEnvironment`, `OverflowAction`, `OverflowObservation` to `create_app`. The resulting `app` object is what uvicorn serves.
98
+
99
+ ### `client.py` β€” WebSocket client
100
+
101
+ `OverflowEnv(EnvClient[OverflowAction, OverflowObservation, OverflowState])` with three required methods:
102
+ - `_step_payload(action)` β€” serializes `OverflowAction` to `{"decision": ..., "reasoning": ...}`
103
+ - `_parse_result(payload)` β€” deserializes server JSON into `StepResult[OverflowObservation]`
104
+ - `_parse_state(payload)` β€” deserializes server JSON into `OverflowState`
105
+
106
+ ### `__init__.py` β€” Public API
107
+
108
+ Exports: `OverflowAction`, `OverflowObservation`, `OverflowState`, `OverflowEnv`.
109
+
110
+ ---
111
+
112
+ ## 3. Data Models (Wire Format)
113
+
114
+ ### OverflowAction β€” What the training script sends to `/step`
115
+
116
+ ```json
117
+ {
118
+ "action": {
119
+ "decision": "brake",
120
+ "reasoning": "Car 3 is ahead in my lane, 15 units away, going slower. I should brake."
121
+ }
122
+ }
123
+ ```
124
+
125
+ | Field | Type | Required | Default | Description |
126
+ |-------|------|----------|---------|-------------|
127
+ | `decision` | `str` | No | `"maintain"` | One of: `accelerate`, `brake`, `lane_change_left`, `lane_change_right`, `maintain` |
128
+ | `reasoning` | `str` | No | `""` | Free-text chain-of-thought. Affects reward via reasoning bonus (0.0–2.0). |
129
+
130
+ The `decision` field is parsed tolerantly β€” see Section 9.
131
+
132
+ ### OverflowObservation β€” What the server returns
133
+
134
+ Each observation carries **both** text (for the LLM) and structured data (for the frontend/viz).
135
+
136
+ ```json
137
+ {
138
+ "observation": {
139
+ "scene_description": "You are Car 0 in lane 2, position 45, speed 60.\n...",
140
+ "incident_report": "Observer: No incidents this step.",
141
+ "done": false,
142
+ "reward": 1.45,
143
+ "cars": [
144
+ {"carId": 0, "lane": 2, "position": {"x": 45.0, "y": 7.4}, "speed": 60.0, "acceleration": 5.0},
145
+ {"carId": 1, "lane": 1, "position": {"x": 43.0, "y": 3.7}, "speed": 55.0, "acceleration": 0.0}
146
+ ],
147
+ "proximities": [
148
+ {"carA": 0, "carB": 1, "distance": 10.5}
149
+ ],
150
+ "lane_occupancies": [
151
+ {"lane": 1, "carIds": [1]},
152
+ {"lane": 2, "carIds": [0]}
153
+ ],
154
+ "metadata": {}
155
+ },
156
+ "reward": 1.45,
157
+ "done": false
158
+ }
159
+ ```
160
+
161
+ #### Text fields (for the LLM)
162
+
163
+ | Field | Type | Description |
164
+ |-------|------|-------------|
165
+ | `scene_description` | `str` | Multi-line text describing all cars. This is what the LLM reads. |
166
+ | `incident_report` | `str` | Observer output. Either `"Observer: No incidents this step."` or a list of CRASH/NEAR MISS events. |
167
+
168
+ #### Structured fields (for the frontend β€” compatible with Overflow frontend types)
169
+
170
+ | Field | Type | Frontend equivalent |
171
+ |-------|------|---------------------|
172
+ | `cars` | `CarStateData[]` | `CarState[]` β€” `{carId, lane, position: {x, y}, speed, acceleration}` |
173
+ | `proximities` | `ProximityData[]` | `{carA, carB, distance}[]` β€” pairwise distances for close cars |
174
+ | `lane_occupancies` | `LaneOccupancyData[]` | `{lane, carIds}[]` β€” which cars are in each lane |
175
+
176
+ Position `y` is computed as `lane * 3.7` (lane width in metres), matching the frontend's `makeCar` convention.
177
+
178
+ #### Common fields
179
+
180
+ | Field | Type | Description |
181
+ |-------|------|-------------|
182
+ | `done` | `bool` | `true` if episode ended (crash, goal reached, or max steps). |
183
+ | `reward` | `float` | Scalar reward for this step. Sum of all reward components. |
184
+
185
+ The `reward` and `done` appear both inside `observation` and at the top level of the response (OpenEnv convention).
186
+
187
+ ### OverflowState β€” What `/state` returns
188
+
189
+ ```json
190
+ {
191
+ "episode_id": "a1b2c3d4-...",
192
+ "step_count": 17,
193
+ "crash_count": 0,
194
+ "near_miss_count": 23,
195
+ "cars_reached_goal": 1,
196
+ "total_cars": 5
197
+ }
198
+ ```
199
+
200
+ | Field | Type | Description |
201
+ |-------|------|-------------|
202
+ | `episode_id` | `str` | UUID for this episode. Set on `reset()`. |
203
+ | `step_count` | `int` | How many `step()` calls have been made. |
204
+ | `crash_count` | `int` | Cumulative crash events (each pair counts as 1). |
205
+ | `near_miss_count` | `int` | Cumulative near-miss events (each pair counts as 1). |
206
+ | `cars_reached_goal` | `int` | How many cars (including scripted) reached their goal. |
207
+ | `total_cars` | `int` | Always 5. |
208
+
209
+ ---
210
+
211
+ ## 4. Simulation Internals
212
+
213
+ ### The Road
214
+
215
+ - 3 lanes, numbered 1, 2, 3 (1 = leftmost, 3 = rightmost)
216
+ - Road length: ~200 position units
217
+ - No wrapping β€” cars move forward from low positions toward high positions
218
+ - Lanes are conceptually 10 units apart for distance calculations
219
+
220
+ ### Car State
221
+
222
+ Each car is a `Car` dataclass:
223
+
224
+ ```python
225
+ @dataclass
226
+ class Car:
227
+ car_id: int # 0 = agent, 1–4 = scripted
228
+ lane: int # 1, 2, or 3
229
+ position: float # 0.0 to ~200.0 (along the road)
230
+ speed: float # 20.0 to 90.0
231
+ goal_position: float # 160.0 to 195.0
232
+ is_agent: bool # True only for car 0
233
+ reached_goal: bool # True once position >= goal_position
234
+ ```
235
+
236
+ ### Initialization (reset)
237
+
238
+ On `reset(seed=N)`:
239
+ 1. A `random.Random(seed)` RNG is created (deterministic replays if same seed).
240
+ 2. 5 cars are spawned:
241
+ - **Lane**: random 1–3
242
+ - **Position**: random 10–80 (spread across the first half of the road)
243
+ - **Speed**: random 40–70
244
+ - **Goal**: random 160–195
245
+ 3. No two cars occupy the same 10-unit segment in the same lane at spawn (deconflicted via `(lane, position // 10)` hash).
246
+ 4. Car 0 is the agent. Cars 1–4 are scripted.
247
+
248
+ ### Movement
249
+
250
+ Each step, every active (non-goal-reached) car moves forward:
251
+
252
+ ```
253
+ car.position += car.speed * 0.1
254
+ ```
255
+
256
+ This means a car at speed 60 moves 6.0 units per step. At that rate, traversing the ~120-unit gap from starting zone (10–80) to goal zone (160–195) takes roughly 20 steps. Faster cars (speed 90) move 9.0 units/step and reach goals sooner.
257
+
258
+ ---
259
+
260
+ ## 5. Step-by-Step Execution Pipeline
261
+
262
+ When `step(action)` is called, the following happens **in this exact order**:
263
+
264
+ ```
265
+ 1. GUARD: if episode is already done β†’ return stale observation with reward=0.0
266
+ 2. INCREMENT step_count
267
+ 3. PARSE the agent's action β†’ one of {accelerate, brake, lane_change_left, lane_change_right, maintain}
268
+ 4. APPLY action to Car 0 (mutate speed or lane)
269
+ 5. COMPUTE scripted actions for Cars 1–4 and APPLY them
270
+ 6. MOVE all active cars forward: position += speed * 0.1
271
+ 7. COLLISION DETECTION (pairwise over all active cars):
272
+ - distance < 5.0 β†’ CRASH (reward -5.0, episode ends)
273
+ - distance < 15.0 β†’ NEAR MISS (reward -1.0 per pair)
274
+ 8. If no crash:
275
+ a. Check if Car 0 reached its goal β†’ reward +3.0, episode ends
276
+ b. Check if scripted cars reached their goals (state tracking only)
277
+ c. If episode not ending β†’ SAFE STEP bonus: reward +0.5
278
+ 9. REASONING BONUS: score the reasoning text β†’ reward +0.0 to +2.0
279
+ 10. MAX STEPS CHECK: if step_count >= 100 β†’ episode ends
280
+ 11. BUILD observation text and incident report
281
+ 12. RETURN OverflowObservation(scene_description, incident_report, done, reward)
282
+ ```
283
+
284
+ **Important ordering detail**: Actions are applied (step 4–5) **before** movement (step 6). This means the agent's speed/lane change takes effect for this step's movement. Collision detection (step 7) happens **after** movement, on the new positions.
285
+
286
+ **Reward accumulation within a step**: A single step's reward is the **sum** of all applicable components. For example, if there are 2 near-miss pairs and the agent is still alive with good reasoning, the reward could be: `(-1.0 * 2) + 0.5 + 1.5 = -1.0`.
287
+
288
+ ---
289
+
290
+ ## 6. Distance and Collision Model
291
+
292
+ Distance between two cars uses a weighted Euclidean formula:
293
+
294
+ ```python
295
+ def distance_to(self, other):
296
+ lane_diff = abs(self.lane - other.lane) * 10.0
297
+ pos_diff = abs(self.position - other.position)
298
+ return sqrt(lane_diff**2 + pos_diff**2)
299
+ ```
300
+
301
+ **Implications**:
302
+ - Two cars in the **same lane** at positions 45 and 50: distance = 5.0 (exactly at crash threshold)
303
+ - Two cars in **adjacent lanes** (e.g., lane 1 and lane 2) at the same position: distance = 10.0 (near miss, not crash)
304
+ - Two cars **two lanes apart** at the same position: distance = 20.0 (safe, no incident)
305
+ - Two cars in adjacent lanes, 10 units apart longitudinally: distance = sqrt(100 + 100) β‰ˆ 14.1 (near miss)
306
+
307
+ **Key insight for the agent**: Lane changes provide safety via the 10-unit lane multiplier. Staying in the same lane as another car is the primary crash risk. The agent should use lane changes proactively to maintain distance from cars in its lane.
308
+
309
+ ### Collision detection scope
310
+
311
+ Detection is **pairwise over ALL active cars**, not just agent-involving pairs. If Car 2 and Car 3 crash, the episode still ends with -5.0 reward. This means the agent is implicitly responsible for the overall traffic flow β€” it should avoid creating situations where its actions cause chain reactions among scripted cars.
312
+
313
+ ---
314
+
315
+ ## 7. Reward Function β€” Complete Breakdown
316
+
317
+ ### Per-step reward components
318
+
319
+ | Component | Value | Condition | Stacks? |
320
+ |-----------|-------|-----------|---------|
321
+ | **Crash** | -5.0 | Any pair distance < 5.0 | Once (episode ends) |
322
+ | **Near miss** | -1.0 | Per pair with distance < 15.0 | Yes, per pair (can be -2.0, -3.0, etc.) |
323
+ | **Safe step** | +0.5 | No crash and episode not ending this step | Once per step |
324
+ | **Goal reached** | +3.0 | Car 0's position >= goal_position | Once (episode ends) |
325
+ | **Reasoning bonus** | +0.0 to +2.0 | Based on reasoning text quality | Once per step |
326
+
327
+ ### Reasoning bonus scoring
328
+
329
+ The bonus has three sub-components capped at 2.0 total:
330
+
331
+ **Length bonus** (up to 0.5):
332
+ - `len > 20` chars β†’ +0.2
333
+ - `len > 50` chars β†’ +0.15
334
+ - `len > 100` chars β†’ +0.15
335
+
336
+ **Keyword awareness** (up to 1.0):
337
+ Each keyword found β†’ +0.2, capped at 1.0. Keywords: `ahead`, `behind`, `lane`, `speed`, `distance`, `safe`, `danger`, `collision`, `brake`, `gap`, `close`, `slow`, `fast`, `goal`, `position`.
338
+
339
+ **Structure bonus** (up to 0.5):
340
+ - Contains `<think>` or `because` β†’ +0.25
341
+ - Contains `therefore`, `so i should`, `best option`, or `i will` β†’ +0.25
342
+
343
+ ### Typical reward ranges per step
344
+
345
+ | Scenario | Typical reward |
346
+ |----------|---------------|
347
+ | Safe step, no reasoning | +0.5 |
348
+ | Safe step, decent reasoning | +1.0 to +2.0 |
349
+ | Safe step, excellent reasoning | +2.0 to +2.5 |
350
+ | 1 near miss, decent reasoning | -0.5 to +0.5 |
351
+ | 2 near misses, decent reasoning | -1.5 to -0.5 |
352
+ | Crash (any) | -5.0 + reasoning bonus |
353
+ | Goal reached, good reasoning | +3.0 + reasoning bonus |
354
+
355
+ ### Episode return (total reward) characteristics
356
+
357
+ Based on testing with seed=42:
358
+ - A "maintain" strategy with decent reasoning gets ~1.1 per step Γ— ~17 steps β‰ˆ 18.7 total, minus near-miss penalties
359
+ - Aggressive "accelerate" strategies reach the goal faster but accumulate more near misses
360
+ - Smart strategies that use lane changes and braking to avoid near misses can maximize total reward
361
+
362
+ ---
363
+
364
+ ## 8. Scripted Car AI
365
+
366
+ Cars 1–4 use `_scripted_car_action(car, all_cars, rng)`:
367
+
368
+ ```
369
+ 1. Find the nearest car AHEAD in the SAME LANE
370
+ 2. If that car is < 20 units ahead β†’ "brake"
371
+ 3. Else if speed < 60 and 10% random chance β†’ "accelerate"
372
+ 4. Else if 5% random chance β†’ lane change (random left/right, respecting boundaries)
373
+ 5. Else β†’ "maintain"
374
+ ```
375
+
376
+ **Characteristics**:
377
+ - Scripted cars are mostly passive β€” they maintain speed
378
+ - They brake reactively when blocked (but only for same-lane, ahead)
379
+ - They rarely change lanes (5% per step), making their behavior somewhat predictable
380
+ - They never intentionally avoid the agent β€” only react to cars directly ahead
381
+ - They can accumulate near misses and crashes among themselves
382
+
383
+ This creates an environment where a smart agent can learn to navigate around largely predictable but occasionally erratic traffic.
384
+
385
+ ---
386
+
387
+ ## 9. Action Parsing β€” How LLM Output Becomes a Decision
388
+
389
+ The parser `_parse_decision(action)` is intentionally forgiving. It tries three strategies in order:
390
+
391
+ ### Strategy 1: Direct field match
392
+ ```python
393
+ decision = action.decision.strip().lower().replace(" ", "_")
394
+ # If it's one of {accelerate, brake, lane_change_left, lane_change_right, maintain} β†’ use it
395
+ ```
396
+
397
+ ### Strategy 2: XML tag extraction
398
+ ```python
399
+ text = f"{action.decision} {action.reasoning}".lower()
400
+ match = re.search(r"<action>\s*(\w+)\s*</action>", text)
401
+ # If found and valid β†’ use it
402
+ ```
403
+
404
+ This handles LLM outputs like:
405
+ ```
406
+ decision: "think about it"
407
+ reasoning: "<think>Car ahead is close</think><action>brake</action>"
408
+ ```
409
+
410
+ ### Strategy 3: Keyword scan
411
+ ```python
412
+ for v in {"accelerate", "brake", "lane_change_left", "lane_change_right", "maintain"}:
413
+ if v in text:
414
+ return v
415
+ ```
416
+
417
+ This handles outputs like `decision: "I want to accelerate now"`.
418
+
419
+ ### Fallback
420
+ If nothing matches β†’ `"maintain"` (safe default).
421
+
422
+ **For training scripts**: The cleanest format is to put the exact decision string in the `decision` field. The tolerant parsing is there so that LLMs in early training (before they learn the format) still produce valid actions rather than crashing.
423
+
424
+ ---
425
+
426
+ ## 10. Observation Text Format
427
+
428
+ The `scene_description` field is a multi-line string that the LLM reads as its input. Example:
429
+
430
+ ```
431
+ You are Car 0 in lane 2, position 45, speed 60.
432
+ Goal: reach position 180.
433
+ Nearby cars:
434
+ - Car 1: lane 1, position 43, speed 55
435
+ - Car 2: lane 3, position 48, speed 70
436
+ - Car 3: lane 2, position 65, speed 50 [AHEAD IN YOUR LANE - 20 units away]
437
+ - Car 4: lane 1, position 30, speed 65
438
+ ```
439
+
440
+ **Annotations added**:
441
+ - `[AHEAD IN YOUR LANE - N units away]` β€” same lane, ahead of agent
442
+ - `[BEHIND IN YOUR LANE - N units away]` β€” same lane, behind agent
443
+ - `[REACHED GOAL]` β€” car has finished
444
+
445
+ The `incident_report` is separate:
446
+ - No incidents: `"Observer: No incidents this step."`
447
+ - With incidents: One line per event, e.g.:
448
+ ```
449
+ NEAR MISS between Car 0 and Car 3 (distance: 12.5)
450
+ Car 0 reached its goal at position 180!
451
+ ```
452
+
453
+ ---
454
+
455
+ ## 11. Server Protocol β€” What Training Scripts Must Send
456
+
457
+ ### WebSocket Protocol (Primary β€” for training)
458
+
459
+ Connect to `ws://host:8000/ws`. All messages are JSON.
460
+
461
+ #### Reset
462
+
463
+ **Send:**
464
+ ```json
465
+ {"type": "reset", "data": {"seed": 42}}
466
+ ```
467
+
468
+ `data` can include `seed` (int) and/or `episode_id` (str). Both are optional.
469
+
470
+ **Receive:**
471
+ ```json
472
+ {
473
+ "type": "observation",
474
+ "data": {
475
+ "observation": {
476
+ "scene_description": "You are Car 0 in lane 3, position 24, speed 40.\n...",
477
+ "incident_report": "",
478
+ "done": false,
479
+ "reward": 0.0,
480
+ "metadata": {}
481
+ },
482
+ "reward": 0.0,
483
+ "done": false
484
+ }
485
+ }
486
+ ```
487
+
488
+ #### Step
489
+
490
+ **Send:**
491
+ ```json
492
+ {
493
+ "type": "step",
494
+ "data": {
495
+ "decision": "brake",
496
+ "reasoning": "Car ahead is close, braking to maintain safe distance."
497
+ }
498
+ }
499
+ ```
500
+
501
+ **Receive:**
502
+ ```json
503
+ {
504
+ "type": "observation",
505
+ "data": {
506
+ "observation": {
507
+ "scene_description": "You are Car 0 in lane 3, position 27, speed 35.\n...",
508
+ "incident_report": "Observer: No incidents this step.",
509
+ "done": false,
510
+ "reward": 2.25,
511
+ "metadata": {}
512
+ },
513
+ "reward": 2.25,
514
+ "done": false
515
+ }
516
+ }
517
+ ```
518
+
519
+ #### State
520
+
521
+ **Send:**
522
+ ```json
523
+ {"type": "state"}
524
+ ```
525
+
526
+ **Receive:**
527
+ ```json
528
+ {
529
+ "type": "state",
530
+ "data": {
531
+ "episode_id": "a1b2c3d4-...",
532
+ "step_count": 7,
533
+ "crash_count": 0,
534
+ "near_miss_count": 3,
535
+ "cars_reached_goal": 0,
536
+ "total_cars": 5
537
+ }
538
+ }
539
+ ```
540
+
541
+ #### Close
542
+
543
+ **Send:**
544
+ ```json
545
+ {"type": "close"}
546
+ ```
547
+
548
+ ### HTTP Protocol (Fallback β€” for simple testing)
549
+
550
+ Note: The HTTP API creates a **new environment instance per endpoint** in factory mode. The `/reset` and `/step` calls hit separate instances. Use WebSocket for stateful multi-step episodes.
551
+
552
+ ```
553
+ POST /reset Body: {"seed": 42} β†’ {"observation": {...}, "reward": 0.0, "done": false}
554
+ POST /step Body: {"action": {"decision": "brake", "reasoning": "..."}} β†’ {"observation": {...}, "reward": ..., "done": ...}
555
+ GET /state β†’ {"episode_id": ..., "step_count": ..., ...}
556
+ GET /health β†’ {"status": "healthy"}
557
+ GET /schema β†’ {"action": {...}, "observation": {...}, "state": {...}}
558
+ ```
559
+
560
+ ### Using the Python Client
561
+
562
+ ```python
563
+ from overflow_env import OverflowEnv, OverflowAction
564
+
565
+ with OverflowEnv(base_url="http://localhost:8000") as env:
566
+ result = env.reset(seed=42)
567
+ # result is StepResult[OverflowObservation]
568
+ # result.observation.scene_description β€” the text for the LLM
569
+ # result.observation.incident_report β€” observer output
570
+ # result.reward β€” float
571
+ # result.done β€” bool
572
+
573
+ while not result.done:
574
+ # Feed scene_description to LLM, get decision + reasoning back
575
+ llm_decision, llm_reasoning = call_llm(result.observation.scene_description)
576
+
577
+ action = OverflowAction(decision=llm_decision, reasoning=llm_reasoning)
578
+ result = env.step(action)
579
+
580
+ # Episode over
581
+ state = env.state()
582
+ print(f"Steps: {state.step_count}, Crashes: {state.crash_count}")
583
+ ```
584
+
585
+ ---
586
+
587
+ ## 12. Training Integration β€” GRPO / TRL
588
+
589
+ ### System prompt for the LLM
590
+
591
+ The training script should set a system prompt like:
592
+
593
+ ```
594
+ You are an autonomous vehicle controller. Each turn you receive a traffic scene description.
595
+ You must output a driving decision and your reasoning.
596
+
597
+ Available decisions: accelerate, brake, lane_change_left, lane_change_right, maintain
598
+
599
+ Output format:
600
+ <think>Your reasoning about the traffic situation</think>
601
+ <action>your_decision</action>
602
+ ```
603
+
604
+ ### What the training loop does each episode
605
+
606
+ ```python
607
+ # 1. Reset environment
608
+ result = env.reset(seed=episode_seed)
609
+
610
+ # 2. Build initial prompt
611
+ messages = [
612
+ {"role": "system", "content": SYSTEM_PROMPT},
613
+ {"role": "user", "content": result.observation.scene_description}
614
+ ]
615
+
616
+ trajectory_rewards = []
617
+
618
+ # 3. Loop until done
619
+ while not result.done:
620
+ # 3a. Get LLM completion
621
+ completion = model.generate(messages) # the text the LLM produces
622
+
623
+ # 3b. Parse LLM output into action
624
+ # The environment's parser is tolerant, but for clean training
625
+ # you might also parse on the client side
626
+ action = OverflowAction(
627
+ decision=extract_decision(completion),
628
+ reasoning=completion # pass full text as reasoning
629
+ )
630
+
631
+ # 3c. Step
632
+ result = env.step(action)
633
+ trajectory_rewards.append(result.reward)
634
+
635
+ # 3d. Append to conversation for next turn
636
+ messages.append({"role": "assistant", "content": completion})
637
+ messages.append({"role": "user", "content": (
638
+ result.observation.scene_description + "\n" +
639
+ result.observation.incident_report
640
+ )})
641
+
642
+ # 4. Compute episode return for GRPO
643
+ episode_return = sum(trajectory_rewards)
644
+ ```
645
+
646
+ ### GRPO reward signal
647
+
648
+ For GRPO (Group Relative Policy Optimization), the reward signal is the **episode return** β€” the sum of all per-step rewards across the episode. The environment is designed so that:
649
+
650
+ - **Positive episode returns** (agent reached goal safely with good reasoning) indicate good behavior
651
+ - **Negative episode returns** (crashes, many near misses) indicate bad behavior
652
+ - The **reasoning bonus** provides per-step reward shaping that encourages the LLM to explain its thinking, which improves interpretability and can speed up learning
653
+
654
+ ### Constructing the reward for TRL
655
+
656
+ If using TRL's `OnlineDPOTrainer` or `GRPOTrainer`:
657
+
658
+ ```python
659
+ # Per-step reward is already in result.reward
660
+ # For token-level reward (assign to last token of each turn):
661
+ rewards_per_turn = trajectory_rewards # list of floats, one per step
662
+
663
+ # For episode-level reward (assign to last token of episode):
664
+ episode_reward = sum(trajectory_rewards)
665
+ ```
666
+
667
+ ---
668
+
669
+ ## 13. Episode Dynamics and RL Characteristics
670
+
671
+ ### Episode length distribution
672
+
673
+ | Scenario | Typical length |
674
+ |----------|---------------|
675
+ | Aggressive accelerate β†’ goal | 12–20 steps |
676
+ | Moderate maintain β†’ goal | 18–30 steps |
677
+ | Conservative braking | 30–50+ steps |
678
+ | Crash (bad luck or bad driving) | 5–15 steps |
679
+ | Max steps timeout | 100 steps |
680
+
681
+ ### What makes this environment learnable
682
+
683
+ 1. **Clear signal**: Crashes give -5.0, goals give +3.0. The agent quickly learns that crashing is bad and reaching the goal is good.
684
+
685
+ 2. **Gradual improvement**: Near misses (-1.0 each) provide intermediate signal. An agent that learns to avoid near misses gets higher returns than one that just avoids crashes.
686
+
687
+ 3. **Speed-accuracy tradeoff**: Accelerating reaches the goal faster (more +3.0 episodes) but increases crash/near-miss risk. The optimal policy is to accelerate when safe and brake/change lanes when needed.
688
+
689
+ 4. **Reasoning is rewarded**: The reasoning bonus (up to +2.0/step) means that over a 20-step episode, reasoning alone can contribute up to +40.0. This incentivizes the LLM to produce structured, situation-aware reasoning.
690
+
691
+ 5. **Stochasticity**: Scripted cars have random elements (10% accelerate, 5% lane change). This means the same seed produces the same episode, but different seeds produce different traffic patterns, forcing the agent to generalize.
692
+
693
+ 6. **All-pairs collision**: The agent is rewarded/punished for the entire traffic system, not just its own car. This means the agent must be aware of the overall traffic flow.
694
+
695
+ ### Typical learning progression
696
+
697
+ 1. **Random policy**: Mostly "maintain", occasional random actions. Episode return: 0 to 15 (depending on luck).
698
+ 2. **Basic safety**: Agent learns to brake when car ahead is close. Fewer crashes, more goals. Episode return: 10 to 25.
699
+ 3. **Strategic driving**: Agent learns to change lanes proactively, accelerate when clear, brake early. Episode return: 20 to 40.
700
+ 4. **Optimized reasoning**: Agent produces structured reasoning with relevant keywords, maximizing the reasoning bonus. Episode return: 30 to 60.
701
+
702
+ ### Reproducibility
703
+
704
+ Passing `seed=N` to `reset()` produces deterministic initial conditions and scripted car behavior (since the `random.Random` instance is seeded). The same seed + same agent actions = same trajectory. This is critical for GRPO, which compares multiple rollouts of the same prompt.
705
+
706
+ ---
707
+
708
+ ## 14. Configuration Constants
709
+
710
+ All constants are defined at the top of `server/overflow_environment.py`:
711
+
712
+ ```python
713
+ NUM_LANES = 3 # Number of road lanes
714
+ ROAD_LENGTH = 200 # Conceptual road length (units)
715
+ NUM_CARS = 5 # Total cars (1 agent + 4 scripted)
716
+ MAX_STEPS = 100 # Maximum steps before forced termination
717
+ CRASH_DISTANCE = 5.0 # Distance threshold for crash
718
+ NEAR_MISS_DISTANCE = 15.0 # Distance threshold for near miss
719
+
720
+ REWARD_CRASH = -5.0 # Reward for any crash
721
+ REWARD_NEAR_MISS = -1.0 # Reward per near-miss pair
722
+ REWARD_SAFE_STEP = 0.5 # Reward for surviving a step
723
+ REWARD_REACHED_GOAL = 3.0 # Reward for reaching goal
724
+ REWARD_REASONING_MAX = 2.0 # Maximum reasoning quality bonus
725
+
726
+ MIN_SPEED = 20 # Minimum car speed
727
+ MAX_SPEED = 90 # Maximum car speed
728
+ SPEED_DELTA = 5 # Speed change per accelerate/brake
729
+ ```
730
+
731
+ To tune difficulty:
732
+ - **Easier**: Increase `CRASH_DISTANCE` and `NEAR_MISS_DISTANCE`, decrease `NUM_CARS`, widen starting positions
733
+ - **Harder**: Decrease distances, increase `NUM_CARS`, narrow starting positions, increase `MAX_SPEED`
734
+ - **Longer episodes**: Increase `ROAD_LENGTH` or decrease starting speeds
735
+ - **More reasoning incentive**: Increase `REWARD_REASONING_MAX`
736
+
737
+ ---
738
+
739
+ ## 15. Docker and Deployment
740
+
741
+ ### Local development
742
+
743
+ ```bash
744
+ uvicorn overflow_env.server.app:app --host 0.0.0.0 --port 8000 --reload
745
+ ```
746
+
747
+ ### Docker build
748
+
749
+ ```bash
750
+ # From the overflow_env/ directory:
751
+ docker build -t overflow-env:latest -f server/Dockerfile .
752
+ docker run -p 8000:8000 overflow-env:latest
753
+ ```
754
+
755
+ The Dockerfile uses a multi-stage build:
756
+ 1. **Builder stage**: Installs dependencies with `uv sync` into a `.venv`
757
+ 2. **Runtime stage**: Copies the `.venv` and source code, runs uvicorn
758
+
759
+ Base image: `ghcr.io/meta-pytorch/openenv-base:latest`
760
+
761
+ ### Push to HuggingFace Spaces
762
+
763
+ ```bash
764
+ openenv push --repo-id username/overflow-env
765
+ ```
766
+
767
+ ### Connect from training script
768
+
769
+ ```python
770
+ # Local
771
+ env = OverflowEnv(base_url="http://localhost:8000")
772
+
773
+ # Docker
774
+ env = OverflowEnv.from_docker_image("overflow-env:latest")
775
+
776
+ # HuggingFace Space
777
+ env = OverflowEnv.from_env("username/overflow-env")
778
+ ```
779
+
780
+ ### openenv.yaml manifest
781
+
782
+ ```yaml
783
+ spec_version: 1
784
+ name: overflow_env
785
+ type: space
786
+ runtime: fastapi
787
+ app: server.app:app
788
+ port: 8000
789
+ ```
790
+
791
+ This tells OpenEnv tooling how to find and run the environment.
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system deps
6
+ RUN apt-get update && \
7
+ apt-get install -y --no-install-recommends git curl && \
8
+ rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy environment code into a proper package directory
11
+ COPY . /app/overflow_env
12
+
13
+ # Install dependencies via pip using requirements.txt
14
+ RUN pip install --no-cache-dir -r /app/overflow_env/server/requirements.txt
15
+
16
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
17
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
18
+
19
+ EXPOSE 8000
20
+
21
+ ENV ENABLE_WEB_INTERFACE=true
22
+ CMD ["uvicorn", "overflow_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
README.md CHANGED
@@ -1,12 +1,70 @@
1
  ---
2
- title: Overflow Openenv
3
- emoji: πŸ“Š
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Overflow OpenENV
3
+ emoji: πŸš—
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ tags:
11
+ - openenv
12
  ---
13
 
14
+ # Overflow Environment
15
+
16
+ An autonomous vehicle fleet oversight environment for [OpenEnv](https://github.com/meta-pytorch/OpenEnv).
17
+
18
+ ## Overview
19
+
20
+ A 2D road grid with N cars. One car (Car 0) is controlled by an LLM agent, while other cars follow simple scripted driving rules. An observer detects crashes and near-misses each step and computes rewards based on safety.
21
+
22
+ ## Quick Start
23
+
24
+ ```bash
25
+ # Install dependencies
26
+ pip install -e .
27
+
28
+ # Run the server
29
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --reload
30
+ ```
31
+
32
+ ```python
33
+ from overflow_env import OverflowEnv, OverflowAction
34
+
35
+ async with OverflowEnv(base_url="http://localhost:8000") as env:
36
+ result = await env.reset()
37
+ print(result.observation.scene_description)
38
+
39
+ action = OverflowAction(decision="maintain", reasoning="Road is clear ahead.")
40
+ result = await env.step(action)
41
+ print(result.observation.incident_report)
42
+ print(f"Reward: {result.reward}, Done: {result.done}")
43
+ ```
44
+
45
+ ## Action Space
46
+
47
+ | Decision | Effect |
48
+ |----------|--------|
49
+ | `accelerate` | Increase speed by 5 |
50
+ | `brake` | Decrease speed by 5 |
51
+ | `lane_change_left` | Move to left lane |
52
+ | `lane_change_right` | Move to right lane |
53
+ | `maintain` | Keep current speed and lane |
54
+
55
+ ## Reward Structure
56
+
57
+ | Event | Reward |
58
+ |-------|--------|
59
+ | Crash (distance < 5) | -5.0 |
60
+ | Near miss (distance < 15) | -1.0 |
61
+ | Safe step toward goal | +0.5 |
62
+ | Reached goal | +3.0 |
63
+ | Reasoning quality bonus | +0.0 to +0.3 |
64
+
65
+ ## Environment Details
66
+
67
+ - **Road**: 3 lanes, ~200 units long
68
+ - **Cars**: 5 total (1 agent + 4 scripted)
69
+ - **Max steps**: 100 per episode
70
+ - **Speed range**: 20–90 units
__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Overflow Environment β€” Autonomous vehicle fleet oversight for OpenEnv."""
2
+
3
+ try:
4
+ from .client import OverflowEnv
5
+ except ImportError:
6
+ OverflowEnv = None # openenv-core not installed; training-only mode
7
+
8
+ from .models import (
9
+ CarStateData,
10
+ LaneOccupancyData,
11
+ OverflowAction,
12
+ OverflowObservation,
13
+ OverflowState,
14
+ Position,
15
+ ProximityData,
16
+ )
17
+
18
+ __all__ = [
19
+ "OverflowAction",
20
+ "OverflowObservation",
21
+ "OverflowState",
22
+ "OverflowEnv",
23
+ "CarStateData",
24
+ "Position",
25
+ "ProximityData",
26
+ "LaneOccupancyData",
27
+ ]
app.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenENV RL Demo β€” Gradio UI entrypoint for HuggingFace Spaces.
3
+
4
+ Runs inside the overflow_env package root. All imports use absolute paths
5
+ so they work both as a package (installed) and as a Space (flat root).
6
+ """
7
+
8
+ import sys, os
9
+ # When running as HF Space, make server/ importable with absolute paths
10
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
11
+
12
+ import math, time, threading
13
+ import numpy as np
14
+ import torch
15
+ import torch.optim as optim
16
+ import matplotlib
17
+ matplotlib.use("Agg")
18
+ import matplotlib.pyplot as plt
19
+ import matplotlib.patches as patches
20
+ import gradio as gr
21
+
22
+ from server.overflow_environment import OverflowEnvironment
23
+ from models import OverflowAction
24
+ from policies.flat_mlp_policy import FlatMLPPolicy
25
+ from policies.policy_spec import build_obs, build_ticket_vector, OBS_DIM
26
+
27
+
28
+ STEPS_PER_EPISODE = 20
29
+ NUM_LANES = 3
30
+ ROAD_LENGTH = 200
31
+
32
+
33
+ # ── Observation adapter ───────────────────────────────────────────────────────
34
+
35
+ def obs_to_vec(overflow_obs) -> np.ndarray:
36
+ cars = overflow_obs.cars
37
+ if not cars:
38
+ return np.zeros(OBS_DIM, dtype=np.float32)
39
+ ego = next((c for c in cars if c.carId == 0), cars[0])
40
+ ego_spd = ego.speed / 4.5
41
+ ego_x = ego.position.x
42
+ ego_y = (ego.lane - 2) * 3.7
43
+ tickets = []
44
+ for car in cars:
45
+ if car.carId == 0:
46
+ continue
47
+ rx = car.position.x - ego.position.x
48
+ ry = (car.lane - ego.lane) * 3.7
49
+ cs = car.speed / 4.5
50
+ d = math.sqrt(rx**2 + ry**2)
51
+ if d > 80:
52
+ continue
53
+ cl = max(ego_spd - cs * math.copysign(1, max(rx, 0.01)), 0.1)
54
+ tickets.append(build_ticket_vector(
55
+ severity_weight=1.0 if d < 8 else 0.75 if d < 15 else 0.5,
56
+ ttl=5.0, pos_x=rx, pos_y=ry, pos_z=0.0,
57
+ vel_x=cs, vel_y=0.0, vel_z=0.0, heading=0.0,
58
+ size_length=4.0, size_width=2.0, size_height=1.5,
59
+ distance=d, time_to_collision=min(d / cl, 30.0),
60
+ bearing=math.atan2(ry, max(rx, 0.01)),
61
+ ticket_type="collision_risk", entity_type="vehicle", confidence=1.0,
62
+ ))
63
+ tv = np.array(tickets, dtype=np.float32) if tickets else None
64
+ return build_obs(ego_x=ego_x, ego_y=ego_y, ego_z=0.0,
65
+ ego_vx=ego_spd, ego_vy=0.0,
66
+ heading=0.0, speed=ego_spd,
67
+ steer=0.0, throttle=0.5, brake=0.0,
68
+ ticket_vectors=tv)
69
+
70
+
71
+ def action_to_decision(a: np.ndarray) -> str:
72
+ s, t, b = float(a[0]), float(a[1]), float(a[2])
73
+ if abs(s) > 0.35: return "lane_change_left" if s < 0 else "lane_change_right"
74
+ if b > 0.25: return "brake"
75
+ if t > 0.20: return "accelerate"
76
+ return "maintain"
77
+
78
+
79
+ # ── Global training state ─────────────────────────────────────────────────────
80
+
81
+ policy = FlatMLPPolicy(obs_dim=OBS_DIM)
82
+ optimizer = optim.Adam(policy.parameters(), lr=3e-4, eps=1e-5)
83
+
84
+ _buf_obs = []
85
+ _buf_acts = []
86
+ _buf_rews = []
87
+ _buf_logps = []
88
+ _buf_vals = []
89
+ _buf_dones = []
90
+
91
+ episode_history = []
92
+ step_log = []
93
+ _running = False
94
+ _lock = threading.Lock()
95
+
96
+
97
+ def _ppo_mini_update():
98
+ if len(_buf_obs) < 2:
99
+ return
100
+ obs_t = torch.tensor(np.array(_buf_obs), dtype=torch.float32)
101
+ acts_t = torch.tensor(np.array(_buf_acts), dtype=torch.float32)
102
+ rews_t = torch.tensor(_buf_rews, dtype=torch.float32)
103
+ logp_t = torch.tensor(_buf_logps, dtype=torch.float32)
104
+ vals_t = torch.tensor(_buf_vals, dtype=torch.float32)
105
+ done_t = torch.tensor(_buf_dones, dtype=torch.float32)
106
+
107
+ gamma, lam = 0.99, 0.95
108
+ adv = torch.zeros_like(rews_t)
109
+ gae = 0.0
110
+ for t in reversed(range(len(rews_t))):
111
+ nv = 0.0 if t == len(rews_t) - 1 else float(vals_t[t + 1])
112
+ d = rews_t[t] + gamma * nv * (1 - done_t[t]) - vals_t[t]
113
+ gae = d + gamma * lam * (1 - done_t[t]) * gae
114
+ adv[t] = gae
115
+ ret = adv + vals_t
116
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
117
+
118
+ policy.train()
119
+ act_mean, val = policy(obs_t)
120
+ val = val.squeeze(-1)
121
+ dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3)
122
+ logp = dist.log_prob(acts_t).sum(dim=-1)
123
+ entropy = dist.entropy().sum(dim=-1).mean()
124
+ ratio = torch.exp(logp - logp_t)
125
+ pg = torch.max(-adv * ratio, -adv * ratio.clamp(0.8, 1.2)).mean()
126
+ vf = 0.5 * ((val - ret) ** 2).mean()
127
+ loss = pg + 0.5 * vf - 0.02 * entropy
128
+ optimizer.zero_grad()
129
+ loss.backward()
130
+ torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)
131
+ optimizer.step()
132
+
133
+
134
+ def run_episodes_loop():
135
+ global _running
136
+ ep_num = 0
137
+ env = OverflowEnvironment()
138
+
139
+ while _running:
140
+ ep_num += 1
141
+ obs = env.reset()
142
+ ep_rew = 0.0
143
+ outcome = "timeout"
144
+
145
+ _buf_obs.clear(); _buf_acts.clear(); _buf_rews.clear()
146
+ _buf_logps.clear(); _buf_vals.clear(); _buf_dones.clear()
147
+
148
+ for step in range(1, STEPS_PER_EPISODE + 1):
149
+ if not _running:
150
+ break
151
+
152
+ obs_vec = obs_to_vec(obs)
153
+ policy.eval()
154
+ with torch.no_grad():
155
+ obs_t = torch.tensor(obs_vec, dtype=torch.float32).unsqueeze(0)
156
+ act_mean, val = policy(obs_t)
157
+ dist = torch.distributions.Normal(act_mean.squeeze(0),
158
+ torch.ones(3) * 0.3)
159
+ action = dist.sample().clamp(-1, 1)
160
+ logp = dist.log_prob(action).sum()
161
+
162
+ decision = action_to_decision(action.numpy())
163
+ obs = env.step(OverflowAction(decision=decision, reasoning=""))
164
+ reward = float(obs.reward or 0.0)
165
+ done = obs.done
166
+ ep_rew += reward
167
+
168
+ _buf_obs.append(obs_vec)
169
+ _buf_acts.append(action.numpy())
170
+ _buf_rews.append(reward)
171
+ _buf_logps.append(float(logp))
172
+ _buf_vals.append(float(val.squeeze()))
173
+ _buf_dones.append(float(done))
174
+
175
+ with _lock:
176
+ step_log.append({
177
+ "ep": ep_num,
178
+ "step": step,
179
+ "decision": decision,
180
+ "reward": round(reward, 2),
181
+ "ep_reward": round(ep_rew, 2),
182
+ "incident": obs.incident_report or "",
183
+ "cars": [(c.carId, c.lane, c.position.x, c.speed)
184
+ for c in obs.cars],
185
+ })
186
+
187
+ if done:
188
+ outcome = "CRASH" if "CRASH" in (obs.incident_report or "") else "GOAL"
189
+ break
190
+
191
+ time.sleep(0.6)
192
+
193
+ _ppo_mini_update()
194
+
195
+ with _lock:
196
+ episode_history.append({
197
+ "ep": ep_num,
198
+ "steps": step,
199
+ "reward": round(ep_rew, 2),
200
+ "outcome": outcome,
201
+ })
202
+
203
+
204
+ # ── Plot helpers ──────────────────────────────────────────────────────────────
205
+
206
+ DECISION_COLORS = {
207
+ "accelerate": "#22c55e",
208
+ "brake": "#ef4444",
209
+ "lane_change_left": "#f59e0b",
210
+ "lane_change_right": "#f59e0b",
211
+ "maintain": "#60a5fa",
212
+ }
213
+
214
+
215
+ def render_road(cars_snapshot, last_decision, last_incident):
216
+ fig, ax = plt.subplots(figsize=(10, 2.8))
217
+ fig.patch.set_facecolor("#0f172a")
218
+ ax.set_facecolor("#1e293b")
219
+
220
+ ax.set_xlim(0, ROAD_LENGTH)
221
+ ax.set_ylim(0, NUM_LANES + 1)
222
+ ax.set_yticks([])
223
+ ax.set_xlabel("Position", color="#94a3b8", fontsize=9)
224
+ ax.tick_params(colors="#94a3b8")
225
+ for spine in ax.spines.values():
226
+ spine.set_edgecolor("#334155")
227
+
228
+ for lane in range(1, NUM_LANES):
229
+ ax.axhline(y=lane + 0.5, color="#334155", linewidth=1, linestyle="--", alpha=0.6)
230
+
231
+ for lane in range(1, NUM_LANES + 1):
232
+ ax.text(2, lane, f"L{lane}", color="#475569", fontsize=8, va="center")
233
+
234
+ ax.axvspan(160, ROAD_LENGTH, alpha=0.12, color="#22c55e")
235
+ ax.text(162, NUM_LANES + 0.6, "GOAL ZONE", color="#22c55e", fontsize=7, alpha=0.8)
236
+
237
+ car_w, car_h = 8, 0.55
238
+ for car_id, lane, pos_x, speed in cars_snapshot:
239
+ is_ego = car_id == 0
240
+ color = "#3b82f6" if is_ego else "#94a3b8"
241
+ outline = "#60a5fa" if is_ego else "#475569"
242
+ lw = 2.0 if is_ego else 1.0
243
+ rect = patches.FancyBboxPatch(
244
+ (pos_x - car_w / 2, lane - car_h / 2),
245
+ car_w, car_h,
246
+ boxstyle="round,pad=0.05",
247
+ facecolor=color, edgecolor=outline, linewidth=lw, alpha=0.92,
248
+ )
249
+ ax.add_patch(rect)
250
+ label = f"{'EGO' if is_ego else f'C{car_id}'}\n{speed:.0f}"
251
+ ax.text(pos_x, lane, label, ha="center", va="center",
252
+ fontsize=6.5, color="white", fontweight="bold" if is_ego else "normal")
253
+
254
+ dec_color = DECISION_COLORS.get(last_decision, "#60a5fa")
255
+ ax.text(ROAD_LENGTH - 2, NUM_LANES + 0.65,
256
+ f"Action: {last_decision.replace('_', ' ').upper()}",
257
+ color=dec_color, fontsize=8, fontweight="bold", ha="right")
258
+
259
+ if "CRASH" in last_incident:
260
+ ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "CRASH",
261
+ color="#ef4444", fontsize=10, fontweight="bold", ha="center")
262
+ elif "NEAR MISS" in last_incident:
263
+ ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "NEAR MISS",
264
+ color="#f59e0b", fontsize=9, fontweight="bold", ha="center")
265
+ elif "GOAL" in last_incident:
266
+ ax.text(ROAD_LENGTH / 2, NUM_LANES + 0.65, "GOAL REACHED",
267
+ color="#22c55e", fontsize=10, fontweight="bold", ha="center")
268
+
269
+ plt.tight_layout(pad=0.3)
270
+ return fig
271
+
272
+
273
+ def render_reward_curve(eps):
274
+ fig, ax = plt.subplots(figsize=(10, 2.8))
275
+ fig.patch.set_facecolor("#0f172a")
276
+ ax.set_facecolor("#1e293b")
277
+ for spine in ax.spines.values():
278
+ spine.set_edgecolor("#334155")
279
+ ax.tick_params(colors="#94a3b8")
280
+ ax.set_xlabel("Episode", color="#94a3b8", fontsize=9)
281
+ ax.set_ylabel("Total Reward", color="#94a3b8", fontsize=9)
282
+
283
+ if not eps:
284
+ ax.text(0.5, 0.5, "Waiting for episodes...", transform=ax.transAxes,
285
+ ha="center", va="center", color="#475569", fontsize=11)
286
+ plt.tight_layout(pad=0.3)
287
+ return fig
288
+
289
+ xs = [e["ep"] for e in eps]
290
+ ys = [e["reward"] for e in eps]
291
+ outcome_colors = {"CRASH": "#ef4444", "GOAL": "#22c55e", "timeout": "#60a5fa"}
292
+ for x, y, e in zip(xs, ys, eps):
293
+ ax.bar(x, y, color=outcome_colors.get(e["outcome"], "#60a5fa"), alpha=0.6, width=0.7)
294
+
295
+ if len(ys) >= 3:
296
+ w = min(5, len(ys))
297
+ smoothed = np.convolve(ys, np.ones(w) / w, mode="valid")
298
+ ax.plot(xs[w - 1:], smoothed, color="#f8fafc", linewidth=2)
299
+
300
+ ax.axhline(0, color="#334155", linewidth=0.8)
301
+
302
+ from matplotlib.patches import Patch
303
+ legend_els = [Patch(facecolor="#ef4444", label="crash"),
304
+ Patch(facecolor="#22c55e", label="goal"),
305
+ Patch(facecolor="#60a5fa", label="timeout")]
306
+ ax.legend(handles=legend_els, facecolor="#1e293b", labelcolor="#94a3b8",
307
+ fontsize=8, framealpha=0.6, edgecolor="#334155", loc="upper left")
308
+
309
+ plt.tight_layout(pad=0.3)
310
+ return fig
311
+
312
+
313
+ # ── Gradio UI ─────────────────────────────────────────────────────────────────
314
+
315
+ def start_training():
316
+ global _running
317
+ if not _running:
318
+ _running = True
319
+ step_log.clear()
320
+ episode_history.clear()
321
+ threading.Thread(target=run_episodes_loop, daemon=True).start()
322
+ return gr.update(value="Running...", interactive=False), gr.update(interactive=True)
323
+
324
+
325
+ def stop_training():
326
+ global _running
327
+ _running = False
328
+ return gr.update(value="Start", interactive=True), gr.update(interactive=False)
329
+
330
+
331
+ def get_updates():
332
+ with _lock:
333
+ logs = list(step_log[-20:])
334
+ eps = list(episode_history[-50:])
335
+ last = step_log[-1] if step_log else None
336
+
337
+ road_fig = render_road(last["cars"], last["decision"], last["incident"]) if last \
338
+ else render_road([], "maintain", "")
339
+ reward_fig = render_reward_curve(eps)
340
+
341
+ lines = []
342
+ for e in reversed(logs):
343
+ flag = ""
344
+ if "CRASH" in e["incident"]: flag = " πŸ’₯"
345
+ elif "GOAL" in e["incident"]: flag = " βœ“"
346
+ elif "NEAR MISS" in e["incident"]: flag = " ⚠"
347
+ lines.append(
348
+ f"ep {e['ep']:>3d} | step {e['step']:>2d} | "
349
+ f"{e['decision']:<20} | r={e['reward']:>+6.2f} | "
350
+ f"ep_total={e['ep_reward']:>7.2f}{flag}"
351
+ )
352
+ step_text = "\n".join(lines) if lines else "Waiting for first episode..."
353
+
354
+ ep_lines = ["Episode | Steps | Total Reward | Outcome", "-" * 44]
355
+ for e in reversed(eps[-15:]):
356
+ ep_lines.append(
357
+ f" {e['ep']:>4d} | {e['steps']:>3d} | "
358
+ f" {e['reward']:>+8.2f} | {e['outcome']}"
359
+ )
360
+ ep_text = "\n".join(ep_lines) if eps else "No episodes completed yet."
361
+
362
+ if len(eps) >= 2:
363
+ rewards = [e["reward"] for e in eps]
364
+ n = len(rewards)
365
+ half = max(n // 2, 1)
366
+ early = sum(rewards[:half]) / half
367
+ late = sum(rewards[half:]) / max(n - half, 1)
368
+ arrow = "↑ improving" if late > early else "↓ declining"
369
+ trend_text = f"Early {half} eps: {early:+.2f} β†’ Last {n-half} eps: {late:+.2f} {arrow}"
370
+ else:
371
+ trend_text = "Collecting data..."
372
+
373
+ status = "● RUNNING" if _running else "β–  STOPPED"
374
+ return road_fig, reward_fig, step_text, ep_text, trend_text, status
375
+
376
+
377
+ _EMPTY_ROAD = render_road([], "maintain", "")
378
+ _EMPTY_REWARD = render_reward_curve([])
379
+
380
+ with gr.Blocks(title="OpenENV RL Demo", theme=gr.themes.Base()) as demo:
381
+ gr.Markdown(
382
+ "# OpenENV RL β€” Live Policy Training\n"
383
+ "**FlatMLPPolicy** drives Car 0 on a 3-lane road for 20 steps per episode. "
384
+ "PPO mini-update after each episode β€” watch rewards trend upward over time."
385
+ )
386
+
387
+ with gr.Row():
388
+ start_btn = gr.Button("Start", variant="primary", scale=1)
389
+ stop_btn = gr.Button("Stop", variant="stop", interactive=False, scale=1)
390
+ status_box = gr.Textbox(value="β–  STOPPED", label="Status",
391
+ interactive=False, scale=0, min_width=130)
392
+
393
+ gr.Markdown("### Road View")
394
+ road_plot = gr.Plot(value=_EMPTY_ROAD, show_label=False)
395
+
396
+ gr.Markdown("### Episode Reward Curve")
397
+ reward_plot = gr.Plot(value=_EMPTY_REWARD, show_label=False)
398
+
399
+ gr.Markdown("### Live Step Feed (last 20 steps)")
400
+ step_display = gr.Textbox(
401
+ value="Press Start to begin...",
402
+ lines=14, max_lines=14, interactive=False,
403
+ )
404
+
405
+ with gr.Row():
406
+ with gr.Column():
407
+ gr.Markdown("### Episode History")
408
+ ep_display = gr.Textbox(lines=10, interactive=False)
409
+ with gr.Column():
410
+ gr.Markdown("### Reward Trend")
411
+ trend_display = gr.Textbox(lines=3, interactive=False)
412
+
413
+ timer = gr.Timer(value=1.0)
414
+ timer.tick(
415
+ fn=get_updates,
416
+ outputs=[road_plot, reward_plot, step_display, ep_display, trend_display, status_box],
417
+ )
418
+
419
+ start_btn.click(fn=start_training, outputs=[start_btn, stop_btn])
420
+ stop_btn.click(fn=stop_training, outputs=[start_btn, stop_btn])
421
+
422
+
423
+ if __name__ == "__main__":
424
+ demo.launch()
client.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Overflow Environment Client.
3
+
4
+ Provides the client for connecting to an Overflow Environment server
5
+ via WebSocket for persistent sessions.
6
+ """
7
+
8
+ from typing import Any, Dict, List
9
+
10
+ from openenv.core.client_types import StepResult
11
+ from openenv.core.env_client import EnvClient
12
+
13
+ from .models import (
14
+ CarStateData,
15
+ LaneOccupancyData,
16
+ OverflowAction,
17
+ OverflowObservation,
18
+ OverflowState,
19
+ Position,
20
+ ProximityData,
21
+ )
22
+
23
+
24
+ class OverflowEnv(EnvClient[OverflowAction, OverflowObservation, OverflowState]):
25
+ """
26
+ WebSocket client for the Overflow Environment.
27
+
28
+ Example:
29
+ >>> with OverflowEnv(base_url="http://localhost:8000") as env:
30
+ ... result = env.reset()
31
+ ... print(result.observation.scene_description)
32
+ ... print(result.observation.cars) # structured car data
33
+ ... action = OverflowAction(decision="maintain", reasoning="Safe for now")
34
+ ... result = env.step(action)
35
+ """
36
+
37
+ def _step_payload(self, action: OverflowAction) -> Dict[str, Any]:
38
+ """Convert OverflowAction to JSON payload for step request."""
39
+ return {
40
+ "decision": action.decision,
41
+ "reasoning": action.reasoning,
42
+ }
43
+
44
+ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[OverflowObservation]:
45
+ """Parse server response into StepResult[OverflowObservation]."""
46
+ obs_data = payload.get("observation", {})
47
+
48
+ # Parse structured car data
49
+ cars = [
50
+ CarStateData(
51
+ carId=c["carId"],
52
+ lane=c["lane"],
53
+ position=Position(**c["position"]),
54
+ speed=c["speed"],
55
+ acceleration=c.get("acceleration", 0.0),
56
+ )
57
+ for c in obs_data.get("cars", [])
58
+ ]
59
+
60
+ proximities = [
61
+ ProximityData(**p) for p in obs_data.get("proximities", [])
62
+ ]
63
+
64
+ lane_occupancies = [
65
+ LaneOccupancyData(**lo) for lo in obs_data.get("lane_occupancies", [])
66
+ ]
67
+
68
+ observation = OverflowObservation(
69
+ scene_description=obs_data.get("scene_description", ""),
70
+ incident_report=obs_data.get("incident_report", ""),
71
+ done=payload.get("done", False),
72
+ reward=payload.get("reward"),
73
+ cars=cars,
74
+ proximities=proximities,
75
+ lane_occupancies=lane_occupancies,
76
+ )
77
+ return StepResult(
78
+ observation=observation,
79
+ reward=payload.get("reward"),
80
+ done=payload.get("done", False),
81
+ )
82
+
83
+ def _parse_state(self, payload: Dict[str, Any]) -> OverflowState:
84
+ """Parse server response into OverflowState."""
85
+ return OverflowState(
86
+ episode_id=payload.get("episode_id"),
87
+ step_count=payload.get("step_count", 0),
88
+ crash_count=payload.get("crash_count", 0),
89
+ near_miss_count=payload.get("near_miss_count", 0),
90
+ cars_reached_goal=payload.get("cars_reached_goal", 0),
91
+ total_cars=payload.get("total_cars", 5),
92
+ )
models.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data models for the Overflow Environment.
3
+
4
+ An autonomous vehicle fleet oversight environment where an LLM agent
5
+ controls one car on a 2D road grid while other cars follow scripted rules.
6
+
7
+ Structured observation fields (cars, proximities, lane_occupancies) are
8
+ compatible with the Overflow frontend's CarState / AnomalyObservation types.
9
+ """
10
+
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ from pydantic import BaseModel, Field
14
+
15
+ try:
16
+ from openenv.core.env_server.types import Action, Observation, State
17
+ except ImportError:
18
+ class Action(BaseModel): pass
19
+ class Observation(BaseModel):
20
+ done: bool = False
21
+ reward: float = 0.0
22
+ class State(BaseModel):
23
+ episode_id: str = ""
24
+ step_count: int = 0
25
+
26
+ # ── Structured sub-models (frontend-compatible) ─────────────────────────
27
+
28
+
29
+ class Position(BaseModel):
30
+ """2D position on the road. x = longitudinal, y = lateral."""
31
+
32
+ x: float = 0.0
33
+ y: float = 0.0
34
+
35
+
36
+ class CarStateData(BaseModel):
37
+ """
38
+ Structured per-car snapshot β€” matches the frontend CarState interface.
39
+
40
+ Frontend type:
41
+ interface CarState {
42
+ carId: number; lane: number;
43
+ position: { x: number; y: number };
44
+ speed: number; acceleration: number;
45
+ }
46
+ """
47
+
48
+ carId: int
49
+ lane: int
50
+ position: Position
51
+ speed: float
52
+ acceleration: float = 0.0
53
+
54
+
55
+ class ProximityData(BaseModel):
56
+ """Pairwise distance between two cars."""
57
+
58
+ carA: int
59
+ carB: int
60
+ distance: float
61
+
62
+
63
+ class LaneOccupancyData(BaseModel):
64
+ """Which cars are in a given lane."""
65
+
66
+ lane: int
67
+ carIds: List[int]
68
+
69
+
70
+ # ── OpenEnv core models ─────────────────────────────────────────────────
71
+
72
+
73
+ class OverflowAction(Action):
74
+ """
75
+ Action for the Overflow environment.
76
+
77
+ The LLM agent outputs a driving decision and optional reasoning.
78
+ """
79
+
80
+ decision: str = Field(
81
+ default="maintain",
82
+ description="Driving decision: accelerate, brake, lane_change_left, lane_change_right, maintain",
83
+ )
84
+ reasoning: str = Field(
85
+ default="",
86
+ description="The LLM's chain-of-thought reasoning for this decision",
87
+ )
88
+
89
+
90
+ class OverflowObservation(Observation):
91
+ """
92
+ Observation from the Overflow environment.
93
+
94
+ Contains both:
95
+ - Text fields (scene_description, incident_report) for the LLM to read.
96
+ - Structured fields (cars, proximities, lane_occupancies) for the frontend
97
+ to render, matching the Overflow frontend AnomalyObservation shape.
98
+ """
99
+
100
+ # ── Text (for the LLM) ──
101
+ scene_description: str = Field(
102
+ default="", description="Text description of the traffic scene"
103
+ )
104
+ incident_report: str = Field(
105
+ default="", description="Observer's incident report, empty if no incident"
106
+ )
107
+
108
+ # ── Structured (for the frontend / viz) ──
109
+ cars: List[CarStateData] = Field(
110
+ default_factory=list, description="Structured state of every car"
111
+ )
112
+ proximities: List[ProximityData] = Field(
113
+ default_factory=list, description="Pairwise proximity measurements"
114
+ )
115
+ lane_occupancies: List[LaneOccupancyData] = Field(
116
+ default_factory=list, description="Per-lane vehicle occupancy"
117
+ )
118
+
119
+
120
+ class OverflowState(State):
121
+ """
122
+ Internal state for the Overflow environment.
123
+ """
124
+
125
+ crash_count: int = Field(default=0, description="Number of crashes this episode")
126
+ near_miss_count: int = Field(
127
+ default=0, description="Number of near misses this episode"
128
+ )
129
+ cars_reached_goal: int = Field(
130
+ default=0, description="Number of cars that reached their goal"
131
+ )
132
+ total_cars: int = Field(
133
+ default=5, description="Total number of cars in the simulation"
134
+ )
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: overflow_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
policies/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .base_policy import BasePolicy
2
+ from .flat_mlp_policy import FlatMLPPolicy
3
+ from .ticket_attention_policy import TicketAttentionPolicy
4
+
5
+ __all__ = ["BasePolicy", "FlatMLPPolicy", "TicketAttentionPolicy"]
policies/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (381 Bytes). View file
 
policies/__pycache__/base_policy.cpython-314.pyc ADDED
Binary file (4.05 kB). View file
 
policies/__pycache__/flat_mlp_policy.cpython-314.pyc ADDED
Binary file (3.71 kB). View file
 
policies/__pycache__/policy_spec.cpython-314.pyc ADDED
Binary file (18.4 kB). View file
 
policies/__pycache__/ticket_attention_policy.cpython-314.pyc ADDED
Binary file (11.3 kB). View file
 
policies/base_policy.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BasePolicy β€” abstract interface all policies implement.
3
+
4
+ All policies expose the same predict() and train_step() API so the
5
+ curriculum trainer can swap them out transparently.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import abc
11
+ from typing import Any, Dict, Optional, Tuple
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+
18
+ class BasePolicy(nn.Module, abc.ABC):
19
+ """
20
+ Abstract base for all driving policies.
21
+
22
+ Subclasses implement:
23
+ forward(obs_tensor) β†’ action_tensor, value_tensor
24
+ encode_obs(obs_np) β†’ torch.Tensor
25
+ """
26
+
27
+ def __init__(self, obs_dim: int, action_dim: int = 3):
28
+ super().__init__()
29
+ self.obs_dim = obs_dim
30
+ self.action_dim = action_dim
31
+
32
+ @abc.abstractmethod
33
+ def forward(
34
+ self, obs: torch.Tensor
35
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ """
37
+ Returns:
38
+ action_mean β€” shape (B, action_dim)
39
+ value β€” shape (B, 1)
40
+ """
41
+ ...
42
+
43
+ def predict(
44
+ self,
45
+ obs: np.ndarray,
46
+ deterministic: bool = False,
47
+ ) -> np.ndarray:
48
+ """Numpy in, numpy out. Used by the env during rollout."""
49
+ self.eval()
50
+ with torch.no_grad():
51
+ t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0)
52
+ mean, _ = self.forward(t)
53
+ if deterministic:
54
+ action = mean
55
+ else:
56
+ action = mean + torch.randn_like(mean) * 0.1
57
+ return action.squeeze(0).numpy()
58
+
59
+ @staticmethod
60
+ def _mlp(dims: list[int], activation=nn.Tanh) -> nn.Sequential:
61
+ layers = []
62
+ for i in range(len(dims) - 1):
63
+ layers.append(nn.Linear(dims[i], dims[i + 1]))
64
+ if i < len(dims) - 2:
65
+ layers.append(activation())
66
+ return nn.Sequential(*layers)
policies/flat_mlp_policy.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FlatMLPPolicy β€” sanity-check baseline.
3
+
4
+ Concatenates the full observation (ego + all tickets flattened) and passes
5
+ it through a standard MLP. No attention, no structure.
6
+
7
+ Use this to:
8
+ 1. Verify the reward signal and environment are working
9
+ 2. Establish a performance floor
10
+ 3. Confirm that TicketAttentionPolicy actually improves over this
11
+
12
+ If FlatMLPPolicy can't learn Stage 1 survival, the reward or env is broken.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from .base_policy import BasePolicy
21
+
22
+
23
+ class FlatMLPPolicy(BasePolicy):
24
+ """Standard 3-layer MLP over the full flat observation."""
25
+
26
+ def __init__(self, obs_dim: int, hidden: int = 256):
27
+ super().__init__(obs_dim)
28
+
29
+ self.actor = nn.Sequential(
30
+ nn.Linear(obs_dim, hidden), nn.LayerNorm(hidden), nn.Tanh(),
31
+ nn.Linear(hidden, hidden), nn.Tanh(),
32
+ nn.Linear(hidden, hidden // 2), nn.Tanh(),
33
+ nn.Linear(hidden // 2, 3), nn.Tanh(),
34
+ )
35
+ self.critic = nn.Sequential(
36
+ nn.Linear(obs_dim, hidden), nn.Tanh(),
37
+ nn.Linear(hidden, hidden // 2), nn.Tanh(),
38
+ nn.Linear(hidden // 2, 1),
39
+ )
40
+ self._init_weights()
41
+
42
+ def _init_weights(self):
43
+ for m in self.modules():
44
+ if isinstance(m, nn.Linear):
45
+ nn.init.orthogonal_(m.weight, gain=1.0)
46
+ nn.init.zeros_(m.bias)
47
+ nn.init.orthogonal_(self.actor[-2].weight, gain=0.01)
48
+
49
+ def forward(self, obs: torch.Tensor):
50
+ return self.actor(obs), self.critic(obs)
policies/policy_spec.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Policy data input specifications β€” formal contracts for observation, action, and ticket data.
3
+
4
+ This module defines the exact data shapes, normalization ranges, and semantic meaning
5
+ of every field consumed by OpenENV policies. Use this as the reference when:
6
+
7
+ 1. Building a new environment that targets these policies
8
+ 2. Writing a bridge/adapter from a different simulator
9
+ 3. Implementing a new policy that must interoperate with the existing set
10
+
11
+ All policies share the same raw observation layout (EGO + ticket matrix).
12
+ Specialized policies (ThreatAvoidance, SystemFailure) select subsets internally.
13
+
14
+ Example usage:
15
+ from openenv.policies.policy_spec import ObsSpec, ActionSpec, validate_obs
16
+
17
+ spec = ObsSpec()
18
+ obs = my_env.get_observation()
19
+ validate_obs(obs, spec) # raises ValueError on shape/range mismatch
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from dataclasses import dataclass, field
25
+ from typing import Any, Dict, List, Optional, Tuple
26
+
27
+ import numpy as np
28
+
29
+
30
+ # ── Ego state specification ──────────────────────────────────────────────────
31
+
32
+ EGO_STATE_DIM = 11
33
+
34
+ @dataclass(frozen=True)
35
+ class EgoField:
36
+ """Description of a single ego state field."""
37
+ index: int
38
+ name: str
39
+ unit: str
40
+ raw_range: Tuple[float, float] # physical range before normalization
41
+ norm_divisor: float # obs_value = raw_value / norm_divisor
42
+ description: str
43
+
44
+ EGO_FIELDS: List[EgoField] = [
45
+ EgoField(0, "x", "m", (-5000, 5000), 1000.0, "Forward displacement from episode start"),
46
+ EgoField(1, "y", "m", (-6.0, 6.0), 3.7, "Lateral displacement (0 = lane center, + = left)"),
47
+ EgoField(2, "z", "m", (-10, 10), 10.0, "Vertical position (flat road = 0)"),
48
+ EgoField(3, "vx", "m/s", (-20, 20), 20.0, "Forward velocity in world frame"),
49
+ EgoField(4, "vy", "m/s", (-20, 20), 20.0, "Lateral velocity in world frame"),
50
+ EgoField(5, "vz", "m/s", (0, 0), 1.0, "Vertical velocity (always 0 on flat road)"),
51
+ EgoField(6, "heading_sin", "rad", (-1, 1), 1.0, "sin(heading angle), 0 = forward"),
52
+ EgoField(7, "heading_cos", "rad", (-1, 1), 1.0, "cos(heading angle), 1 = forward"),
53
+ EgoField(8, "speed", "m/s", (0, 20), 20.0, "Scalar speed = sqrt(vx^2 + vy^2)"),
54
+ EgoField(9, "steer", "norm", (-1, 1), 1.0, "Current steering command [-1=full left, 1=full right]"),
55
+ EgoField(10, "net_drive", "norm", (-1, 1), 1.0, "throttle - brake [-1=full brake, 1=full throttle]"),
56
+ ]
57
+
58
+
59
+ # ── Ticket vector specification ──────────────────────────────────────────────
60
+
61
+ TICKET_VECTOR_DIM = 37 # 18 fixed + 14 type one-hot + 5 entity one-hot
62
+ MAX_TICKETS = 16
63
+
64
+ # Ticket types (14 total) β€” one-hot encoded starting at index 18
65
+ TICKET_TYPES = [
66
+ "collision_risk", "sudden_brake", "side_impact", "head_on",
67
+ "merge_cut", "rear_end_risk",
68
+ "pedestrian_crossing", "cyclist_lane",
69
+ "tire_blowout", "brake_fade", "steering_loss", "sensor_occlusion",
70
+ "road_hazard", "weather_visibility",
71
+ ]
72
+
73
+ # Entity types (5 total) β€” one-hot encoded after ticket types
74
+ ENTITY_TYPES = ["vehicle", "pedestrian", "cyclist", "obstacle", "system"]
75
+
76
+ # Verify dimension
77
+ assert 18 + len(TICKET_TYPES) + len(ENTITY_TYPES) == TICKET_VECTOR_DIM, (
78
+ f"Ticket vector dim mismatch: 18 + {len(TICKET_TYPES)} + {len(ENTITY_TYPES)} "
79
+ f"!= {TICKET_VECTOR_DIM}"
80
+ )
81
+
82
+ @dataclass(frozen=True)
83
+ class TicketField:
84
+ """Description of a single ticket vector field."""
85
+ offset: int # index within the TICKET_VECTOR_DIM vector
86
+ length: int # number of floats
87
+ name: str
88
+ unit: str
89
+ raw_range: Tuple[float, float]
90
+ norm_divisor: float
91
+ description: str
92
+
93
+ TICKET_FIELDS: List[TicketField] = [
94
+ TicketField(0, 1, "severity_weight", "norm", (0, 1), 1.0, "Severity: 0.25=LOW, 0.5=MED, 0.75=HIGH, 1.0=CRITICAL"),
95
+ TicketField(1, 1, "ttl_norm", "s", (0, 10), 10.0, "Time-to-live remaining, clamped to [0,1]"),
96
+ TicketField(2, 1, "pos_x", "m", (-100, 100), 100.0, "Ego-relative X (forward positive)"),
97
+ TicketField(3, 1, "pos_y", "m", (-50, 50), 50.0, "Ego-relative Y (left positive)"),
98
+ TicketField(4, 1, "pos_z", "m", (-10, 10), 10.0, "Ego-relative Z (up positive)"),
99
+ TicketField(5, 1, "vel_x", "m/s", (-30, 30), 30.0, "Entity velocity X in world frame"),
100
+ TicketField(6, 1, "vel_y", "m/s", (-30, 30), 30.0, "Entity velocity Y in world frame"),
101
+ TicketField(7, 1, "vel_z", "m/s", (-10, 10), 10.0, "Entity velocity Z in world frame"),
102
+ TicketField(8, 1, "heading_sin", "rad", (-1, 1), 1.0, "sin(entity heading relative to ego)"),
103
+ TicketField(9, 1, "heading_cos", "rad", (-1, 1), 1.0, "cos(entity heading relative to ego)"),
104
+ TicketField(10, 1, "size_length", "m", (0, 10), 10.0, "Entity bounding box length"),
105
+ TicketField(11, 1, "size_width", "m", (0, 5), 5.0, "Entity bounding box width"),
106
+ TicketField(12, 1, "size_height", "m", (0, 4), 4.0, "Entity bounding box height"),
107
+ TicketField(13, 1, "distance_norm", "m", (0, 100), 100.0, "Euclidean distance to ego, clamped to [0,1]"),
108
+ TicketField(14, 1, "ttc_norm", "s", (0, 30), 30.0, "Time-to-collision, clamped to [0,1]. 1.0 = no collision"),
109
+ TicketField(15, 1, "bearing_sin", "rad", (-1, 1), 1.0, "sin(bearing angle from ego forward axis)"),
110
+ TicketField(16, 1, "bearing_cos", "rad", (-1, 1), 1.0, "cos(bearing angle from ego forward axis)"),
111
+ TicketField(17, 1, "confidence", "norm", (0, 1), 1.0, "Perception confidence [0=unreliable, 1=certain]"),
112
+ TicketField(18, len(TICKET_TYPES), "type_onehot", "bool", (0, 1), 1.0, "One-hot ticket type"),
113
+ TicketField(18 + len(TICKET_TYPES), len(ENTITY_TYPES), "entity_onehot", "bool", (0, 1), 1.0, "One-hot entity type"),
114
+ ]
115
+
116
+
117
+ # ── Full observation specification ───────────────────────────────────────────
118
+
119
+ OBS_DIM = EGO_STATE_DIM + MAX_TICKETS * TICKET_VECTOR_DIM # 11 + 16*37 = 603
120
+
121
+ @dataclass(frozen=True)
122
+ class ObsSpec:
123
+ """Complete observation space specification."""
124
+ ego_dim: int = EGO_STATE_DIM
125
+ ticket_dim: int = TICKET_VECTOR_DIM
126
+ max_tickets: int = MAX_TICKETS
127
+ total_dim: int = OBS_DIM
128
+ dtype: str = "float32"
129
+ value_range: Tuple[float, float] = (-1.0, 1.0)
130
+
131
+ # Layout: obs[0:ego_dim] = ego state
132
+ # obs[ego_dim:] reshaped to (max_tickets, ticket_dim)
133
+ # Tickets are sorted by severity desc, distance asc. Zero-padded rows = empty slots.
134
+
135
+
136
+ # ── Action specification ─────────────────────────────────────────────────────
137
+
138
+ @dataclass(frozen=True)
139
+ class ActionField:
140
+ index: int
141
+ name: str
142
+ raw_range: Tuple[float, float]
143
+ description: str
144
+
145
+ ACTION_DIM = 3
146
+
147
+ ACTION_FIELDS: List[ActionField] = [
148
+ ActionField(0, "steer", (-1.0, 1.0), "Steering command. -1=full left, +1=full right. Scaled by MAX_STEER=0.6 rad"),
149
+ ActionField(1, "throttle", (-1.0, 1.0), "Throttle command. Only positive values used (clipped to [0,1]). Scaled by MAX_ACCEL=4.0 m/s^2"),
150
+ ActionField(2, "brake", (-1.0, 1.0), "Brake command. Only positive values used (clipped to [0,1]). Scaled by MAX_BRAKE=8.0 m/s^2"),
151
+ ]
152
+
153
+ @dataclass(frozen=True)
154
+ class ActionSpec:
155
+ """Action space specification."""
156
+ dim: int = ACTION_DIM
157
+ dtype: str = "float32"
158
+ value_range: Tuple[float, float] = (-1.0, 1.0)
159
+
160
+
161
+ # ── Policy input requirements ────────────────────────────────────────────────
162
+
163
+ @dataclass(frozen=True)
164
+ class PolicyInputSpec:
165
+ """Describes what a specific policy reads from the observation."""
166
+ name: str
167
+ reads_ego: bool
168
+ ego_indices: Tuple[int, ...] # which ego fields are used
169
+ reads_tickets: bool
170
+ ticket_filter: Optional[str] # None = all, or "kinematic" / "failure"
171
+ max_tickets_used: int # how many ticket slots the policy actually reads
172
+ requires_history: bool # whether GRU/recurrent hidden state is needed
173
+ description: str
174
+
175
+ POLICY_SPECS: Dict[str, PolicyInputSpec] = {
176
+ "SurvivalPolicy": PolicyInputSpec(
177
+ name="SurvivalPolicy",
178
+ reads_ego=True,
179
+ ego_indices=tuple(range(EGO_STATE_DIM)),
180
+ reads_tickets=False,
181
+ ticket_filter=None,
182
+ max_tickets_used=0,
183
+ requires_history=False,
184
+ description="Stage 1 baseline. Reads only ego state (first 11 dims). "
185
+ "Ticket portion of obs is ignored entirely.",
186
+ ),
187
+ "FlatMLPPolicy": PolicyInputSpec(
188
+ name="FlatMLPPolicy",
189
+ reads_ego=True,
190
+ ego_indices=tuple(range(EGO_STATE_DIM)),
191
+ reads_tickets=True,
192
+ ticket_filter=None,
193
+ max_tickets_used=MAX_TICKETS,
194
+ requires_history=False,
195
+ description="Sanity-check baseline. Reads full flat observation (ego + all tickets "
196
+ "concatenated). No attention or structure.",
197
+ ),
198
+ "TicketAttentionPolicy": PolicyInputSpec(
199
+ name="TicketAttentionPolicy",
200
+ reads_ego=True,
201
+ ego_indices=tuple(range(EGO_STATE_DIM)),
202
+ reads_tickets=True,
203
+ ticket_filter=None,
204
+ max_tickets_used=MAX_TICKETS,
205
+ requires_history=False,
206
+ description="Main policy (Stage 2+). Cross-attention: ego queries ticket set. "
207
+ "Order-invariant over tickets. Padding mask on zero-rows.",
208
+ ),
209
+ "ThreatAvoidancePolicy": PolicyInputSpec(
210
+ name="ThreatAvoidancePolicy",
211
+ reads_ego=True,
212
+ ego_indices=tuple(range(EGO_STATE_DIM)),
213
+ reads_tickets=True,
214
+ ticket_filter="kinematic",
215
+ max_tickets_used=1,
216
+ requires_history=False,
217
+ description="Specialist for kinematic threats (collision_risk, sudden_brake, "
218
+ "side_impact, head_on, merge_cut, rear_end_risk). Extracts the "
219
+ "highest-severity kinematic ticket and gates between brake/evade branches.",
220
+ ),
221
+ "SystemFailurePolicy": PolicyInputSpec(
222
+ name="SystemFailurePolicy",
223
+ reads_ego=True,
224
+ ego_indices=tuple(range(EGO_STATE_DIM)),
225
+ reads_tickets=True,
226
+ ticket_filter="failure",
227
+ max_tickets_used=1,
228
+ requires_history=False,
229
+ description="Specialist for onboard failures (tire_blowout, brake_fade, steering_loss). "
230
+ "Mixture-of-experts with one expert per failure type. Initialized with "
231
+ "domain-correct response priors.",
232
+ ),
233
+ "RecurrentPolicy": PolicyInputSpec(
234
+ name="RecurrentPolicy",
235
+ reads_ego=True,
236
+ ego_indices=tuple(range(EGO_STATE_DIM)),
237
+ reads_tickets=True,
238
+ ticket_filter=None,
239
+ max_tickets_used=MAX_TICKETS,
240
+ requires_history=True,
241
+ description="GRU-based policy for partial observability (Stage 4+). Carries hidden "
242
+ "state across timesteps. Requires h_prev to be tracked by caller.",
243
+ ),
244
+ }
245
+
246
+
247
+ # ── Validation helpers ───────────────────────────────────────────────────────
248
+
249
+ def validate_obs(obs: np.ndarray, spec: Optional[ObsSpec] = None) -> None:
250
+ """
251
+ Validate an observation array against the spec.
252
+ Raises ValueError with a descriptive message on any mismatch.
253
+ """
254
+ spec = spec or ObsSpec()
255
+ if obs.ndim != 1:
256
+ raise ValueError(f"Observation must be 1D, got shape {obs.shape}")
257
+ if obs.shape[0] != spec.total_dim:
258
+ raise ValueError(
259
+ f"Observation dim mismatch: expected {spec.total_dim}, got {obs.shape[0]}. "
260
+ f"Check ego_dim ({spec.ego_dim}) + max_tickets ({spec.max_tickets}) "
261
+ f"* ticket_dim ({spec.ticket_dim})"
262
+ )
263
+ if obs.dtype != np.float32:
264
+ raise ValueError(f"Observation dtype must be float32, got {obs.dtype}")
265
+
266
+
267
+ def validate_action(action: np.ndarray) -> None:
268
+ """Validate an action array."""
269
+ if action.shape != (ACTION_DIM,):
270
+ raise ValueError(f"Action shape mismatch: expected ({ACTION_DIM},), got {action.shape}")
271
+ if np.any(action < -1.0) or np.any(action > 1.0):
272
+ raise ValueError(f"Action values must be in [-1, 1], got min={action.min()}, max={action.max()}")
273
+
274
+
275
+ def build_obs(
276
+ ego_x: float, ego_y: float, ego_z: float,
277
+ ego_vx: float, ego_vy: float,
278
+ heading: float, speed: float,
279
+ steer: float, throttle: float, brake: float,
280
+ ticket_vectors: Optional[np.ndarray] = None,
281
+ max_tickets: int = MAX_TICKETS,
282
+ ) -> np.ndarray:
283
+ """
284
+ Build a valid observation vector from raw values.
285
+
286
+ This is the primary entry point for external environments that want to
287
+ produce observations compatible with OpenENV policies.
288
+
289
+ Parameters
290
+ ----------
291
+ ego_x : forward displacement from episode start (metres)
292
+ ego_y : lateral displacement from lane center (metres, + = left)
293
+ ego_z : vertical position (metres)
294
+ ego_vx : forward velocity (m/s)
295
+ ego_vy : lateral velocity (m/s)
296
+ heading : heading angle (radians, 0 = forward)
297
+ speed : scalar speed (m/s)
298
+ steer : current steering command [-1, 1]
299
+ throttle : current throttle command [0, 1]
300
+ brake : current brake command [0, 1]
301
+ ticket_vectors : (N, TICKET_VECTOR_DIM) array of ticket vectors, or None.
302
+ Use EventTicket.to_vector() or build_ticket_vector() to create these.
303
+ max_tickets : number of ticket slots (must match policy expectation, default 16)
304
+
305
+ Returns
306
+ -------
307
+ obs : np.ndarray of shape (EGO_STATE_DIM + max_tickets * TICKET_VECTOR_DIM,)
308
+ """
309
+ import math
310
+
311
+ ego = np.array([
312
+ ego_x / 1000.0,
313
+ ego_y / 3.7, # ROAD_HALF_WIDTH
314
+ ego_z / 10.0,
315
+ ego_vx / 20.0, # MAX_SPEED
316
+ ego_vy / 20.0,
317
+ 0.0, # vz (flat road)
318
+ math.sin(heading),
319
+ math.cos(heading),
320
+ speed / 20.0,
321
+ steer,
322
+ throttle - brake, # net drive signal
323
+ ], dtype=np.float32)
324
+
325
+ ticket_matrix = np.zeros((max_tickets, TICKET_VECTOR_DIM), dtype=np.float32)
326
+ if ticket_vectors is not None:
327
+ n = min(len(ticket_vectors), max_tickets)
328
+ ticket_matrix[:n] = ticket_vectors[:n]
329
+
330
+ return np.concatenate([ego, ticket_matrix.flatten()])
331
+
332
+
333
+ def build_ticket_vector(
334
+ severity_weight: float,
335
+ ttl: float,
336
+ pos_x: float, pos_y: float, pos_z: float,
337
+ vel_x: float, vel_y: float, vel_z: float,
338
+ heading: float,
339
+ size_length: float, size_width: float, size_height: float,
340
+ distance: float,
341
+ time_to_collision: Optional[float],
342
+ bearing: float,
343
+ ticket_type: str,
344
+ entity_type: str,
345
+ confidence: float = 1.0,
346
+ ) -> np.ndarray:
347
+ """
348
+ Build a single ticket vector from raw values without needing the full
349
+ EventTicket class. Use this when adapting a different simulator.
350
+
351
+ Parameters
352
+ ----------
353
+ severity_weight : 0.25 (LOW), 0.5 (MEDIUM), 0.75 (HIGH), 1.0 (CRITICAL)
354
+ ttl : seconds remaining until ticket expires
355
+ pos_x/y/z : ego-relative position (metres)
356
+ vel_x/y/z : entity velocity in world frame (m/s)
357
+ heading : entity heading relative to ego (radians)
358
+ size_length/width/height : entity bounding box (metres)
359
+ distance : euclidean distance to ego (metres)
360
+ time_to_collision : seconds until collision, or None if no collision course
361
+ bearing : angle from ego forward axis (radians)
362
+ ticket_type : one of TICKET_TYPES (e.g., "collision_risk")
363
+ entity_type : one of ENTITY_TYPES (e.g., "vehicle")
364
+ confidence : perception confidence [0, 1]
365
+
366
+ Returns
367
+ -------
368
+ vec : np.ndarray of shape (TICKET_VECTOR_DIM,) = (37,)
369
+ """
370
+ import math
371
+
372
+ ttc_norm = min((time_to_collision if time_to_collision is not None else 30.0) / 30.0, 1.0)
373
+
374
+ type_oh = [0.0] * len(TICKET_TYPES)
375
+ entity_oh = [0.0] * len(ENTITY_TYPES)
376
+
377
+ if ticket_type in TICKET_TYPES:
378
+ type_oh[TICKET_TYPES.index(ticket_type)] = 1.0
379
+ else:
380
+ raise ValueError(f"Unknown ticket_type '{ticket_type}'. Must be one of {TICKET_TYPES}")
381
+
382
+ if entity_type in ENTITY_TYPES:
383
+ entity_oh[ENTITY_TYPES.index(entity_type)] = 1.0
384
+ else:
385
+ raise ValueError(f"Unknown entity_type '{entity_type}'. Must be one of {ENTITY_TYPES}")
386
+
387
+ vec = [
388
+ severity_weight,
389
+ min(ttl / 10.0, 1.0),
390
+ pos_x / 100.0,
391
+ pos_y / 50.0,
392
+ pos_z / 10.0,
393
+ vel_x / 30.0,
394
+ vel_y / 30.0,
395
+ vel_z / 10.0,
396
+ math.sin(heading),
397
+ math.cos(heading),
398
+ size_length / 10.0,
399
+ size_width / 5.0,
400
+ size_height / 4.0,
401
+ min(distance / 100.0, 1.0),
402
+ ttc_norm,
403
+ math.sin(bearing),
404
+ math.cos(bearing),
405
+ confidence,
406
+ *type_oh,
407
+ *entity_oh,
408
+ ]
409
+ return np.array(vec, dtype=np.float32)
policies/ticket_attention_policy.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TicketAttentionPolicy β€” the main policy (Stage 2+).
3
+
4
+ Architecture: two-pass "reflective" cross-attention.
5
+
6
+ Pass 1: ego queries tickets β†’ raw threat context
7
+ Pass 2: (ego + raw context) queries tickets again β†’ refined context
8
+ This forces the policy to "think twice" β€” first perceive, then plan.
9
+
10
+ [ego | refined_context] β†’ steer head β†’ steer action
11
+ β†’ drive head β†’ throttle, brake
12
+ β†’ critic head β†’ value
13
+
14
+ Why two-pass:
15
+ The first pass gathers what threats exist. The second pass re-examines
16
+ tickets knowing what the overall threat picture looks like. This prevents
17
+ the impulsive single-shot responses that cause wild oscillation.
18
+
19
+ Why separate heads:
20
+ Steering requires smooth, conservative output (off-road = death).
21
+ Throttle/brake can be more aggressive. Separate heads + separate
22
+ noise levels let each dimension learn at its own pace.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+
31
+ from .base_policy import BasePolicy
32
+ EGO_STATE_DIM = 11
33
+ MAX_TICKETS = 16
34
+ TICKET_VECTOR_DIM = 37
35
+
36
+
37
+ class TicketAttentionPolicy(BasePolicy):
38
+ """
39
+ Two-pass reflective attention policy.
40
+
41
+ Pass 1: perceive β€” what threats exist?
42
+ Pass 2: plan β€” given what I see, which threats matter most?
43
+ Output: separate steer head (conservative) + drive head (throttle/brake)
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ obs_dim: int,
49
+ ego_embed: int = 64,
50
+ ticket_embed: int = 64,
51
+ n_heads: int = 4,
52
+ hidden: int = 256,
53
+ ):
54
+ super().__init__(obs_dim)
55
+ assert ego_embed % n_heads == 0
56
+ assert ticket_embed == ego_embed
57
+
58
+ self.ego_embed = ego_embed
59
+ self.max_tickets = MAX_TICKETS
60
+ self.ticket_dim = TICKET_VECTOR_DIM
61
+
62
+ # ── Encoders ──────────────────────────────────────────────────────
63
+ self.ego_encoder = nn.Sequential(
64
+ nn.Linear(EGO_STATE_DIM, hidden // 2),
65
+ nn.LayerNorm(hidden // 2),
66
+ nn.Tanh(),
67
+ nn.Linear(hidden // 2, ego_embed),
68
+ nn.LayerNorm(ego_embed),
69
+ )
70
+ self.ticket_encoder = nn.Sequential(
71
+ nn.Linear(TICKET_VECTOR_DIM, hidden // 2),
72
+ nn.LayerNorm(hidden // 2),
73
+ nn.ReLU(),
74
+ nn.Linear(hidden // 2, ticket_embed),
75
+ nn.LayerNorm(ticket_embed),
76
+ )
77
+
78
+ # ── Pass 1: perceive (ego queries tickets) ───────────────────────
79
+ self.attn_pass1 = nn.MultiheadAttention(
80
+ embed_dim=ego_embed, num_heads=n_heads,
81
+ dropout=0.0, batch_first=True,
82
+ )
83
+ self.norm1 = nn.LayerNorm(ego_embed)
84
+
85
+ # ── Reflection gate: fuse ego + pass1 context for second query ───
86
+ self.reflect_proj = nn.Sequential(
87
+ nn.Linear(ego_embed * 2, ego_embed),
88
+ nn.LayerNorm(ego_embed),
89
+ nn.Tanh(),
90
+ )
91
+
92
+ # ── Pass 2: plan (refined query re-attends to tickets) ───────────
93
+ self.attn_pass2 = nn.MultiheadAttention(
94
+ embed_dim=ego_embed, num_heads=n_heads,
95
+ dropout=0.0, batch_first=True,
96
+ )
97
+ self.norm2 = nn.LayerNorm(ego_embed)
98
+
99
+ # ── Fused representation ─────────────────────────────────────────
100
+ fused_dim = ego_embed + ego_embed # ego + refined context
101
+
102
+ # ── Steer head (conservative, smooth output) ─────────────────────
103
+ self.steer_head = nn.Sequential(
104
+ nn.Linear(fused_dim, hidden // 2),
105
+ nn.LayerNorm(hidden // 2),
106
+ nn.Tanh(),
107
+ nn.Linear(hidden // 2, hidden // 4),
108
+ nn.Tanh(),
109
+ nn.Linear(hidden // 4, 1),
110
+ nn.Tanh(),
111
+ )
112
+
113
+ # ── Drive head (throttle + brake) ────────────────────────────────
114
+ self.drive_head = nn.Sequential(
115
+ nn.Linear(fused_dim, hidden // 2),
116
+ nn.LayerNorm(hidden // 2),
117
+ nn.Tanh(),
118
+ nn.Linear(hidden // 2, hidden // 4),
119
+ nn.Tanh(),
120
+ nn.Linear(hidden // 4, 2),
121
+ nn.Tanh(),
122
+ )
123
+
124
+ # ── Critic head ──────────────────────────────────────────────────
125
+ self.critic = nn.Sequential(
126
+ nn.Linear(fused_dim, hidden),
127
+ nn.LayerNorm(hidden),
128
+ nn.Tanh(),
129
+ nn.Linear(hidden, hidden // 2),
130
+ nn.Tanh(),
131
+ nn.Linear(hidden // 2, 1),
132
+ )
133
+
134
+ self._init_weights()
135
+
136
+ def _init_weights(self):
137
+ for m in self.modules():
138
+ if isinstance(m, nn.Linear):
139
+ nn.init.orthogonal_(m.weight, gain=1.0)
140
+ if m.bias is not None:
141
+ nn.init.zeros_(m.bias)
142
+ # Very small initial actions β€” start by doing almost nothing
143
+ nn.init.orthogonal_(self.steer_head[-2].weight, gain=0.01)
144
+ nn.init.orthogonal_(self.drive_head[-2].weight, gain=0.01)
145
+ # Critic starts near zero
146
+ nn.init.orthogonal_(self.critic[-1].weight, gain=0.1)
147
+
148
+ def _attend(self, attn_module, norm_module, query, tk_emb, is_padding, all_empty):
149
+ """Run one attention pass with NaN-safe masking."""
150
+ B = query.shape[0]
151
+ q = query if query.dim() == 3 else query.unsqueeze(1)
152
+
153
+ if all_empty.all():
154
+ return torch.zeros(B, self.ego_embed, device=query.device)
155
+
156
+ safe_mask = is_padding.clone()
157
+ safe_mask[all_empty, 0] = False
158
+ attn_out, _ = attn_module(
159
+ query=q, key=tk_emb, value=tk_emb,
160
+ key_padding_mask=safe_mask,
161
+ )
162
+ context = attn_out.squeeze(1)
163
+ context[all_empty] = 0.0
164
+ return norm_module(context)
165
+
166
+ def forward(self, obs: torch.Tensor):
167
+ B = obs.shape[0]
168
+
169
+ # Split observation
170
+ ego_raw = obs[:, :EGO_STATE_DIM]
171
+ tk_raw = obs[:, EGO_STATE_DIM:].view(B, self.max_tickets, self.ticket_dim)
172
+
173
+ # Encode
174
+ ego_emb = self.ego_encoder(ego_raw)
175
+ tk_emb = self.ticket_encoder(tk_raw)
176
+
177
+ # Padding mask
178
+ is_padding = (tk_raw.abs().sum(dim=-1) == 0)
179
+ all_empty = is_padding.all(dim=-1)
180
+
181
+ # ── Pass 1: perceive ─────────────────────────────────────────────
182
+ ctx1 = self._attend(self.attn_pass1, self.norm1,
183
+ ego_emb, tk_emb, is_padding, all_empty)
184
+
185
+ # ── Reflect: combine ego + initial context into refined query ────
186
+ reflected = self.reflect_proj(torch.cat([ego_emb, ctx1], dim=-1))
187
+
188
+ # ── Pass 2: plan (re-attend with richer query) ───────────────────
189
+ ctx2 = self._attend(self.attn_pass2, self.norm2,
190
+ reflected, tk_emb, is_padding, all_empty)
191
+
192
+ # ── Fuse and decode ──────────────────────────────────────────────
193
+ fused = torch.cat([ego_emb, ctx2], dim=-1)
194
+
195
+ steer = self.steer_head(fused) # (B, 1)
196
+ drive = self.drive_head(fused) # (B, 2)
197
+ action = torch.cat([steer, drive], dim=-1) # (B, 3)
198
+ value = self.critic(fused) # (B, 1)
199
+
200
+ return action, value
201
+
202
+ def get_attention_weights(self, obs: torch.Tensor) -> torch.Tensor:
203
+ """Returns pass-2 attention weights for interpretability."""
204
+ B = obs.shape[0]
205
+ ego_raw = obs[:, :EGO_STATE_DIM]
206
+ tk_raw = obs[:, EGO_STATE_DIM:].view(B, self.max_tickets, self.ticket_dim)
207
+ ego_emb = self.ego_encoder(ego_raw)
208
+ tk_emb = self.ticket_encoder(tk_raw)
209
+ is_padding = (tk_raw.abs().sum(dim=-1) == 0)
210
+ all_empty = is_padding.all(dim=-1)
211
+
212
+ # Pass 1
213
+ ctx1 = self._attend(self.attn_pass1, self.norm1,
214
+ ego_emb, tk_emb, is_padding, all_empty)
215
+ reflected = self.reflect_proj(torch.cat([ego_emb, ctx1], dim=-1))
216
+
217
+ # Pass 2 β€” get weights
218
+ safe_mask = is_padding.clone()
219
+ safe_mask[all_empty, 0] = False
220
+ query = reflected.unsqueeze(1)
221
+ _, weights = self.attn_pass2(
222
+ query=query, key=tk_emb, value=tk_emb,
223
+ key_padding_mask=safe_mask,
224
+ need_weights=True, average_attn_weights=False,
225
+ )
226
+ weights[all_empty] = 0.0
227
+ return weights
pyproject.toml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-overflow-env"
7
+ version = "0.1.0"
8
+ description = "Overflow Environment for OpenEnv β€” autonomous vehicle fleet oversight on a 2D road grid"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "openenv-core[core]>=0.2.1",
12
+ "fastapi>=0.115.0",
13
+ "pydantic>=2.0.0",
14
+ "uvicorn[standard]>=0.24.0",
15
+ "requests>=2.31.0",
16
+ "torch>=2.10.0",
17
+ "numpy>=2.2.6",
18
+ "pillow>=12.1.1",
19
+ "gymnasium>=1.2.3",
20
+ ]
21
+
22
+ [project.optional-dependencies]
23
+ dev = [
24
+ "pytest>=8.0.0",
25
+ "pytest-cov>=4.0.0",
26
+ ]
27
+
28
+ [project.scripts]
29
+ server = "overflow_env.server.app:main"
30
+
31
+ [tool.setuptools]
32
+ packages = ["overflow_env", "overflow_env.server"]
33
+ package-dir = { "overflow_env" = ".", "overflow_env.server" = "server" }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ torch==2.5.1+cpu
3
+ numpy>=1.24.0
4
+ pillow==10.4.0
5
+ matplotlib>=3.8.0
6
+ pydantic>=2.0.0
7
+ requests>=2.31.0
8
+ gymnasium>=0.29.0
server/__init__.py ADDED
File without changes
server/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (167 Bytes). View file
 
server/__pycache__/overflow_environment.cpython-314.pyc ADDED
Binary file (23.6 kB). View file
 
server/__pycache__/policy_adapter.cpython-314.pyc ADDED
Binary file (4.89 kB). View file
 
server/app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application for the Overflow Environment.
3
+
4
+ Exposes the OverflowEnvironment over HTTP and WebSocket endpoints.
5
+
6
+ Usage:
7
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
8
+ """
9
+
10
+ import inspect
11
+
12
+ from openenv.core.env_server.http_server import create_app
13
+
14
+ from ..models import OverflowAction, OverflowObservation
15
+ from .overflow_environment import OverflowEnvironment
16
+
17
+
18
+ def _create_overflow_app():
19
+ """Build app across create_app variants that may expect a factory or an instance."""
20
+ try:
21
+ first_param = next(iter(inspect.signature(create_app).parameters.values()))
22
+ annotation_text = str(first_param.annotation)
23
+ except (StopIteration, TypeError, ValueError):
24
+ annotation_text = "typing.Callable"
25
+
26
+ expects_instance = (
27
+ "Environment" in annotation_text and "Callable" not in annotation_text
28
+ )
29
+ env_arg = OverflowEnvironment() if expects_instance else OverflowEnvironment
30
+ return create_app(
31
+ env_arg, OverflowAction, OverflowObservation, env_name="overflow_env"
32
+ )
33
+
34
+
35
+ app = _create_overflow_app()
36
+
37
+
38
+ def main():
39
+ """Entry point for direct execution via uv run or python -m."""
40
+ import uvicorn
41
+
42
+ uvicorn.run(app, host="0.0.0.0", port=8000)
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()
server/overflow_environment.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Overflow Environment Implementation.
3
+
4
+ A 2D road grid with N cars. One car (Car 0) is the LLM agent, others follow
5
+ scripted rules. An observer checks for collisions each step. The environment
6
+ returns text observations describing the traffic scene and rewards based on safety.
7
+
8
+ Observations carry both text (for the LLM) and structured data (for the frontend).
9
+ """
10
+
11
+ import math
12
+ import random
13
+ import re
14
+ from dataclasses import dataclass, field
15
+ from typing import Any, List, Optional
16
+ from uuid import uuid4
17
+
18
+ try:
19
+ from openenv.core.env_server.interfaces import Environment
20
+ from openenv.core.env_server.types import State
21
+ except ImportError:
22
+ class Environment: # stub for training-only mode
23
+ pass
24
+ class State:
25
+ pass
26
+
27
+ try:
28
+ from ..models import (
29
+ CarStateData, LaneOccupancyData, OverflowAction,
30
+ OverflowObservation, OverflowState, Position, ProximityData,
31
+ )
32
+ from ..policies.flat_mlp_policy import FlatMLPPolicy
33
+ from ..policies.ticket_attention_policy import TicketAttentionPolicy
34
+ from ..policies.policy_spec import OBS_DIM
35
+ from .policy_adapter import overflow_obs_to_policy_obs, policy_action_to_decision
36
+ except ImportError:
37
+ from models import (
38
+ CarStateData, LaneOccupancyData, OverflowAction,
39
+ OverflowObservation, OverflowState, Position, ProximityData,
40
+ )
41
+ from policies.flat_mlp_policy import FlatMLPPolicy
42
+ from policies.ticket_attention_policy import TicketAttentionPolicy
43
+ from policies.policy_spec import OBS_DIM
44
+ from server.policy_adapter import overflow_obs_to_policy_obs, policy_action_to_decision
45
+
46
+ # --- Constants ---
47
+ NUM_LANES = 3
48
+ ROAD_LENGTH = 200
49
+ NUM_CARS = 5
50
+ MAX_STEPS = 100
51
+ CRASH_DISTANCE = 5.0
52
+ NEAR_MISS_DISTANCE = 15.0
53
+ LANE_WIDTH = 3.7 # metres β€” matches frontend's makeCar convention
54
+
55
+ # Reward values
56
+ REWARD_CRASH = -5.0
57
+ REWARD_NEAR_MISS = -1.0
58
+ REWARD_SAFE_STEP = 0.5
59
+ REWARD_REACHED_GOAL = 3.0
60
+ REWARD_REASONING_MAX = 0.3
61
+
62
+ # Speed bounds
63
+ MIN_SPEED = 20
64
+ MAX_SPEED = 90
65
+ SPEED_DELTA = 5
66
+
67
+
68
+ @dataclass
69
+ class Car:
70
+ """Represents a car on the road grid."""
71
+
72
+ car_id: int
73
+ lane: int # 1-indexed: 1, 2, or 3
74
+ position: float
75
+ speed: float
76
+ goal_position: float
77
+ is_agent: bool = False
78
+ reached_goal: bool = False
79
+ prev_speed: float = 0.0 # speed last step, for acceleration calc
80
+
81
+ def distance_to(self, other: "Car") -> float:
82
+ """Euclidean-ish distance considering lane and position."""
83
+ lane_diff = abs(self.lane - other.lane) * 10.0 # lanes are ~10 units apart
84
+ pos_diff = abs(self.position - other.position)
85
+ return math.sqrt(lane_diff**2 + pos_diff**2)
86
+
87
+ @property
88
+ def acceleration(self) -> float:
89
+ """Speed delta since last step."""
90
+ return self.speed - self.prev_speed
91
+
92
+ def to_state_data(self) -> CarStateData:
93
+ """Convert to frontend-compatible CarStateData."""
94
+ return CarStateData(
95
+ carId=self.car_id,
96
+ lane=self.lane,
97
+ position=Position(x=self.position, y=self.lane * LANE_WIDTH),
98
+ speed=self.speed,
99
+ acceleration=self.acceleration,
100
+ )
101
+
102
+
103
+ def _parse_decision(action: OverflowAction) -> str:
104
+ """Extract a valid decision from the action, being forgiving about format."""
105
+ valid = {"accelerate", "brake", "lane_change_left", "lane_change_right", "maintain"}
106
+
107
+ # Try the decision field directly
108
+ decision = action.decision.strip().lower().replace(" ", "_")
109
+ if decision in valid:
110
+ return decision
111
+
112
+ # Try to extract from free text (the LLM might wrap it in tags)
113
+ text = f"{action.decision} {action.reasoning}".lower()
114
+
115
+ # Check for <action>...</action> tags
116
+ match = re.search(r"<action>\s*(\w+)\s*</action>", text)
117
+ if match:
118
+ candidate = match.group(1).strip().replace(" ", "_")
119
+ if candidate in valid:
120
+ return candidate
121
+
122
+ # Check for keywords anywhere (ordered: most specific first to avoid ambiguity)
123
+ for v in ["lane_change_left", "lane_change_right", "accelerate", "brake", "maintain"]:
124
+ if v in text:
125
+ return v
126
+
127
+ return "maintain"
128
+
129
+
130
+ def _compute_reasoning_bonus(reasoning: str) -> float:
131
+ """
132
+ Compute a small reasoning quality bonus (0.0 to 0.3).
133
+
134
+ Gives a minor reward for providing structured reasoning, kept low
135
+ so driving performance remains the dominant training signal.
136
+ """
137
+ if not reasoning:
138
+ return 0.0
139
+
140
+ score = 0.0
141
+ lower = reasoning.lower()
142
+
143
+ # Small bonus for providing any reasoning at all
144
+ if len(reasoning) > 20:
145
+ score += 0.1
146
+
147
+ # Bonus for structured reasoning (not just keyword stuffing)
148
+ if "<think>" in lower or "because" in lower:
149
+ score += 0.1
150
+ if any(word in lower for word in ["therefore", "so i should", "best option", "i will"]):
151
+ score += 0.1
152
+
153
+ return min(score, REWARD_REASONING_MAX)
154
+
155
+
156
+ def _scripted_car_action(car: Car, all_cars: List[Car], rng: random.Random) -> str:
157
+ """
158
+ Simple scripted AI for non-agent cars.
159
+
160
+ Rules:
161
+ - If car ahead in same lane is close (< 20 units): brake
162
+ - If speed is low and random chance: accelerate
163
+ - Otherwise: maintain
164
+ """
165
+ # Find nearest car ahead in same lane
166
+ nearest_ahead_dist = float("inf")
167
+ for other in all_cars:
168
+ if other.car_id == car.car_id:
169
+ continue
170
+ if other.lane == car.lane and other.position > car.position:
171
+ dist = other.position - car.position
172
+ if dist < nearest_ahead_dist:
173
+ nearest_ahead_dist = dist
174
+
175
+ if nearest_ahead_dist < 20:
176
+ return "brake"
177
+
178
+ if car.speed < 60 and rng.random() < 0.1:
179
+ return "accelerate"
180
+
181
+ # Occasionally change lanes to make traffic more dynamic
182
+ if rng.random() < 0.05:
183
+ if car.lane > 1 and rng.random() < 0.5:
184
+ return "lane_change_left"
185
+ elif car.lane < NUM_LANES:
186
+ return "lane_change_right"
187
+
188
+ return "maintain"
189
+
190
+
191
+ def _apply_action(car: Car, decision: str) -> None:
192
+ """Apply a driving decision to a car, mutating it in place."""
193
+ if decision == "accelerate":
194
+ car.speed = min(car.speed + SPEED_DELTA, MAX_SPEED)
195
+ elif decision == "brake":
196
+ car.speed = max(car.speed - SPEED_DELTA, MIN_SPEED)
197
+ elif decision == "lane_change_left":
198
+ if car.lane > 1:
199
+ car.lane -= 1
200
+ elif decision == "lane_change_right":
201
+ if car.lane < NUM_LANES:
202
+ car.lane += 1
203
+ # "maintain" β€” no change
204
+
205
+
206
+ def _generate_scene_description(agent_car: Car, cars: List[Car]) -> str:
207
+ """Generate a text description of the current traffic scene."""
208
+ lines = [
209
+ f"You are Car 0 in lane {agent_car.lane}, position {agent_car.position:.0f}, speed {agent_car.speed:.0f}.",
210
+ f"Goal: reach position {agent_car.goal_position:.0f}.",
211
+ "Nearby cars:",
212
+ ]
213
+
214
+ for car in cars:
215
+ if car.car_id == agent_car.car_id:
216
+ continue
217
+
218
+ detail = f"- Car {car.car_id}: lane {car.lane}, position {car.position:.0f}, speed {car.speed:.0f}"
219
+
220
+ # Add context about relative position
221
+ if car.lane == agent_car.lane:
222
+ pos_diff = car.position - agent_car.position
223
+ if pos_diff > 0:
224
+ detail += f" [AHEAD IN YOUR LANE - {pos_diff:.0f} units away]"
225
+ else:
226
+ detail += f" [BEHIND IN YOUR LANE - {abs(pos_diff):.0f} units away]"
227
+
228
+ if car.reached_goal:
229
+ detail += " [REACHED GOAL]"
230
+
231
+ lines.append(detail)
232
+
233
+ return "\n".join(lines)
234
+
235
+
236
+ def _build_structured_data(
237
+ cars: List[Car],
238
+ proximity_pairs: List[ProximityData],
239
+ ) -> tuple[List[CarStateData], List[LaneOccupancyData]]:
240
+ """Build structured arrays for the observation."""
241
+ cars_data = [c.to_state_data() for c in cars]
242
+
243
+ # Lane occupancies
244
+ lane_map: dict[int, list[int]] = {}
245
+ for car in cars:
246
+ if not car.reached_goal:
247
+ lane_map.setdefault(car.lane, []).append(car.car_id)
248
+ lane_occupancies = [
249
+ LaneOccupancyData(lane=lane, carIds=ids)
250
+ for lane, ids in sorted(lane_map.items())
251
+ ]
252
+
253
+ return cars_data, lane_occupancies
254
+
255
+
256
+ class OverflowEnvironment(Environment):
257
+ """
258
+ Autonomous vehicle fleet oversight environment.
259
+
260
+ A 2D road grid with N cars. Car 0 is the LLM agent, others follow
261
+ scripted rules. The observer detects crashes and near-misses and
262
+ computes rewards based on safety.
263
+ """
264
+
265
+ def __init__(self):
266
+ super().__init__()
267
+ self._state = OverflowState(episode_id=str(uuid4()))
268
+ self._cars: List[Car] = []
269
+ self._rng = random.Random()
270
+ self._done = False
271
+ self._last_obs: Optional[OverflowObservation] = None
272
+ self._policies = {
273
+ "flat_mlp": FlatMLPPolicy(obs_dim=OBS_DIM),
274
+ "ticket_attention": TicketAttentionPolicy(obs_dim=OBS_DIM),
275
+ }
276
+
277
+ def _build_observation(
278
+ self,
279
+ incident_report: str,
280
+ reward: float,
281
+ proximities: Optional[List[ProximityData]] = None,
282
+ ) -> OverflowObservation:
283
+ """Build a full observation with text + structured data."""
284
+ agent = self._cars[0]
285
+ scene = _generate_scene_description(agent, self._cars)
286
+ prox = proximities or []
287
+ cars_data, lane_occ = _build_structured_data(self._cars, prox)
288
+
289
+ return OverflowObservation(
290
+ scene_description=scene,
291
+ incident_report=incident_report,
292
+ done=self._done,
293
+ reward=reward,
294
+ cars=cars_data,
295
+ proximities=prox,
296
+ lane_occupancies=lane_occ,
297
+ )
298
+
299
+ def reset(
300
+ self,
301
+ seed: Optional[int] = None,
302
+ episode_id: Optional[str] = None,
303
+ **kwargs: Any,
304
+ ) -> OverflowObservation:
305
+ """Reset the environment: create road and spawn cars."""
306
+ if seed is not None:
307
+ self._rng = random.Random(seed)
308
+ else:
309
+ self._rng = random.Random()
310
+
311
+ self._state = OverflowState(
312
+ episode_id=episode_id or str(uuid4()),
313
+ step_count=0,
314
+ crash_count=0,
315
+ near_miss_count=0,
316
+ cars_reached_goal=0,
317
+ total_cars=NUM_CARS,
318
+ )
319
+ self._done = False
320
+
321
+ # Spawn cars with random positions, speeds, lanes, and goals
322
+ self._cars = []
323
+
324
+ for i in range(NUM_CARS):
325
+ # Ensure no two cars spawn within crash distance
326
+ for _attempt in range(100):
327
+ lane = self._rng.randint(1, NUM_LANES)
328
+ position = float(self._rng.randint(10, 80))
329
+ too_close = False
330
+ for existing in self._cars:
331
+ lane_diff = abs(lane - existing.lane) * 10.0
332
+ pos_diff = abs(position - existing.position)
333
+ dist = math.sqrt(lane_diff**2 + pos_diff**2)
334
+ if dist < CRASH_DISTANCE * 2:
335
+ too_close = True
336
+ break
337
+ if not too_close:
338
+ break
339
+
340
+ speed = float(self._rng.randint(40, 70))
341
+ goal = float(self._rng.randint(160, 195))
342
+
343
+ self._cars.append(
344
+ Car(
345
+ car_id=i,
346
+ lane=lane,
347
+ position=position,
348
+ speed=speed,
349
+ goal_position=goal,
350
+ is_agent=(i == 0),
351
+ prev_speed=speed, # no delta on first step
352
+ )
353
+ )
354
+
355
+ self._last_obs = self._build_observation(incident_report="", reward=0.0)
356
+ return self._last_obs
357
+
358
+ def step(
359
+ self,
360
+ action: OverflowAction,
361
+ timeout_s: Optional[float] = None,
362
+ **kwargs: Any,
363
+ ) -> OverflowObservation:
364
+ """Execute one simulation step."""
365
+ if self._done:
366
+ return self._build_observation(
367
+ incident_report="Episode is over. Call reset() to start a new one.",
368
+ reward=0.0,
369
+ )
370
+
371
+ # Policy intercept: decision="policy:flat_mlp" or "policy:ticket_attention"
372
+ if action.decision.startswith("policy:") and self._last_obs is not None:
373
+ policy_name = action.decision.split(":", 1)[1].lower()
374
+ if policy_name in self._policies:
375
+ obs_vec = overflow_obs_to_policy_obs(self._last_obs)
376
+ act_vec = self._policies[policy_name].predict(obs_vec)
377
+ decision, reasoning = policy_action_to_decision(act_vec)
378
+ action = OverflowAction(
379
+ decision=decision,
380
+ reasoning=f"[{policy_name}] {reasoning}",
381
+ )
382
+
383
+ self._state.step_count += 1
384
+ reward = 0.0
385
+ incidents = []
386
+
387
+ # Snapshot previous speeds for acceleration tracking
388
+ for car in self._cars:
389
+ car.prev_speed = car.speed
390
+
391
+ # 1. Parse and apply the agent's action to Car 0
392
+ decision = _parse_decision(action)
393
+ _apply_action(self._cars[0], decision)
394
+
395
+ # 2. Compute and apply scripted actions for Cars 1-N
396
+ for car in self._cars[1:]:
397
+ if car.reached_goal:
398
+ continue
399
+ scripted_decision = _scripted_car_action(car, self._cars, self._rng)
400
+ _apply_action(car, scripted_decision)
401
+
402
+ # 3. Move all cars forward based on speed (speed is in units/step, scaled down)
403
+ for car in self._cars:
404
+ if car.reached_goal:
405
+ continue
406
+ car.position += car.speed * 0.1 # scale factor for reasonable movement
407
+
408
+ # 4. Collision detection (pairwise)
409
+ agent_crash = False
410
+ proximity_list: List[ProximityData] = []
411
+ active_cars = [c for c in self._cars if not c.reached_goal]
412
+ agent_id = self._cars[0].car_id
413
+ for i in range(len(active_cars)):
414
+ for j in range(i + 1, len(active_cars)):
415
+ dist = active_cars[i].distance_to(active_cars[j])
416
+ involves_agent = active_cars[i].car_id == agent_id or active_cars[j].car_id == agent_id
417
+ if dist < CRASH_DISTANCE:
418
+ self._state.crash_count += 1
419
+ proximity_list.append(
420
+ ProximityData(
421
+ carA=active_cars[i].car_id,
422
+ carB=active_cars[j].car_id,
423
+ distance=round(dist, 2),
424
+ )
425
+ )
426
+ incidents.append(
427
+ f"CRASH between Car {active_cars[i].car_id} and Car {active_cars[j].car_id}! "
428
+ f"(distance: {dist:.1f})"
429
+ )
430
+ if involves_agent:
431
+ agent_crash = True
432
+ elif dist < NEAR_MISS_DISTANCE:
433
+ self._state.near_miss_count += 1
434
+ # Only penalize near misses involving the agent
435
+ if involves_agent:
436
+ reward += REWARD_NEAR_MISS
437
+ proximity_list.append(
438
+ ProximityData(
439
+ carA=active_cars[i].car_id,
440
+ carB=active_cars[j].car_id,
441
+ distance=round(dist, 2),
442
+ )
443
+ )
444
+ incidents.append(
445
+ f"NEAR MISS between Car {active_cars[i].car_id} and Car {active_cars[j].car_id} "
446
+ f"(distance: {dist:.1f})"
447
+ )
448
+
449
+ if agent_crash:
450
+ reward += REWARD_CRASH
451
+ self._done = True
452
+ else:
453
+ # 5. Goal check for agent car
454
+ agent = self._cars[0]
455
+ if agent.position >= agent.goal_position:
456
+ agent.reached_goal = True
457
+ self._state.cars_reached_goal += 1
458
+ reward += REWARD_REACHED_GOAL
459
+ incidents.append(
460
+ f"Car 0 reached its goal at position {agent.goal_position:.0f}!"
461
+ )
462
+ self._done = True
463
+
464
+ # Check goal for scripted cars too (for state tracking)
465
+ for car in self._cars[1:]:
466
+ if not car.reached_goal and car.position >= car.goal_position:
467
+ car.reached_goal = True
468
+ self._state.cars_reached_goal += 1
469
+
470
+ # 6. Safe step bonus (no crash, agent still active)
471
+ if not self._done:
472
+ reward += REWARD_SAFE_STEP
473
+
474
+ # 7. Reasoning quality bonus
475
+ reasoning_bonus = _compute_reasoning_bonus(action.reasoning)
476
+ reward += reasoning_bonus
477
+
478
+ # 8. Max steps check
479
+ if self._state.step_count >= MAX_STEPS and not self._done:
480
+ self._done = True
481
+ incidents.append(f"Maximum steps ({MAX_STEPS}) reached.")
482
+
483
+ incident_report = (
484
+ "\n".join(incidents) if incidents else "Observer: No incidents this step."
485
+ )
486
+
487
+ self._last_obs = self._build_observation(
488
+ incident_report=incident_report,
489
+ reward=reward,
490
+ proximities=proximity_list,
491
+ )
492
+ return self._last_obs
493
+
494
+ @property
495
+ def state(self) -> OverflowState:
496
+ """Get the current environment state."""
497
+ return self._state
server/policy_adapter.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapter between OverflowObservation (2D road grid) and the OpenENV policy
3
+ observation format (ego state + ticket matrix).
4
+
5
+ Nearby cars are converted to collision_risk tickets so TicketAttentionPolicy
6
+ can reason about them using the same mechanism it was designed for.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import math
12
+ import numpy as np
13
+
14
+ try:
15
+ from ..policies.policy_spec import build_obs, build_ticket_vector, OBS_DIM
16
+ except ImportError:
17
+ from policies.policy_spec import build_obs, build_ticket_vector, OBS_DIM
18
+
19
+
20
+ def overflow_obs_to_policy_obs(obs) -> np.ndarray:
21
+ """OverflowObservation β†’ 603-dim numpy vector for our policies."""
22
+ cars = obs.cars
23
+ if not cars:
24
+ return np.zeros(OBS_DIM, dtype=np.float32)
25
+
26
+ ego = next((c for c in cars if c.carId == 0), cars[0])
27
+ ego_speed_ms = ego.speed / 4.5 # OverflowEnv speed units β†’ m/s
28
+ ego_x = ego.position.x
29
+ ego_y = (ego.lane - 2) * 3.7 # lane β†’ lateral metres
30
+
31
+ ticket_vectors = []
32
+ for car in cars:
33
+ if car.carId == 0:
34
+ continue
35
+ rel_x = car.position.x - ego.position.x
36
+ rel_y = (car.lane - ego.lane) * 3.7
37
+ car_spd = car.speed / 4.5
38
+ distance = math.sqrt(rel_x ** 2 + rel_y ** 2)
39
+ if distance > 80:
40
+ continue
41
+ closing = max(ego_speed_ms - car_spd * math.copysign(1, max(rel_x, 0.01)), 0.1)
42
+ ttc = min(distance / closing, 30.0)
43
+ severity = 1.0 if distance < 8 else (0.75 if distance < 15 else 0.5)
44
+ ticket_vectors.append(build_ticket_vector(
45
+ severity_weight=severity, ttl=5.0,
46
+ pos_x=rel_x, pos_y=rel_y, pos_z=0.0,
47
+ vel_x=car_spd, vel_y=0.0, vel_z=0.0,
48
+ heading=0.0,
49
+ size_length=4.0, size_width=2.0, size_height=1.5,
50
+ distance=distance, time_to_collision=ttc,
51
+ bearing=math.atan2(rel_y, max(rel_x, 0.01)),
52
+ ticket_type="collision_risk", entity_type="vehicle", confidence=1.0,
53
+ ))
54
+
55
+ tv = np.array(ticket_vectors, dtype=np.float32) if ticket_vectors else None
56
+ return build_obs(
57
+ ego_x=ego_x, ego_y=ego_y, ego_z=0.0,
58
+ ego_vx=ego_speed_ms, ego_vy=0.0,
59
+ heading=0.0, speed=ego_speed_ms,
60
+ steer=0.0, throttle=0.5, brake=0.0,
61
+ ticket_vectors=tv,
62
+ )
63
+
64
+
65
+ def policy_action_to_decision(action_vec: np.ndarray) -> tuple[str, str]:
66
+ """Continuous [steer, throttle, brake] β†’ (text decision, reasoning)."""
67
+ steer, throttle, brake = float(action_vec[0]), float(action_vec[1]), float(action_vec[2])
68
+ if abs(steer) > 0.35:
69
+ decision = "lane_change_left" if steer < 0 else "lane_change_right"
70
+ reasoning = f"steer={steer:.2f}: lateral avoidance"
71
+ elif brake > 0.25:
72
+ decision = "brake"
73
+ reasoning = f"brake={brake:.2f}: closing gap"
74
+ elif throttle > 0.20:
75
+ decision = "accelerate"
76
+ reasoning = f"throttle={throttle:.2f}: clear ahead"
77
+ else:
78
+ decision = "maintain"
79
+ reasoning = f"s={steer:.2f} t={throttle:.2f} b={brake:.2f}: holding course"
80
+ return decision, reasoning
server/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ torch>=2.5.0
3
+ gymnasium>=0.29.0
4
+ openenv-core[core]>=0.2.1
5
+ fastapi>=0.115.0
6
+ pydantic>=2.0.0
7
+ uvicorn[standard]>=0.24.0
8
+ requests>=2.31.0
training/__init__.py ADDED
File without changes
training/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (169 Bytes). View file
 
training/__pycache__/curriculum.cpython-314.pyc ADDED
Binary file (5.69 kB). View file
 
training/__pycache__/overflow_gym_env.cpython-314.pyc ADDED
Binary file (9.75 kB). View file
 
training/__pycache__/ppo_trainer.cpython-314.pyc ADDED
Binary file (20.6 kB). View file
 
training/__pycache__/reward.cpython-314.pyc ADDED
Binary file (3.47 kB). View file
 
training/curriculum.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CurriculumManager β€” ported from openenv/training/curriculum.py.
3
+
4
+ Same 4-stage progression and same reward thresholds. Adapted for
5
+ OverflowEnvironment: no ticket injection (the env has its own scripted
6
+ NPCs), stages instead control training logging and advancement criteria.
7
+
8
+ Stage 1 No extra pressure. Goal: learn basic speed + lane keeping.
9
+ Stage 2 Standard traffic. Goal: survive without crashing.
10
+ Stage 3 Evaluate more. Goal: consistent goal-reaching.
11
+ Stage 4 Full evaluation. Goal: high mean reward over long window.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from dataclasses import dataclass
17
+ from typing import List
18
+
19
+
20
+ @dataclass
21
+ class StageConfig:
22
+ stage: int
23
+ name: str
24
+ description: str
25
+ advance_threshold: float # mean episode reward to advance
26
+ advance_window: int # consecutive episodes required
27
+
28
+
29
+ STAGES: List[StageConfig] = [
30
+ StageConfig(
31
+ stage=1, name="Survival",
32
+ description="Learn basic speed control and lane keeping.",
33
+ advance_threshold=50.0, advance_window=8,
34
+ ),
35
+ StageConfig(
36
+ stage=2, name="Crash Avoidance",
37
+ description="Navigate traffic without colliding.",
38
+ advance_threshold=120.0, advance_window=15,
39
+ ),
40
+ StageConfig(
41
+ stage=3, name="Goal Reaching",
42
+ description="Consistently reach the goal position.",
43
+ advance_threshold=200.0, advance_window=15,
44
+ ),
45
+ StageConfig(
46
+ stage=4, name="Mastery",
47
+ description="High reward, smooth driving, minimal near-misses.",
48
+ advance_threshold=280.0, advance_window=15,
49
+ ),
50
+ ]
51
+
52
+
53
+ class CurriculumManager:
54
+ """
55
+ Tracks stage progression based on episode rewards.
56
+ Same API as openenv CurriculumManager β€” PPOTrainer calls it unchanged.
57
+ """
58
+
59
+ def __init__(self, seed: int = 0):
60
+ self._stage_idx = 0
61
+ self._rewards: List[float] = []
62
+ self._auto_advance = True
63
+
64
+ @property
65
+ def current_stage(self) -> int:
66
+ return STAGES[self._stage_idx].stage
67
+
68
+ @property
69
+ def config(self) -> StageConfig:
70
+ return STAGES[self._stage_idx]
71
+
72
+ def step(self, sim_time: float) -> list:
73
+ """No ticket injection in OverflowEnvironment β€” always returns []."""
74
+ return []
75
+
76
+ def record_episode_reward(self, reward: float) -> bool:
77
+ """Record episode reward and advance stage if threshold met."""
78
+ self._rewards.append(reward)
79
+ cfg = self.config
80
+ window = self._rewards[-cfg.advance_window:]
81
+
82
+ if (
83
+ self._auto_advance
84
+ and len(window) >= cfg.advance_window
85
+ and sum(window) / len(window) >= cfg.advance_threshold
86
+ and self._stage_idx < len(STAGES) - 1
87
+ ):
88
+ self._stage_idx += 1
89
+ self._rewards = []
90
+ print(f"[Curriculum] Advanced to Stage {self.current_stage}: {self.config.name}")
91
+ return True
92
+ return False
93
+
94
+ def force_stage(self, stage: int) -> None:
95
+ idx = stage - 1
96
+ if 0 <= idx < len(STAGES):
97
+ self._stage_idx = idx
98
+ self._rewards = []
99
+ print(f"[Curriculum] Forced to Stage {stage}: {self.config.name}")
training/overflow_gym_env.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gymnasium wrapper around OverflowEnvironment.
3
+
4
+ Bridges the gap between OverflowEnvironment (text actions, structured obs)
5
+ and our PPO trainer (continuous actions, numeric obs vector).
6
+
7
+ Observation: 603-dim float32 vector (same layout as CarEnv3D β€” ego state +
8
+ collision-risk ticket matrix built from nearby cars)
9
+
10
+ Action: [steer, throttle, brake] all in [-1, 1]
11
+ β†’ mapped to text decision for OverflowEnvironment
12
+
13
+ This makes OverflowEnvironment a drop-in replacement for CarEnv3D so that
14
+ FlatMLPPolicy and TicketAttentionPolicy train with the exact same PPO loop.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ from typing import Any, Dict, Optional, Tuple
21
+
22
+ import numpy as np
23
+ import gymnasium as gym
24
+ from gymnasium import spaces
25
+
26
+ from ..server.overflow_environment import OverflowEnvironment
27
+ from ..models import OverflowAction
28
+ from ..policies.policy_spec import (
29
+ build_obs, build_ticket_vector, OBS_DIM,
30
+ )
31
+ from .reward import compute_reward
32
+
33
+
34
+ # ── Action mapping ────────────────────────────────────────────────────────────
35
+
36
+ def _action_to_decision(action: np.ndarray) -> str:
37
+ steer, throttle, brake = float(action[0]), float(action[1]), float(action[2])
38
+ if abs(steer) > 0.35:
39
+ return "lane_change_left" if steer < 0 else "lane_change_right"
40
+ if brake > 0.25:
41
+ return "brake"
42
+ if throttle > 0.20:
43
+ return "accelerate"
44
+ return "maintain"
45
+
46
+
47
+ # ── Observation extraction ────────────────────────────────────────────────────
48
+
49
+ def _obs_to_vector(overflow_obs) -> np.ndarray:
50
+ """OverflowObservation β†’ 603-dim numpy vector matching policy_spec layout."""
51
+ cars = overflow_obs.cars
52
+ if not cars:
53
+ return np.zeros(OBS_DIM, dtype=np.float32)
54
+
55
+ ego = next((c for c in cars if c.carId == 0), cars[0])
56
+ ego_speed_ms = ego.speed / 4.5
57
+ ego_x = ego.position.x
58
+ ego_y = (ego.lane - 2) * 3.7
59
+
60
+ ticket_vectors = []
61
+ for car in cars:
62
+ if car.carId == 0:
63
+ continue
64
+ rel_x = car.position.x - ego.position.x
65
+ rel_y = (car.lane - ego.lane) * 3.7
66
+ car_spd = car.speed / 4.5
67
+ distance = math.sqrt(rel_x ** 2 + rel_y ** 2)
68
+ if distance > 80:
69
+ continue
70
+ closing = max(ego_speed_ms - car_spd * math.copysign(1, max(rel_x, 0.01)), 0.1)
71
+ ttc = min(distance / closing, 30.0)
72
+ severity = 1.0 if distance < 8 else (0.75 if distance < 15 else 0.5)
73
+ ticket_vectors.append(build_ticket_vector(
74
+ severity_weight=severity, ttl=5.0,
75
+ pos_x=rel_x, pos_y=rel_y, pos_z=0.0,
76
+ vel_x=car_spd, vel_y=0.0, vel_z=0.0,
77
+ heading=0.0,
78
+ size_length=4.0, size_width=2.0, size_height=1.5,
79
+ distance=distance, time_to_collision=ttc,
80
+ bearing=math.atan2(rel_y, max(rel_x, 0.01)),
81
+ ticket_type="collision_risk", entity_type="vehicle", confidence=1.0,
82
+ ))
83
+
84
+ tv = np.array(ticket_vectors, dtype=np.float32) if ticket_vectors else None
85
+ return build_obs(
86
+ ego_x=ego_x, ego_y=ego_y, ego_z=0.0,
87
+ ego_vx=ego_speed_ms, ego_vy=0.0,
88
+ heading=0.0, speed=ego_speed_ms,
89
+ steer=0.0, throttle=0.5, brake=0.0,
90
+ ticket_vectors=tv,
91
+ )
92
+
93
+
94
+ # ── Gymnasium wrapper ─────────────────────────────────────────────────────────
95
+
96
+ class OverflowGymEnv(gym.Env):
97
+ """
98
+ Gymnasium-compatible wrapper around OverflowEnvironment.
99
+
100
+ Provides the same interface as CarEnv3D so PPOTrainer works unchanged.
101
+ """
102
+
103
+ metadata = {"render_modes": []}
104
+
105
+ def __init__(self):
106
+ super().__init__()
107
+ self._env = OverflowEnvironment()
108
+ self._last_overflow_obs = None
109
+ self._prev_action = np.zeros(3, dtype=np.float32)
110
+ self._sim_time = 0.0 # incremented each step (mirrors CarEnv3D._sim_time)
111
+ self._step_dt = 0.1 # seconds per step
112
+
113
+ self.observation_space = spaces.Box(
114
+ low=-1.0, high=1.0, shape=(OBS_DIM,), dtype=np.float32
115
+ )
116
+ self.action_space = spaces.Box(
117
+ low=-1.0, high=1.0, shape=(3,), dtype=np.float32
118
+ )
119
+
120
+ def reset(
121
+ self,
122
+ seed: Optional[int] = None,
123
+ options: Optional[Dict[str, Any]] = None,
124
+ ) -> Tuple[np.ndarray, Dict]:
125
+ super().reset(seed=seed)
126
+ self._last_overflow_obs = self._env.reset(seed=seed)
127
+ self._prev_action = np.zeros(3, dtype=np.float32)
128
+ self._sim_time = 0.0
129
+ return _obs_to_vector(self._last_overflow_obs), {}
130
+
131
+ def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict]:
132
+ decision = _action_to_decision(action)
133
+ overflow_action = OverflowAction(decision=decision, reasoning="")
134
+ overflow_obs = self._env.step(overflow_action)
135
+ self._last_overflow_obs = overflow_obs
136
+ self._sim_time += self._step_dt
137
+
138
+ obs_vec = _obs_to_vector(overflow_obs)
139
+
140
+ # Extract signals for reward shaping
141
+ ego = next((c for c in overflow_obs.cars if c.carId == 0), None)
142
+ ego_speed_ms = (ego.speed / 4.5) if ego else 0.0
143
+ ego_y = ((ego.lane - 2) * 3.7) if ego else 0.0
144
+
145
+ collision = any("CRASH" in p for p in (overflow_obs.incident_report or "").split("\n")
146
+ if "Car 0" in p)
147
+ goal_reached = overflow_obs.done and not collision
148
+
149
+ reward = compute_reward(
150
+ ego_speed = ego_speed_ms,
151
+ ego_y = ego_y,
152
+ action = action,
153
+ prev_action = self._prev_action,
154
+ collision = collision,
155
+ goal_reached = goal_reached,
156
+ near_miss = "NEAR MISS" in (overflow_obs.incident_report or ""),
157
+ raw_reward = overflow_obs.reward or 0.0,
158
+ )
159
+
160
+ self._prev_action = action.copy()
161
+
162
+ terminated = overflow_obs.done
163
+ truncated = False
164
+ info: Dict[str, Any] = {
165
+ "collision": collision,
166
+ "goal_reached": goal_reached,
167
+ "incident": overflow_obs.incident_report,
168
+ }
169
+
170
+ return obs_vec, reward, terminated, truncated, info
training/ppo_trainer.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PPO trainer β€” ported directly from openenv/training/ppo_trainer.py.
3
+
4
+ Same algorithm, same hyperparameters, same GAE implementation.
5
+ Only change: uses OverflowGymEnv instead of CarEnv3D.
6
+
7
+ Usage:
8
+ from overflow_env.training.ppo_trainer import run_training
9
+ run_training(policy_type="attention", total_steps=2_000_000)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import time
15
+ from collections import deque
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.optim as optim
23
+
24
+ from .overflow_gym_env import OverflowGymEnv
25
+ from .curriculum import CurriculumManager
26
+ from .reward import compute_episode_bonus
27
+ from ..policies.base_policy import BasePolicy
28
+ from ..policies.policy_spec import OBS_DIM
29
+
30
+
31
+ # ── Rollout buffer ─────────────────────────────────────────────────────────────
32
+ # Identical to openenv/training/ppo_trainer.py
33
+
34
+ class RolloutBuffer:
35
+ def __init__(self, n_steps: int, obs_dim: int, device: torch.device):
36
+ self.n = n_steps
37
+ self.obs = torch.zeros(n_steps, obs_dim, device=device)
38
+ self.acts = torch.zeros(n_steps, 3, device=device)
39
+ self.rew = torch.zeros(n_steps, device=device)
40
+ self.val = torch.zeros(n_steps, device=device)
41
+ self.logp = torch.zeros(n_steps, device=device)
42
+ self.done = torch.zeros(n_steps, device=device)
43
+ self.ptr = 0
44
+
45
+ def add(self, obs, act, rew, val, logp, done):
46
+ i = self.ptr
47
+ self.obs[i] = torch.as_tensor(obs, dtype=torch.float32)
48
+ self.acts[i] = torch.as_tensor(act, dtype=torch.float32)
49
+ self.rew[i] = float(rew)
50
+ self.val[i] = float(val)
51
+ self.logp[i] = float(logp)
52
+ self.done[i] = float(done)
53
+ self.ptr += 1
54
+
55
+ def full(self) -> bool:
56
+ return self.ptr >= self.n
57
+
58
+ def reset(self):
59
+ self.ptr = 0
60
+
61
+ def compute_returns(self, last_val: float, gamma: float, gae_lambda: float):
62
+ """Generalized Advantage Estimation β€” identical to openenv."""
63
+ adv = torch.zeros_like(self.rew)
64
+ gae = 0.0
65
+ for t in reversed(range(self.n)):
66
+ next_val = last_val if t == self.n - 1 else float(self.val[t + 1])
67
+ delta = self.rew[t] + gamma * next_val * (1 - self.done[t]) - self.val[t]
68
+ gae = delta + gamma * gae_lambda * (1 - self.done[t]) * gae
69
+ adv[t] = gae
70
+ self.ret = adv + self.val
71
+
72
+
73
+ # ── PPO Trainer ────────────────────────────────────────────────────────────────
74
+
75
+ class PPOTrainer:
76
+ """
77
+ Identical to openenv PPOTrainer β€” same hyperparameters, same PPO update.
78
+ Environment is OverflowGymEnv instead of CarEnv3D.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ policy: BasePolicy,
84
+ env: OverflowGymEnv,
85
+ curriculum: Optional[CurriculumManager] = None,
86
+ # PPO hyperparameters β€” same defaults as openenv
87
+ lr: float = 3e-4,
88
+ gamma: float = 0.99,
89
+ gae_lambda: float = 0.95,
90
+ clip_range: float = 0.2,
91
+ clip_range_vf: float = 0.2,
92
+ ent_coef: float = 0.02,
93
+ vf_coef: float = 0.5,
94
+ max_grad_norm: float = 0.5,
95
+ n_steps: int = 2048,
96
+ batch_size: int = 256,
97
+ n_epochs: int = 10,
98
+ save_dir: str = "checkpoints",
99
+ log_interval: int = 10,
100
+ device: str = "auto",
101
+ ):
102
+ self.policy = policy
103
+ self.env = env
104
+ self.curriculum = curriculum or CurriculumManager()
105
+ self.gamma = gamma
106
+ self.gae_lambda = gae_lambda
107
+ self.clip = clip_range
108
+ self.clip_vf = clip_range_vf
109
+ self.ent_coef = ent_coef
110
+ self.vf_coef = vf_coef
111
+ self.max_grad = max_grad_norm
112
+ self.n_steps = n_steps
113
+ self.batch_size = batch_size
114
+ self.n_epochs = n_epochs
115
+ self.log_every = log_interval
116
+ self.save_dir = Path(save_dir)
117
+ self.save_dir.mkdir(parents=True, exist_ok=True)
118
+
119
+ if device == "auto":
120
+ device = "cuda" if torch.cuda.is_available() else \
121
+ "mps" if torch.backends.mps.is_available() else "cpu"
122
+ self.device = torch.device(device)
123
+ self.policy.to(self.device)
124
+
125
+ self.optimizer = optim.Adam(policy.parameters(), lr=lr, eps=1e-5)
126
+ self.scheduler = optim.lr_scheduler.LinearLR(
127
+ self.optimizer, start_factor=1.0, end_factor=0.1, total_iters=500,
128
+ )
129
+
130
+ self.buffer = RolloutBuffer(n_steps, OBS_DIM, self.device)
131
+
132
+ self.ep_rewards = deque(maxlen=100)
133
+ self.ep_lengths = deque(maxlen=100)
134
+ self.total_steps = 0
135
+ self.n_updates = 0
136
+
137
+ # ── Main training loop ─────────────────────────────────────────────────────
138
+
139
+ def train(self, total_steps: int = 2_000_000) -> None:
140
+ print(f"\n{'='*70}", flush=True)
141
+ print(f" OpenENV PPO Training β€” policy={self.policy.__class__.__name__}", flush=True)
142
+ print(f" total_steps={total_steps} n_steps={self.n_steps} lr={self.optimizer.param_groups[0]['lr']:.0e}", flush=True)
143
+ print(f" gamma={self.gamma} gae_lambda={self.gae_lambda} clip={self.clip} ent_coef={self.ent_coef}", flush=True)
144
+ print(f"{'='*70}\n", flush=True)
145
+
146
+ obs, _ = self.env.reset()
147
+ ep_reward = 0.0
148
+ ep_steps = 0
149
+ t0 = time.time()
150
+
151
+ while self.total_steps < total_steps:
152
+ self.buffer.reset()
153
+ self.policy.eval()
154
+
155
+ # ── Collect rollout ──────────────────────────────────────────────
156
+ for _ in range(self.n_steps):
157
+ # Curriculum step (returns [] for OverflowEnv β€” kept for API compat)
158
+ self.curriculum.step(self.env._sim_time)
159
+
160
+ obs_t = torch.as_tensor(obs, dtype=torch.float32, device=self.device)
161
+ with torch.no_grad():
162
+ act_mean, val = self.policy(obs_t.unsqueeze(0))
163
+ act_mean = act_mean.squeeze(0)
164
+ val = val.squeeze(0)
165
+
166
+ dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3)
167
+ action = dist.sample().clamp(-1, 1)
168
+ logp = dist.log_prob(action).sum()
169
+
170
+ next_obs, reward, term, trunc, info = self.env.step(action.cpu().numpy())
171
+
172
+ self.buffer.add(
173
+ obs, action.cpu().numpy(), reward,
174
+ float(val), float(logp), float(term or trunc),
175
+ )
176
+
177
+ obs = next_obs
178
+ ep_reward += reward
179
+ ep_steps += 1
180
+ self.total_steps += 1
181
+
182
+ if term or trunc:
183
+ bonus = compute_episode_bonus(
184
+ total_steps=ep_steps,
185
+ survived=not info.get("collision", False),
186
+ )
187
+ ep_reward += bonus
188
+ self.ep_rewards.append(ep_reward)
189
+ self.ep_lengths.append(ep_steps)
190
+ advanced = self.curriculum.record_episode_reward(ep_reward)
191
+
192
+ outcome = "CRASH" if info.get("collision") else ("GOAL" if info.get("goal_reached") else "timeout")
193
+ print(
194
+ f" ep#{len(self.ep_rewards):>4d} | "
195
+ f"steps={ep_steps:>3d} | "
196
+ f"reward={ep_reward:>8.2f} | "
197
+ f"outcome={outcome:<8} | "
198
+ f"stage={self.curriculum.current_stage} | "
199
+ f"total_steps={self.total_steps}",
200
+ flush=True,
201
+ )
202
+
203
+ obs, _ = self.env.reset()
204
+ ep_reward = 0.0
205
+ ep_steps = 0
206
+
207
+ # ── PPO update ───────────────────────────────────────────────────
208
+ with torch.no_grad():
209
+ obs_t = torch.as_tensor(obs, dtype=torch.float32, device=self.device)
210
+ _, last_val = self.policy(obs_t.unsqueeze(0))
211
+ self.buffer.compute_returns(float(last_val), self.gamma, self.gae_lambda)
212
+
213
+ self.policy.train()
214
+ self._ppo_update()
215
+ self.n_updates += 1
216
+ self.scheduler.step()
217
+
218
+ elapsed = time.time() - t0
219
+ sps = self.total_steps / max(elapsed, 1)
220
+ mean_r = np.mean(self.ep_rewards) if self.ep_rewards else 0.0
221
+ mean_l = np.mean(self.ep_lengths) if self.ep_lengths else 0.0
222
+ print(
223
+ f"\n[PPO update #{self.n_updates}] "
224
+ f"step={self.total_steps} "
225
+ f"mean_reward={mean_r:.2f} "
226
+ f"mean_ep_len={mean_l:.0f} "
227
+ f"stage={self.curriculum.current_stage} "
228
+ f"sps={sps:.0f}\n",
229
+ flush=True,
230
+ )
231
+
232
+ # ── Checkpoint ───────────────────────────────────────────────────
233
+ if self.n_updates % 50 == 0:
234
+ ckpt = self.save_dir / f"policy_step{self.total_steps}_stage{self.curriculum.current_stage}.pt"
235
+ torch.save({
236
+ "step": self.total_steps,
237
+ "stage": self.curriculum.current_stage,
238
+ "policy": self.policy.state_dict(),
239
+ "optim": self.optimizer.state_dict(),
240
+ }, ckpt)
241
+ print(f"[PPO] Saved checkpoint β†’ {ckpt}")
242
+
243
+ # ── PPO update pass β€” identical to openenv ─────────────────────────────────
244
+
245
+ def _ppo_update(self):
246
+ obs = self.buffer.obs
247
+ acts = self.buffer.acts
248
+ old_logp = self.buffer.logp
249
+ adv = self.buffer.ret - self.buffer.val
250
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
251
+ ret = self.buffer.ret
252
+ old_val = self.buffer.val
253
+
254
+ indices = torch.randperm(self.n_steps, device=self.device)
255
+
256
+ for _ in range(self.n_epochs):
257
+ for start in range(0, self.n_steps, self.batch_size):
258
+ idx = indices[start: start + self.batch_size]
259
+
260
+ act_mean, val = self.policy(obs[idx])
261
+ val = val.squeeze(-1)
262
+
263
+ dist = torch.distributions.Normal(act_mean, torch.ones_like(act_mean) * 0.3)
264
+ logp = dist.log_prob(acts[idx]).sum(dim=-1)
265
+ entropy = dist.entropy().sum(dim=-1).mean()
266
+
267
+ ratio = torch.exp(logp - old_logp[idx])
268
+ pg_loss1 = -adv[idx] * ratio
269
+ pg_loss2 = -adv[idx] * ratio.clamp(1 - self.clip, 1 + self.clip)
270
+ pg_loss = torch.max(pg_loss1, pg_loss2).mean()
271
+
272
+ val_unclipped = (val - ret[idx]) ** 2
273
+ val_clipped = (
274
+ old_val[idx]
275
+ + (val - old_val[idx]).clamp(-self.clip_vf, self.clip_vf)
276
+ - ret[idx]
277
+ ) ** 2
278
+ vf_loss = 0.5 * torch.max(val_unclipped, val_clipped).mean()
279
+
280
+ loss = pg_loss + self.vf_coef * vf_loss - self.ent_coef * entropy
281
+
282
+ self.optimizer.zero_grad()
283
+ loss.backward()
284
+ nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad)
285
+ self.optimizer.step()
286
+
287
+
288
+ # ── Entry point ────────────────────────────────────────────────────────────────
289
+
290
+ def run_training(
291
+ policy_type: str = "attention",
292
+ total_steps: int = 2_000_000,
293
+ start_stage: int = 1,
294
+ checkpoint: Optional[str] = None,
295
+ device: str = "auto",
296
+ ) -> None:
297
+ from ..policies.ticket_attention_policy import TicketAttentionPolicy
298
+ from ..policies.flat_mlp_policy import FlatMLPPolicy
299
+
300
+ policy_map = {
301
+ "attention": lambda: TicketAttentionPolicy(obs_dim=OBS_DIM),
302
+ "mlp": lambda: FlatMLPPolicy(obs_dim=OBS_DIM),
303
+ }
304
+ policy = policy_map[policy_type]()
305
+
306
+ if checkpoint:
307
+ ckpt = torch.load(checkpoint, map_location="cpu")
308
+ policy.load_state_dict(ckpt["policy"])
309
+ print(f"[PPO] Loaded checkpoint from {checkpoint}")
310
+
311
+ env = OverflowGymEnv()
312
+ cm = CurriculumManager()
313
+ if start_stage > 1:
314
+ cm.force_stage(start_stage)
315
+
316
+ trainer = PPOTrainer(policy=policy, env=env, curriculum=cm, device=device, n_steps=512)
317
+ trainer.train(total_steps=total_steps)
318
+
319
+
320
+ if __name__ == "__main__":
321
+ import argparse
322
+ p = argparse.ArgumentParser()
323
+ p.add_argument("--policy", default="attention", choices=["attention", "mlp"])
324
+ p.add_argument("--steps", default=2_000_000, type=int)
325
+ p.add_argument("--stage", default=1, type=int)
326
+ p.add_argument("--checkpoint", default=None)
327
+ p.add_argument("--device", default="auto")
328
+ args = p.parse_args()
329
+ run_training(args.policy, args.steps, args.stage, args.checkpoint, args.device)
training/reward.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reward shaping for OverflowEnvironment β€” ported from openenv/training/reward.py.
3
+
4
+ Same core principle: BASE + THREAT_RESPONSE with clear gradient direction.
5
+ Adapted to OverflowEnvironment's signals (no EventTicket objects β€” uses
6
+ collision/near-miss flags and raw reward from the environment).
7
+
8
+ BASE: survival + speed + lane ~+0.4/step
9
+ COLLISION: -50 (terminal)
10
+ NEAR MISS: -0.8 per event
11
+ GOAL REACHED: +5.0 (terminal bonus)
12
+ SMOOTH DRIVING: small bonus when no threats
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import numpy as np
18
+
19
+ # ── Same weights as openenv/training/reward.py ────────────────────────────────
20
+
21
+ W_ALIVE = 0.40
22
+ W_SPEED = 0.10
23
+ W_LANE = 0.15
24
+ W_SMOOTH = 0.03
25
+ TARGET_SPEED = 11.0 # m/s (~40 km/h)
26
+ TARGET_SPEED_TOL = 3.0
27
+
28
+ W_COLLISION = -50.0
29
+ W_NEAR_MISS = -0.8
30
+ W_GOAL = 5.0
31
+ W_SURVIVE_BONUS = 5.0
32
+
33
+ ROAD_HALF_WIDTH = 3.7 * 1.5 # ~2.5 lanes worth of tolerance
34
+
35
+
36
+ def compute_reward(
37
+ ego_speed: float,
38
+ ego_y: float,
39
+ action: np.ndarray,
40
+ prev_action: np.ndarray,
41
+ collision: bool,
42
+ goal_reached: bool,
43
+ near_miss: bool,
44
+ raw_reward: float, # OverflowEnvironment's built-in reward (used as baseline)
45
+ ) -> float:
46
+ """
47
+ Shaped reward. Mirrors openenv reward structure:
48
+ - collision β†’ large terminal penalty
49
+ - base survival + speed + lane keeping
50
+ - near-miss penalty
51
+ - goal bonus
52
+ - smooth driving bonus when clear
53
+ """
54
+ if collision:
55
+ return W_COLLISION
56
+
57
+ reward = 0.0
58
+
59
+ # 1. Survival
60
+ reward += W_ALIVE
61
+
62
+ # 2. Speed maintenance (same formula as openenv)
63
+ speed_err = abs(ego_speed - TARGET_SPEED)
64
+ if speed_err < TARGET_SPEED_TOL:
65
+ reward += W_SPEED * (1.0 - speed_err / TARGET_SPEED_TOL)
66
+ else:
67
+ reward -= 0.03 * min(speed_err - TARGET_SPEED_TOL, 5.0)
68
+
69
+ # 3. Lane keeping
70
+ norm_y = abs(ego_y) / ROAD_HALF_WIDTH
71
+ reward += W_LANE * max(0.0, 1.0 - norm_y ** 2)
72
+
73
+ # 4. Near miss penalty
74
+ if near_miss:
75
+ reward += W_NEAR_MISS
76
+
77
+ # 5. Goal bonus
78
+ if goal_reached:
79
+ reward += W_GOAL
80
+
81
+ # 6. Smooth driving
82
+ action_delta = np.abs(action - prev_action).sum()
83
+ reward += W_SMOOTH * max(0.0, 1.0 - action_delta * 3.0)
84
+
85
+ return float(reward)
86
+
87
+
88
+ def compute_episode_bonus(total_steps: int, survived: bool) -> float:
89
+ """End-of-episode bonus β€” same as openenv."""
90
+ if not survived:
91
+ return 0.0
92
+ bonus = W_SURVIVE_BONUS
93
+ bonus += min(total_steps, 500) * 0.02 # longevity reward
94
+ return float(bonus)