File size: 17,258 Bytes
cb054fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
"""
Overflow Environment Implementation.

A 2D road grid with N cars. One car (Car 0) is the LLM agent, others follow
scripted rules. An observer checks for collisions each step. The environment
returns text observations describing the traffic scene and rewards based on safety.

Observations carry both text (for the LLM) and structured data (for the frontend).
"""

import math
import random
import re
from dataclasses import dataclass, field
from typing import Any, List, Optional
from uuid import uuid4

try:
    from openenv.core.env_server.interfaces import Environment
    from openenv.core.env_server.types import State
except ImportError:
    class Environment:  # stub for training-only mode
        pass
    class State:
        pass

try:
    from ..models import (
        CarStateData, LaneOccupancyData, OverflowAction,
        OverflowObservation, OverflowState, Position, ProximityData,
    )
    from ..policies.flat_mlp_policy import FlatMLPPolicy
    from ..policies.ticket_attention_policy import TicketAttentionPolicy
    from ..policies.policy_spec import OBS_DIM
    from .policy_adapter import overflow_obs_to_policy_obs, policy_action_to_decision
except ImportError:
    from models import (
        CarStateData, LaneOccupancyData, OverflowAction,
        OverflowObservation, OverflowState, Position, ProximityData,
    )
    from policies.flat_mlp_policy import FlatMLPPolicy
    from policies.ticket_attention_policy import TicketAttentionPolicy
    from policies.policy_spec import OBS_DIM
    from server.policy_adapter import overflow_obs_to_policy_obs, policy_action_to_decision

# --- Constants ---
NUM_LANES = 3
ROAD_LENGTH = 200
NUM_CARS = 5
MAX_STEPS = 100
CRASH_DISTANCE = 5.0
NEAR_MISS_DISTANCE = 15.0
LANE_WIDTH = 3.7  # metres — matches frontend's makeCar convention

# Reward values
REWARD_CRASH = -5.0
REWARD_NEAR_MISS = -1.0
REWARD_SAFE_STEP = 0.5
REWARD_REACHED_GOAL = 3.0
REWARD_REASONING_MAX = 0.3

# Speed bounds
MIN_SPEED = 20
MAX_SPEED = 90
SPEED_DELTA = 5


@dataclass
class Car:
    """Represents a car on the road grid."""

    car_id: int
    lane: int  # 1-indexed: 1, 2, or 3
    position: float
    speed: float
    goal_position: float
    is_agent: bool = False
    reached_goal: bool = False
    prev_speed: float = 0.0  # speed last step, for acceleration calc

    def distance_to(self, other: "Car") -> float:
        """Euclidean-ish distance considering lane and position."""
        lane_diff = abs(self.lane - other.lane) * 10.0  # lanes are ~10 units apart
        pos_diff = abs(self.position - other.position)
        return math.sqrt(lane_diff**2 + pos_diff**2)

    @property
    def acceleration(self) -> float:
        """Speed delta since last step."""
        return self.speed - self.prev_speed

    def to_state_data(self) -> CarStateData:
        """Convert to frontend-compatible CarStateData."""
        return CarStateData(
            carId=self.car_id,
            lane=self.lane,
            position=Position(x=self.position, y=self.lane * LANE_WIDTH),
            speed=self.speed,
            acceleration=self.acceleration,
        )


def _parse_decision(action: OverflowAction) -> str:
    """Extract a valid decision from the action, being forgiving about format."""
    valid = {"accelerate", "brake", "lane_change_left", "lane_change_right", "maintain"}

    # Try the decision field directly
    decision = action.decision.strip().lower().replace(" ", "_")
    if decision in valid:
        return decision

    # Try to extract from free text (the LLM might wrap it in tags)
    text = f"{action.decision} {action.reasoning}".lower()

    # Check for <action>...</action> tags
    match = re.search(r"<action>\s*(\w+)\s*</action>", text)
    if match:
        candidate = match.group(1).strip().replace(" ", "_")
        if candidate in valid:
            return candidate

    # Check for keywords anywhere (ordered: most specific first to avoid ambiguity)
    for v in ["lane_change_left", "lane_change_right", "accelerate", "brake", "maintain"]:
        if v in text:
            return v

    return "maintain"


