traffic_light_env / models.py
rishabh16196's picture
Upload folder using huggingface_hub
5920175 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the Traffic Light Environment.
A single 4-way intersection with four traffic-flow directions:
NS (north-to-south), SN (south-to-north),
EW (east-to-west), WE (west-to-east).
Each direction has 2 lanes, giving 8 lanes total.
Each lane has two observation zones: 100 m and 500 m from the intersection.
Traffic lights have 3 states per direction: red (0), yellow (1), green (2).
The agent selects one of 6 phases that determine which directions get green.
"""
from typing import Any, Dict, List, Optional
from openenv.core.env_server.types import Action, Observation
from pydantic import Field
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# Light state constants
LIGHT_RED = 0
LIGHT_YELLOW = 1
LIGHT_GREEN = 2
# Direction indices
DIR_NS = 0 # North β†’ South
DIR_SN = 1 # South β†’ North
DIR_EW = 2 # East β†’ West
DIR_WE = 3 # West β†’ East
NUM_DIRECTIONS = 4
LANES_PER_DIRECTION = 2
NUM_LANES = NUM_DIRECTIONS * LANES_PER_DIRECTION # 8
DIRECTION_NAMES = ["ns", "sn", "ew", "we"]
# Phase definitions β€” which directions get green
NUM_PHASES = 6
PHASE_NAMES = [
"ns_sn_corridor", # 0: full north-south corridor
"ew_we_corridor", # 1: full east-west corridor
"ns_only", # 2: north-to-south only
"sn_only", # 3: south-to-north only
"ew_only", # 4: east-to-west only
"we_only", # 5: west-to-east only
]
# ---------------------------------------------------------------------------
# Vehicle types with real-world physics
# ---------------------------------------------------------------------------
# Stopping distance = d_reaction + d_braking
# d_reaction = speed_ms Γ— reaction_time_s
# d_braking = speed_msΒ² / (2 Γ— deceleration_ms2)
# Assumes dry urban road (friction coefficient β‰ˆ 0.7).
VEHICLE_TYPES: Dict[str, Dict[str, float]] = {
"car": {
"speed_kmh": 50.0,
"reaction_time_s": 1.0,
"deceleration_ms2": 6.8, # standard passenger car, dry road
"spawn_weight": 0.40,
},
"suv": {
"speed_kmh": 50.0,
"reaction_time_s": 1.2,
"deceleration_ms2": 6.0, # higher center of gravity, slightly worse braking
"spawn_weight": 0.25,
},
"bus": {
"speed_kmh": 40.0,
"reaction_time_s": 1.5,
"deceleration_ms2": 4.5, # heavy, pneumatic brakes, longer reaction
"spawn_weight": 0.10,
},
"truck": {
"speed_kmh": 45.0,
"reaction_time_s": 1.4,
"deceleration_ms2": 4.0, # heaviest, longest stopping distance
"spawn_weight": 0.15,
},
"motorcycle": {
"speed_kmh": 55.0,
"reaction_time_s": 0.8,
"deceleration_ms2": 7.5, # light, good brakes, alert rider
"spawn_weight": 0.10,
},
}
VEHICLE_TYPE_NAMES: List[str] = list(VEHICLE_TYPES.keys())
def stopping_distance(vtype: str) -> float:
"""Compute total stopping distance in metres for a vehicle type."""
props = VEHICLE_TYPES[vtype]
v = props["speed_kmh"] / 3.6 # km/h β†’ m/s
d_react = v * props["reaction_time_s"]
d_brake = v ** 2 / (2.0 * props["deceleration_ms2"])
return d_react + d_brake
# Pre-computed stopping distances (metres) and dilemma-zone fractions
# Dilemma fraction = stopping_distance / 100 m (clamped to [0, 1])
# Vehicles within this fraction of the 100 m zone can't stop safely.
STOPPING_DISTANCES: Dict[str, float] = {
vt: round(stopping_distance(vt), 1) for vt in VEHICLE_TYPE_NAMES
}
DILEMMA_FRACTIONS: Dict[str, float] = {
vt: min(STOPPING_DISTANCES[vt] / 100.0, 1.0) for vt in VEHICLE_TYPE_NAMES
}
# Available task names
TASK_NAMES = [
"balanced",
"rush_hour_ns",
"rush_hour_ew",
"alternating_surge",
"random_spikes",
"gridlock",
"emergency_vehicle",
]
# ---------------------------------------------------------------------------
# Action
# ---------------------------------------------------------------------------
class TrafficLightAction(Action):
"""Action: choose the desired green phase.
phase:
0 = NS+SN corridor (both north-south directions green)
1 = EW+WE corridor (both east-west directions green)
2 = NS only (north-to-south green)
3 = SN only (south-to-north green)
4 = EW only (east-to-west green)
5 = WE only (west-to-east green)
Switching to a different phase triggers a mandatory yellow transition
period before the new phase activates.
"""
phase: int = Field(
...,
description=(
"Desired green phase: "
"0=NS+SN corridor, 1=EW+WE corridor, "
"2=NS only, 3=SN only, 4=EW only, 5=WE only"
),
ge=0,
le=NUM_PHASES - 1,
)
# ---------------------------------------------------------------------------
# Observation
# ---------------------------------------------------------------------------
class TrafficLightObservation(Observation):
"""Observation from the intersection with per-direction queue totals."""
# Task info
task_name: str = Field(default="balanced", description="Current task/scenario name")
# Per-direction 100 m zone totals (sum of both lanes in each direction)
ns_100m: int = Field(default=0, description="Vehicles within 100 m β€” NS direction (2 lanes)")
sn_100m: int = Field(default=0, description="Vehicles within 100 m β€” SN direction (2 lanes)")
ew_100m: int = Field(default=0, description="Vehicles within 100 m β€” EW direction (2 lanes)")
we_100m: int = Field(default=0, description="Vehicles within 100 m β€” WE direction (2 lanes)")
# Per-direction 500 m zone totals
ns_500m: int = Field(default=0, description="Vehicles 100-500 m β€” NS direction (2 lanes)")
sn_500m: int = Field(default=0, description="Vehicles 100-500 m β€” SN direction (2 lanes)")
ew_500m: int = Field(default=0, description="Vehicles 100-500 m β€” EW direction (2 lanes)")
we_500m: int = Field(default=0, description="Vehicles 100-500 m β€” WE direction (2 lanes)")
# Per-direction light state: 0=red, 1=yellow, 2=green
light_ns: int = Field(default=0, description="NS direction light: 0=red, 1=yellow, 2=green")
light_sn: int = Field(default=0, description="SN direction light: 0=red, 1=yellow, 2=green")
light_ew: int = Field(default=0, description="EW direction light: 0=red, 1=yellow, 2=green")
light_we: int = Field(default=0, description="WE direction light: 0=red, 1=yellow, 2=green")
# Emergency vehicle info
emergency_direction: int = Field(
default=-1,
description="Direction with emergency vehicle: 0=NS, 1=SN, 2=EW, 3=WE, -1=none",
)
emergency_lane: int = Field(
default=-1,
description="Specific lane with emergency vehicle (0-7), -1=none",
)
emergency_wait: int = Field(
default=0, description="Steps the emergency vehicle has been waiting"
)
# Phase and timing
active_phase: int = Field(
default=0,
description="Active phase 0-5 (-1 during yellow transition)",
)
yellow_remaining: int = Field(
default=0, description="Steps remaining in yellow transition (0 = not transitioning)"
)
time_in_phase: int = Field(default=0, description="Steps since last phase change completed")
step_number: int = Field(default=0, description="Current simulation step")
# Aggregate stats
total_waiting: int = Field(default=0, description="Total vehicles across all lanes and zones")
total_throughput: int = Field(default=0, description="Cumulative vehicles that have passed through")
arrivals: List[int] = Field(
default_factory=lambda: [0, 0, 0, 0],
description="Vehicles arrived this step per direction [NS, SN, EW, WE]",
)
departures: List[int] = Field(
default_factory=lambda: [0, 0, 0, 0],
description="Vehicles departed this step per direction [NS, SN, EW, WE]",
)
# Per-lane detail (8 values: NS_L0, NS_L1, SN_L0, SN_L1, EW_L0, EW_L1, WE_L0, WE_L1)
lanes_100m: List[int] = Field(
default_factory=lambda: [0] * NUM_LANES,
description="Per-lane 100 m queue counts (8 lanes)",
)
lanes_500m: List[int] = Field(
default_factory=lambda: [0] * NUM_LANES,
description="Per-lane 500 m queue counts (8 lanes)",
)
# Vehicle type composition per direction at 100 m
# Keys: vehicle type name, Values: list of 4 ints [NS, SN, EW, WE]
vehicles_100m: Dict[str, List[int]] = Field(
default_factory=lambda: {vt: [0] * NUM_DIRECTIONS for vt in VEHICLE_TYPE_NAMES},
description="Per-type, per-direction vehicle counts at 100 m zone",
)
vehicles_500m: Dict[str, List[int]] = Field(
default_factory=lambda: {vt: [0] * NUM_DIRECTIONS for vt in VEHICLE_TYPE_NAMES},
description="Per-type, per-direction vehicle counts at 500 m zone",
)
# Dilemma zone β€” physics-based safety metric
dilemma_risk: float = Field(
default=0.0,
description=(
"Number of vehicles in the dilemma zone this step (can't stop safely "
"after phase switch). 0.0 when no switch occurred."
),
)
total_dilemma_vehicles: float = Field(
default=0.0,
description="Cumulative dilemma-zone vehicles across the episode",
)
# Grading (populated on final step when done=True)
grade_score: Optional[float] = Field(
default=None,
description="Final grade 0.0-1.0 (only set on terminal observation when done=True)",
)
grade_details: Optional[Dict[str, Any]] = Field(
default=None,
description="Breakdown of grading components (only set on terminal observation when done=True)",
)