Spaces:
Sleeping
Sleeping
| """ | |
| models.py β Typed data contracts for the Traffic Flow environment. | |
| Every piece of data the environment sends or receives is defined here | |
| as a Pydantic model. This gives us: | |
| - Automatic validation (bad data raises an error immediately) | |
| - Auto-generated JSON schema (FastAPI uses this for /docs) | |
| - Clean type hints throughout the codebase | |
| Three models you MUST have for OpenEnv spec compliance: | |
| TrafficAction β what the agent sends IN (input) | |
| TrafficObservation β what the environment sends OUT after each step | |
| TrafficState β episode-level metadata returned by state() | |
| """ | |
| from typing import List, Optional, Dict | |
| from pydantic import BaseModel, Field | |
| from enum import Enum | |
| # ---------------------------- | |
| # ACTION SPACE | |
| # What the agent can do. Kept deliberately small so the | |
| # agent has fewer choices to learn β best practice for RL. | |
| # ---------------------------- | |
| class ActionType(str, Enum): | |
| EXTEND_GREEN = "extend_green" # Hold current green phase 5 extra seconds | |
| NEXT_PHASE = "next_phase" # Switch to next phase immediately | |
| class TrafficAction(BaseModel): | |
| """ | |
| One action the agent takes per step. | |
| action_type : extend_green | next_phase | |
| intersection_id : which intersection to control (0-indexed) | |
| easy=0 only, medium=0-2, hard=0-8 | |
| """ | |
| action_type: ActionType = Field( | |
| ..., | |
| description="extend_green = +5s on current green; next_phase = switch now", | |
| ) | |
| intersection_id: int = Field( | |
| default=0, | |
| ge=0, | |
| description="Index of the intersection to control", | |
| ) | |
| # ---------------------------- | |
| # OBSERVATION SPACE | |
| # What the agent sees. Mirrors real road-sensor output: | |
| # queue lengths and wait times per lane β nothing more. | |
| # ---------------------------- | |
| class LaneState(BaseModel): | |
| """Per-lane sensor reading (one road-camera / loop-detector).""" | |
| lane_id: str | |
| queue_length: int = Field(..., ge=0, description="Cars waiting at red") | |
| avg_wait_time: float = Field(..., ge=0.0, description="Average wait time (seconds)") | |
| flow_rate: float = Field(..., ge=0.0, description="Max possible veh/min on green") | |
| class IntersectionState(BaseModel): | |
| """All sensor readings for one intersection.""" | |
| intersection_id: int | |
| current_phase: int = Field(..., description="Active signal phase (0β3)") | |
| phase_elapsed: float = Field(..., description="Seconds spent in this phase so far") | |
| lanes: List[LaneState] | |
| class TrafficObservation(BaseModel): | |
| """ | |
| Returned by reset() and step(). | |
| Everything the agent needs to make its next decision. | |
| """ | |
| intersections: List[IntersectionState] | |
| total_waiting_vehicles: int | |
| total_avg_wait: float = Field(..., description="Fleet-wide average wait (seconds)") | |
| throughput_last_step: int = Field(..., description="Vehicles cleared this step") | |
| reward: float = Field(..., description="Step reward signal") | |
| done: bool = False | |
| info: Dict[str, float] = Field(default_factory=dict) | |
| # ---------------------------- | |
| # EPISODE STATE | |
| # Returned by state(). NOT the same as observation β | |
| # this is bookkeeping metadata for the training framework. | |
| # ---------------------------- | |
| class TrafficState(BaseModel): | |
| """Episode-level metadata returned by state().""" | |
| episode_id: Optional[str] = None | |
| step_count: int = 0 | |
| task_level: str = Field("easy", description="easy | medium | hard") | |
| cumulative_reward: float = 0.0 | |
| cumulative_throughput: int = 0 | |
| cumulative_wait: float = 0.0 | |