File size: 3,885 Bytes
cb054fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Data models for the Overflow Environment.

An autonomous vehicle fleet oversight environment where an LLM agent
controls one car on a 2D road grid while other cars follow scripted rules.

Structured observation fields (cars, proximities, lane_occupancies) are
compatible with the Overflow frontend's CarState / AnomalyObservation types.
"""

from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field

try:
    from openenv.core.env_server.types import Action, Observation, State
except ImportError:
    class Action(BaseModel): pass
    class Observation(BaseModel):
        done: bool = False
        reward: float = 0.0
    class State(BaseModel):
        episode_id: str = ""
        step_count: int = 0

# ── Structured sub-models (frontend-compatible) ─────────────────────────


class Position(BaseModel):
    """2D position on the road. x = longitudinal, y = lateral."""

    x: float = 0.0
    y: float = 0.0


class CarStateData(BaseModel):
    """
    Structured per-car snapshot β€” matches the frontend CarState interface.

    Frontend type:
        interface CarState {
            carId: number; lane: number;
            position: { x: number; y: number };
            speed: number; acceleration: number;
        }
    """

    carId: int
    lane: int
    position: Position
    speed: float
    acceleration: float = 0.0


class ProximityData(BaseModel):
    """Pairwise distance between two cars."""

    carA: int
    carB: int
    distance: float


class LaneOccupancyData(BaseModel):
    """Which cars are in a given lane."""

    lane: int
    carIds: List[int]


# ── OpenEnv core models ─────────────────────────────────────────────────


class OverflowAction(Action):
    """
    Action for the Overflow environment.

    The LLM agent outputs a driving decision and optional reasoning.
    """

    decision: str = Field(
        default="maintain",
        description="Driving decision: accelerate, brake, lane_change_left, lane_change_right, maintain",
    )
    reasoning: str = Field(
        default="",
        description="The LLM's chain-of-thought reasoning for this decision",
    )


class OverflowObservation(Observation):
    """
    Observation from the Overflow environment.

    Contains both:
    - Text fields (scene_description, incident_report) for the LLM to read.
    - Structured fields (cars, proximities, lane_occupancies) for the frontend
      to render, matching the Overflow frontend AnomalyObservation shape.
    """

    # ── Text (for the LLM) ──
    scene_description: str = Field(
        default="", description="Text description of the traffic scene"
    )
    incident_report: str = Field(
        default="", description="Observer's incident report, empty if no incident"
    )

    # ── Structured (for the frontend / viz) ──
    cars: List[CarStateData] = Field(
        default_factory=list, description="Structured state of every car"
    )
    proximities: List[ProximityData] = Field(
        default_factory=list, description="Pairwise proximity measurements"
    )
    lane_occupancies: List[LaneOccupancyData] = Field(
        default_factory=list, description="Per-lane vehicle occupancy"
    )


class OverflowState(State):
    """
    Internal state for the Overflow environment.
    """

    crash_count: int = Field(default=0, description="Number of crashes this episode")
    near_miss_count: int = Field(
        default=0, description="Number of near misses this episode"
    )
    cars_reached_goal: int = Field(
        default=0, description="Number of cars that reached their goal"
    )
    total_cars: int = Field(
        default=5, description="Total number of cars in the simulation"
    )