def _compute_reasoning_bonus(reasoning: str) -> float:
    """
    Compute a small reasoning quality bonus (0.0 to 0.3).

    Gives a minor reward for providing structured reasoning, kept low
    so driving performance remains the dominant training signal.
    """
    if not reasoning:
        return 0.0

    score = 0.0
    lower = reasoning.lower()

    # Small bonus for providing any reasoning at all
    if len(reasoning) > 20:
        score += 0.1

    # Bonus for structured reasoning (not just keyword stuffing)
    if "<think>" in lower or "because" in lower:
        score += 0.1
    if any(word in lower for word in ["therefore", "so i should", "best option", "i will"]):
        score += 0.1

    return min(score, REWARD_REASONING_MAX)


def _scripted_car_action(car: Car, all_cars: List[Car], rng: random.Random) -> str:
    """
    Simple scripted AI for non-agent cars.

    Rules:
    - If car ahead in same lane is close (< 20 units): brake
    - If speed is low and random chance: accelerate
    - Otherwise: maintain
    """
    # Find nearest car ahead in same lane
    nearest_ahead_dist = float("inf")
    for other in all_cars:
        if other.car_id == car.car_id:
            continue
        if other.lane == car.lane and other.position > car.position:
            dist = other.position - car.position
            if dist < nearest_ahead_dist:
                nearest_ahead_dist = dist

    if nearest_ahead_dist < 20:
        return "brake"

    if car.speed < 60 and rng.random() < 0.1:
        return "accelerate"

    # Occasionally change lanes to make traffic more dynamic
    if rng.random() < 0.05:
        if car.lane > 1 and rng.random() < 0.5:
            return "lane_change_left"
        elif car.lane < NUM_LANES:
            return "lane_change_right"

    return "maintain"


def _apply_action(car: Car, decision: str) -> None:
    """Apply a driving decision to a car, mutating it in place."""
    if decision == "accelerate":
        car.speed = min(car.speed + SPEED_DELTA, MAX_SPEED)
    elif decision == "brake":
        car.speed = max(car.speed - SPEED_DELTA, MIN_SPEED)
    elif decision == "lane_change_left":
        if car.lane > 1:
            car.lane -= 1
    elif decision == "lane_change_right":
        if car.lane < NUM_LANES:
            car.lane += 1
    # "maintain" — no change


def _generate_scene_description(agent_car: Car, cars: List[Car]) -> str:
    """Generate a text description of the current traffic scene."""
    lines = [
        f"You are Car 0 in lane {agent_car.lane}, position {agent_car.position:.0f}, speed {agent_car.speed:.0f}.",
        f"Goal: reach position {agent_car.goal_position:.0f}.",
        "Nearby cars:",
    ]

    for car in cars:
        if car.car_id == agent_car.car_id:
            continue

        detail = f"- Car {car.car_id}: lane {car.lane}, position {car.position:.0f}, speed {car.speed:.0f}"

        # Add context about relative position
        if car.lane == agent_car.lane:
            pos_diff = car.position - agent_car.position
            if pos_diff > 0:
                detail += f" [AHEAD IN YOUR LANE - {pos_diff:.0f} units away]"
            else:
                detail += f" [BEHIND IN YOUR LANE - {abs(pos_diff):.0f} units away]"

        if car.reached_goal:
            detail += " [REACHED GOAL]"

        lines.append(detail)

    return "\n".join(lines)


def _build_structured_data(
    cars: List[Car],
    proximity_pairs: List[ProximityData],
) -> tuple[List[CarStateData], List[LaneOccupancyData]]:
    """Build structured arrays for the observation."""
    cars_data = [c.to_state_data() for c in cars]

    # Lane occupancies
    lane_map: dict[int, list[int]] = {}
    for car in cars:
        if not car.reached_goal:
            lane_map.setdefault(car.lane, []).append(car.car_id)
    lane_occupancies = [
        LaneOccupancyData(lane=lane, carIds=ids)
        for lane, ids in sorted(lane_map.items())
    ]

    return cars_data, lane_occupancies


