aparekh02 commited on
Commit
69d4a95
·
verified ·
1 Parent(s): 055871c

bundle overflow_env locally, drop openenv-core git dep (websockets conflict fix)

Browse files
app.py CHANGED
@@ -12,7 +12,7 @@ import torch
12
  import torch.optim as optim
13
  import gradio as gr
14
 
15
- from overflow_env.server.overflow_environment import OverflowEnvironment
16
  from overflow_env.models import OverflowAction
17
  from policies.flat_mlp_policy import FlatMLPPolicy
18
  from policies.ticket_attention_policy import TicketAttentionPolicy
 
12
  import torch.optim as optim
13
  import gradio as gr
14
 
15
+ from overflow_env.environment import OverflowEnvironment
16
  from overflow_env.models import OverflowAction
17
  from policies.flat_mlp_policy import FlatMLPPolicy
18
  from policies.ticket_attention_policy import TicketAttentionPolicy
overflow_env/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .environment import OverflowEnvironment
2
+ from .models import OverflowAction, OverflowObservation
overflow_env/environment.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Overflow Environment — standalone bundled version (no openenv.core dependency).
3
+ 2D road grid, 5 cars, 3 lanes. Car 0 is the RL agent.
4
+ """
5
+
6
+ import math
7
+ import random
8
+ import re
9
+ from dataclasses import dataclass
10
+ from typing import Any, List, Optional
11
+ from uuid import uuid4
12
+
13
+ from .models import (
14
+ CarStateData, LaneOccupancyData, OverflowAction,
15
+ OverflowObservation, OverflowState, Position, ProximityData,
16
+ )
17
+
18
+ NUM_LANES = 3
19
+ ROAD_LENGTH = 200
20
+ NUM_CARS = 5
21
+ MAX_STEPS = 100
22
+ CRASH_DISTANCE = 5.0
23
+ NEAR_MISS_DISTANCE = 15.0
24
+ LANE_WIDTH = 3.7
25
+
26
+ REWARD_CRASH = -5.0
27
+ REWARD_NEAR_MISS = -1.0
28
+ REWARD_SAFE_STEP = 0.5
29
+ REWARD_REACHED_GOAL = 3.0
30
+ REWARD_REASONING_MAX = 0.3
31
+
32
+ MIN_SPEED = 20
33
+ MAX_SPEED = 90
34
+ SPEED_DELTA = 5
35
+
36
+
37
+ @dataclass
38
+ class Car:
39
+ car_id: int
40
+ lane: int
41
+ position: float
42
+ speed: float
43
+ goal_position: float
44
+ is_agent: bool = False
45
+ reached_goal: bool = False
46
+ prev_speed: float = 0.0
47
+
48
+ def distance_to(self, other: "Car") -> float:
49
+ lane_diff = abs(self.lane - other.lane) * 10.0
50
+ pos_diff = abs(self.position - other.position)
51
+ return math.sqrt(lane_diff ** 2 + pos_diff ** 2)
52
+
53
+ @property
54
+ def acceleration(self) -> float:
55
+ return self.speed - self.prev_speed
56
+
57
+ def to_state_data(self) -> CarStateData:
58
+ return CarStateData(
59
+ carId=self.car_id,
60
+ lane=self.lane,
61
+ position=Position(x=self.position, y=self.lane * LANE_WIDTH),
62
+ speed=self.speed,
63
+ acceleration=self.acceleration,
64
+ )
65
+
66
+
67
+ def _parse_decision(action: OverflowAction) -> str:
68
+ valid = {"accelerate", "brake", "lane_change_left", "lane_change_right", "maintain"}
69
+ decision = action.decision.strip().lower().replace(" ", "_")
70
+ if decision in valid:
71
+ return decision
72
+ text = f"{action.decision} {action.reasoning}".lower()
73
+ match = re.search(r"<action>\s*(\w+)\s*</action>", text)
74
+ if match:
75
+ candidate = match.group(1).strip().replace(" ", "_")
76
+ if candidate in valid:
77
+ return candidate
78
+ for v in ["lane_change_left", "lane_change_right", "accelerate", "brake", "maintain"]:
79
+ if v in text:
80
+ return v
81
+ return "maintain"
82
+
83
+
84
+ def _scripted_car_action(car: Car, all_cars: List[Car], rng: random.Random) -> str:
85
+ nearest_ahead_dist = float("inf")
86
+ for other in all_cars:
87
+ if other.car_id == car.car_id:
88
+ continue
89
+ if other.lane == car.lane and other.position > car.position:
90
+ dist = other.position - car.position
91
+ if dist < nearest_ahead_dist:
92
+ nearest_ahead_dist = dist
93
+ if nearest_ahead_dist < 20:
94
+ return "brake"
95
+ if car.speed < 60 and rng.random() < 0.1:
96
+ return "accelerate"
97
+ if rng.random() < 0.05:
98
+ if car.lane > 1 and rng.random() < 0.5:
99
+ return "lane_change_left"
100
+ elif car.lane < NUM_LANES:
101
+ return "lane_change_right"
102
+ return "maintain"
103
+
104
+
105
+ def _apply_action(car: Car, decision: str) -> None:
106
+ if decision == "accelerate":
107
+ car.speed = min(car.speed + SPEED_DELTA, MAX_SPEED)
108
+ elif decision == "brake":
109
+ car.speed = max(car.speed - SPEED_DELTA, MIN_SPEED)
110
+ elif decision == "lane_change_left":
111
+ if car.lane > 1:
112
+ car.lane -= 1
113
+ elif decision == "lane_change_right":
114
+ if car.lane < NUM_LANES:
115
+ car.lane += 1
116
+
117
+
118
+ def _generate_scene_description(agent_car: Car, cars: List[Car]) -> str:
119
+ lines = [
120
+ f"You are Car 0 in lane {agent_car.lane}, position {agent_car.position:.0f}, speed {agent_car.speed:.0f}.",
121
+ f"Goal: reach position {agent_car.goal_position:.0f}.",
122
+ "Nearby cars:",
123
+ ]
124
+ for car in cars:
125
+ if car.car_id == agent_car.car_id:
126
+ continue
127
+ detail = f"- Car {car.car_id}: lane {car.lane}, position {car.position:.0f}, speed {car.speed:.0f}"
128
+ if car.lane == agent_car.lane:
129
+ pos_diff = car.position - agent_car.position
130
+ if pos_diff > 0:
131
+ detail += f" [AHEAD IN YOUR LANE - {pos_diff:.0f} units away]"
132
+ else:
133
+ detail += f" [BEHIND IN YOUR LANE - {abs(pos_diff):.0f} units away]"
134
+ if car.reached_goal:
135
+ detail += " [REACHED GOAL]"
136
+ lines.append(detail)
137
+ return "\n".join(lines)
138
+
139
+
140
+ def _build_structured_data(cars: List[Car], proximity_pairs: List[ProximityData]):
141
+ cars_data = [c.to_state_data() for c in cars]
142
+ lane_map: dict = {}
143
+ for car in cars:
144
+ if not car.reached_goal:
145
+ lane_map.setdefault(car.lane, []).append(car.car_id)
146
+ lane_occupancies = [
147
+ LaneOccupancyData(lane=lane, carIds=ids)
148
+ for lane, ids in sorted(lane_map.items())
149
+ ]
150
+ return cars_data, lane_occupancies
151
+
152
+
153
+ class OverflowEnvironment:
154
+ def __init__(self):
155
+ self._state = OverflowState(episode_id=str(uuid4()))
156
+ self._cars: List[Car] = []
157
+ self._rng = random.Random()
158
+ self._done = False
159
+ self._last_obs: Optional[OverflowObservation] = None
160
+
161
+ def _build_observation(self, incident_report: str, reward: float,
162
+ proximities: Optional[List[ProximityData]] = None) -> OverflowObservation:
163
+ agent = self._cars[0]
164
+ scene = _generate_scene_description(agent, self._cars)
165
+ prox = proximities or []
166
+ cars_data, lane_occ = _build_structured_data(self._cars, prox)
167
+ return OverflowObservation(
168
+ scene_description=scene,
169
+ incident_report=incident_report,
170
+ done=self._done,
171
+ reward=reward,
172
+ cars=cars_data,
173
+ proximities=prox,
174
+ lane_occupancies=lane_occ,
175
+ )
176
+
177
+ def reset(self, seed: Optional[int] = None, **kwargs: Any) -> OverflowObservation:
178
+ if seed is not None:
179
+ self._rng = random.Random(seed)
180
+ else:
181
+ self._rng = random.Random()
182
+ self._state = OverflowState(
183
+ episode_id=str(uuid4()), step_count=0,
184
+ crash_count=0, near_miss_count=0, cars_reached_goal=0, total_cars=NUM_CARS,
185
+ )
186
+ self._done = False
187
+ self._cars = []
188
+ for i in range(NUM_CARS):
189
+ for _attempt in range(100):
190
+ lane = self._rng.randint(1, NUM_LANES)
191
+ position = float(self._rng.randint(10, 80))
192
+ too_close = False
193
+ for existing in self._cars:
194
+ lane_diff = abs(lane - existing.lane) * 10.0
195
+ pos_diff = abs(position - existing.position)
196
+ if math.sqrt(lane_diff ** 2 + pos_diff ** 2) < CRASH_DISTANCE * 2:
197
+ too_close = True
198
+ break
199
+ if not too_close:
200
+ break
201
+ speed = float(self._rng.randint(40, 70))
202
+ goal = float(self._rng.randint(160, 195))
203
+ self._cars.append(Car(
204
+ car_id=i, lane=lane, position=position, speed=speed,
205
+ goal_position=goal, is_agent=(i == 0), prev_speed=speed,
206
+ ))
207
+ self._last_obs = self._build_observation(incident_report="", reward=0.0)
208
+ return self._last_obs
209
+
210
+ def step(self, action: OverflowAction, **kwargs: Any) -> OverflowObservation:
211
+ if self._done:
212
+ return self._build_observation(
213
+ incident_report="Episode is over. Call reset() to start a new one.", reward=0.0
214
+ )
215
+ self._state.step_count += 1
216
+ reward = 0.0
217
+ incidents = []
218
+
219
+ for car in self._cars:
220
+ car.prev_speed = car.speed
221
+
222
+ decision = _parse_decision(action)
223
+ _apply_action(self._cars[0], decision)
224
+
225
+ for car in self._cars[1:]:
226
+ if car.reached_goal:
227
+ continue
228
+ _apply_action(car, _scripted_car_action(car, self._cars, self._rng))
229
+
230
+ for car in self._cars:
231
+ if not car.reached_goal:
232
+ car.position += car.speed * 0.1
233
+
234
+ agent_crash = False
235
+ proximity_list: List[ProximityData] = []
236
+ active_cars = [c for c in self._cars if not c.reached_goal]
237
+ agent_id = self._cars[0].car_id
238
+
239
+ for i in range(len(active_cars)):
240
+ for j in range(i + 1, len(active_cars)):
241
+ dist = active_cars[i].distance_to(active_cars[j])
242
+ involves_agent = (active_cars[i].car_id == agent_id or
243
+ active_cars[j].car_id == agent_id)
244
+ if dist < CRASH_DISTANCE:
245
+ self._state.crash_count += 1
246
+ proximity_list.append(ProximityData(
247
+ carA=active_cars[i].car_id, carB=active_cars[j].car_id,
248
+ distance=round(dist, 2),
249
+ ))
250
+ incidents.append(
251
+ f"CRASH between Car {active_cars[i].car_id} and Car {active_cars[j].car_id}! "
252
+ f"(distance: {dist:.1f})"
253
+ )
254
+ if involves_agent:
255
+ agent_crash = True
256
+ elif dist < NEAR_MISS_DISTANCE:
257
+ self._state.near_miss_count += 1
258
+ if involves_agent:
259
+ reward += REWARD_NEAR_MISS
260
+ proximity_list.append(ProximityData(
261
+ carA=active_cars[i].car_id, carB=active_cars[j].car_id,
262
+ distance=round(dist, 2),
263
+ ))
264
+ incidents.append(
265
+ f"NEAR MISS between Car {active_cars[i].car_id} and Car {active_cars[j].car_id} "
266
+ f"(distance: {dist:.1f})"
267
+ )
268
+
269
+ if agent_crash:
270
+ reward += REWARD_CRASH
271
+ self._done = True
272
+ else:
273
+ agent = self._cars[0]
274
+ if agent.position >= agent.goal_position:
275
+ agent.reached_goal = True
276
+ self._state.cars_reached_goal += 1
277
+ reward += REWARD_REACHED_GOAL
278
+ incidents.append(f"Car 0 reached its goal at position {agent.goal_position:.0f}!")
279
+ self._done = True
280
+ for car in self._cars[1:]:
281
+ if not car.reached_goal and car.position >= car.goal_position:
282
+ car.reached_goal = True
283
+ self._state.cars_reached_goal += 1
284
+ if not self._done:
285
+ reward += REWARD_SAFE_STEP
286
+
287
+ if self._state.step_count >= MAX_STEPS and not self._done:
288
+ self._done = True
289
+ incidents.append(f"Maximum steps ({MAX_STEPS}) reached.")
290
+
291
+ incident_report = "\n".join(incidents) if incidents else "Observer: No incidents this step."
292
+ self._last_obs = self._build_observation(
293
+ incident_report=incident_report, reward=reward, proximities=proximity_list,
294
+ )
295
+ return self._last_obs
overflow_env/models.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from pydantic import BaseModel, Field
3
+
4
+
5
+ class Position(BaseModel):
6
+ x: float = 0.0
7
+ y: float = 0.0
8
+
9
+
10
+ class CarStateData(BaseModel):
11
+ carId: int
12
+ lane: int
13
+ position: Position
14
+ speed: float
15
+ acceleration: float = 0.0
16
+
17
+
18
+ class ProximityData(BaseModel):
19
+ carA: int
20
+ carB: int
21
+ distance: float
22
+
23
+
24
+ class LaneOccupancyData(BaseModel):
25
+ lane: int
26
+ carIds: List[int]
27
+
28
+
29
+ class OverflowAction(BaseModel):
30
+ decision: str = Field(default="maintain")
31
+ reasoning: str = Field(default="")
32
+
33
+
34
+ class OverflowObservation(BaseModel):
35
+ done: bool = False
36
+ reward: float = 0.0
37
+ scene_description: str = ""
38
+ incident_report: str = ""
39
+ cars: List[CarStateData] = Field(default_factory=list)
40
+ proximities: List[ProximityData] = Field(default_factory=list)
41
+ lane_occupancies: List[LaneOccupancyData] = Field(default_factory=list)
42
+
43
+
44
+ class OverflowState(BaseModel):
45
+ episode_id: str = ""
46
+ step_count: int = 0
47
+ crash_count: int = 0
48
+ near_miss_count: int = 0
49
+ cars_reached_goal: int = 0
50
+ total_cars: int = 5
requirements.txt CHANGED
@@ -2,7 +2,5 @@
2
  torch==2.5.1+cpu
3
  numpy>=1.24.0
4
  pillow==10.4.0
5
- gradio>=4.44.0
6
  pydantic>=2.0.0
7
  requests>=2.31.0
8
- openenv-overflow-env @ git+https://huggingface.co/spaces/SteveDusty/overflow_env
 
2
  torch==2.5.1+cpu
3
  numpy>=1.24.0
4
  pillow==10.4.0
 
5
  pydantic>=2.0.0
6
  requests>=2.31.0