traffic-env / models.py
anidoesdev's picture
Upload folder using huggingface_hub
efe18c4 verified
Raw
History Blame Contribute Delete
3.72 kB
"""
models.py β€” Typed data contracts for the Traffic Flow environment.
Every piece of data the environment sends or receives is defined here
as a Pydantic model. This gives us:
- Automatic validation (bad data raises an error immediately)
- Auto-generated JSON schema (FastAPI uses this for /docs)
- Clean type hints throughout the codebase
Three models you MUST have for OpenEnv spec compliance:
TrafficAction β†’ what the agent sends IN (input)
TrafficObservation β†’ what the environment sends OUT after each step
TrafficState β†’ episode-level metadata returned by state()
"""
from typing import List, Optional, Dict
from pydantic import BaseModel, Field
from enum import Enum
# ----------------------------
# ACTION SPACE
# What the agent can do. Kept deliberately small so the
# agent has fewer choices to learn β€” best practice for RL.
# ----------------------------
class ActionType(str, Enum):
EXTEND_GREEN = "extend_green" # Hold current green phase 5 extra seconds
NEXT_PHASE = "next_phase" # Switch to next phase immediately
class TrafficAction(BaseModel):
"""
One action the agent takes per step.
action_type : extend_green | next_phase
intersection_id : which intersection to control (0-indexed)
easy=0 only, medium=0-2, hard=0-8
"""
action_type: ActionType = Field(
...,
description="extend_green = +5s on current green; next_phase = switch now",
)
intersection_id: int = Field(
default=0,
ge=0,
description="Index of the intersection to control",
)
# ----------------------------
# OBSERVATION SPACE
# What the agent sees. Mirrors real road-sensor output:
# queue lengths and wait times per lane β€” nothing more.
# ----------------------------
class LaneState(BaseModel):
"""Per-lane sensor reading (one road-camera / loop-detector)."""
lane_id: str
queue_length: int = Field(..., ge=0, description="Cars waiting at red")
avg_wait_time: float = Field(..., ge=0.0, description="Average wait time (seconds)")
flow_rate: float = Field(..., ge=0.0, description="Max possible veh/min on green")
class IntersectionState(BaseModel):
"""All sensor readings for one intersection."""
intersection_id: int
current_phase: int = Field(..., description="Active signal phase (0–3)")
phase_elapsed: float = Field(..., description="Seconds spent in this phase so far")
lanes: List[LaneState]
class TrafficObservation(BaseModel):
"""
Returned by reset() and step().
Everything the agent needs to make its next decision.
"""
intersections: List[IntersectionState]
total_waiting_vehicles: int
total_avg_wait: float = Field(..., description="Fleet-wide average wait (seconds)")
throughput_last_step: int = Field(..., description="Vehicles cleared this step")
reward: float = Field(..., description="Step reward signal")
done: bool = False
info: Dict[str, float] = Field(default_factory=dict)
# ----------------------------
# EPISODE STATE
# Returned by state(). NOT the same as observation β€”
# this is bookkeeping metadata for the training framework.
# ----------------------------
class TrafficState(BaseModel):
"""Episode-level metadata returned by state()."""
episode_id: Optional[str] = None
step_count: int = 0
task_level: str = Field("easy", description="easy | medium | hard")
cumulative_reward: float = 0.0
cumulative_throughput: int = 0
cumulative_wait: float = 0.0