class OverflowEnvironment(Environment):
    """
    Autonomous vehicle fleet oversight environment.

    A 2D road grid with N cars. Car 0 is the LLM agent, others follow
    scripted rules. The observer detects crashes and near-misses and
    computes rewards based on safety.
    """

    def __init__(self):
        super().__init__()
        self._state = OverflowState(episode_id=str(uuid4()))
        self._cars: List[Car] = []
        self._rng = random.Random()
        self._done = False
        self._last_obs: Optional[OverflowObservation] = None
        self._policies = {
            "flat_mlp":         FlatMLPPolicy(obs_dim=OBS_DIM),
            "ticket_attention":  TicketAttentionPolicy(obs_dim=OBS_DIM),
        }

    def _build_observation(
        self,
        incident_report: str,
        reward: float,
        proximities: Optional[List[ProximityData]] = None,
    ) -> OverflowObservation:
        """Build a full observation with text + structured data."""
        agent = self._cars[0]
        scene = _generate_scene_description(agent, self._cars)
        prox = proximities or []
        cars_data, lane_occ = _build_structured_data(self._cars, prox)

        return OverflowObservation(
            scene_description=scene,
            incident_report=incident_report,
            done=self._done,
            reward=reward,
            cars=cars_data,
            proximities=prox,
            lane_occupancies=lane_occ,
        )

    def reset(
        self,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        **kwargs: Any,
    ) -> OverflowObservation:
        """Reset the environment: create road and spawn cars."""
        if seed is not None:
            self._rng = random.Random(seed)
        else:
            self._rng = random.Random()

        self._state = OverflowState(
            episode_id=episode_id or str(uuid4()),
            step_count=0,
            crash_count=0,
            near_miss_count=0,
            cars_reached_goal=0,
            total_cars=NUM_CARS,
        )
        self._done = False

        # Spawn cars with random positions, speeds, lanes, and goals
        self._cars = []

        for i in range(NUM_CARS):
            # Ensure no two cars spawn within crash distance
            for _attempt in range(100):
                lane = self._rng.randint(1, NUM_LANES)
                position = float(self._rng.randint(10, 80))
                too_close = False
                for existing in self._cars:
                    lane_diff = abs(lane - existing.lane) * 10.0
                    pos_diff = abs(position - existing.position)
                    dist = math.sqrt(lane_diff**2 + pos_diff**2)
                    if dist < CRASH_DISTANCE * 2:
                        too_close = True
                        break
                if not too_close:
                    break

            speed = float(self._rng.randint(40, 70))
            goal = float(self._rng.randint(160, 195))

            self._cars.append(
                Car(
                    car_id=i,
                    lane=lane,
                    position=position,
                    speed=speed,
                    goal_position=goal,
                    is_agent=(i == 0),
                    prev_speed=speed,  # no delta on first step
                )
            )

        self._last_obs = self._build_observation(incident_report="", reward=0.0)
        return self._last_obs

    def step(
        self,
        action: OverflowAction,
        timeout_s: Optional[float] = None,
        **kwargs: Any,
    ) -> OverflowObservation:
        """Execute one simulation step."""
        if self._done:
            return self._build_observation(
                incident_report="Episode is over. Call reset() to start a new one.",
                reward=0.0,
            )

        # Policy intercept: decision="policy:flat_mlp" or "policy:ticket_attention"
        if action.decision.startswith("policy:") and self._last_obs is not None:
            policy_name = action.decision.split(":", 1)[1].lower()
            if policy_name in self._policies:
                obs_vec = overflow_obs_to_policy_obs(self._last_obs)
                act_vec = self._policies[policy_name].predict(obs_vec)
                decision, reasoning = policy_action_to_decision(act_vec)
                action = OverflowAction(
                    decision=decision,
                    reasoning=f"[{policy_name}] {reasoning}",
                )

        self._state.step_count += 1
        reward = 0.0
        incidents = []

        # Snapshot previous speeds for acceleration tracking
        for car in self._cars:
            car.prev_speed = car.speed

        # 1. Parse and apply the agent's action to Car 0
        decision = _parse_decision(action)
        _apply_action(self._cars[0], decision)

        # 2. Compute and apply scripted actions for Cars 1-N
        for car in self._cars[1:]:
            if car.reached_goal:
                continue
            scripted_decision = _scripted_car_action(car, self._cars, self._rng)
            _apply_action(car, scripted_decision)

        # 3. Move all cars forward based on speed (speed is in units/step, scaled down)
        for car in self._cars:
            if car.reached_goal:
                continue
            car.position += car.speed * 0.1  # scale factor for reasonable movement

        # 4. Collision detection (pairwise)
        agent_crash = False
        proximity_list: List[ProximityData] = []
        active_cars = [c for c in self._cars if not c.reached_goal]
        agent_id = self._cars[0].car_id
        for i in range(len(active_cars)):
            for j in range(i + 1, len(active_cars)):
                dist = active_cars[i].distance_to(active_cars[j])
                involves_agent = active_cars[i].car_id == agent_id or active_cars[j].car_id == agent_id
                if dist < CRASH_DISTANCE:
                    self._state.crash_count += 1
                    proximity_list.append(
                        ProximityData(
                            carA=active_cars[i].car_id,
                            carB=active_cars[j].car_id,
                            distance=round(dist, 2),
                        )
                    )
                    incidents.append(
                        f"CRASH between Car {active_cars[i].car_id} and Car {active_cars[j].car_id}! "
                        f"(distance: {dist:.1f})"
                    )
                    if involves_agent:
                        agent_crash = True
                elif dist < NEAR_MISS_DISTANCE:
                    self._state.near_miss_count += 1
                    # Only penalize near misses involving the agent
                    if involves_agent:
                        reward += REWARD_NEAR_MISS
                    proximity_list.append(
                        ProximityData(
                            carA=active_cars[i].car_id,
                            carB=active_cars[j].car_id,
                            distance=round(dist, 2),
                        )
                    )
                    incidents.append(
                        f"NEAR MISS between Car {active_cars[i].car_id} and Car {active_cars[j].car_id} "
                        f"(distance: {dist:.1f})"
                    )

        if agent_crash:
            reward += REWARD_CRASH
            self._done = True
        else:
            # 5. Goal check for agent car
            agent = self._cars[0]
            if agent.position >= agent.goal_position:
                agent.reached_goal = True
                self._state.cars_reached_goal += 1
                reward += REWARD_REACHED_GOAL
                incidents.append(
                    f"Car 0 reached its goal at position {agent.goal_position:.0f}!"
                )
                self._done = True

            # Check goal for scripted cars too (for state tracking)
            for car in self._cars[1:]:
                if not car.reached_goal and car.position >= car.goal_position:
                    car.reached_goal = True
                    self._state.cars_reached_goal += 1

            # 6. Safe step bonus (no crash, agent still active)
            if not self._done:
                reward += REWARD_SAFE_STEP

        # 7. Reasoning quality bonus
        reasoning_bonus = _compute_reasoning_bonus(action.reasoning)
        reward += reasoning_bonus

        # 8. Max steps check
        if self._state.step_count >= MAX_STEPS and not self._done:
            self._done = True
            incidents.append(f"Maximum steps ({MAX_STEPS}) reached.")

        incident_report = (
            "\n".join(incidents) if incidents else "Observer: No incidents this step."
        )

        self._last_obs = self._build_observation(
            incident_report=incident_report,
            reward=reward,
            proximities=proximity_list,
        )
        return self._last_obs

    @property
    def state(self) -> OverflowState:
        """Get the current environment state."""
        return self._state