trafficops / server /sim /world.py
Kunalsinghh's picture
Upload folder using huggingface_hub
92107a5 verified
from dataclasses import dataclass, field
from typing import Any, Literal, Optional
import numpy as np
Direction = Literal["N", "S", "E", "W"]
VehicleType = Literal["civilian", "ambulance", "fire", "police", "bus"]
IncidentKind = Literal["accident", "construction", "debris"]
EMERGENCY_TYPES: set[str] = {"ambulance", "fire", "police"}
OPPOSITE: dict[Direction, Direction] = {"N": "S", "S": "N", "E": "W", "W": "E"}
@dataclass
class Road:
id: str
from_node: str
to_node: str
approach_direction: Direction
length: int
cells: list[Optional[str]] = field(default_factory=list)
blocked: bool = False
def __post_init__(self):
if not self.cells:
self.cells = [None] * self.length
def occupancy(self) -> int:
return sum(1 for c in self.cells if c is not None)
def queue_at_tail(self) -> int:
n = 0
for i in range(self.length - 1, -1, -1):
if self.cells[i] is not None:
n += 1
else:
break
return n
@dataclass
class Intersection:
id: str
position: tuple[int, int]
phases: list[frozenset[Direction]]
current_phase_idx: int = 0
phase_timer: int = 0
min_phase_ticks: int = 6
max_phase_ticks: int = 45
incoming: dict[Direction, str] = field(default_factory=dict)
outgoing: dict[Direction, str] = field(default_factory=dict)
bias: dict[Direction, float] = field(default_factory=lambda: {"N": 1.0, "S": 1.0, "E": 1.0, "W": 1.0})
preempt_direction: Optional[Direction] = None
preempt_expires_tick: Optional[int] = None
neighbors: list[str] = field(default_factory=list)
def current_phase(self) -> frozenset[Direction]:
return self.phases[self.current_phase_idx]
def phase_name(self) -> str:
return "+".join(sorted(self.current_phase()))
def phase_idx_containing(self, direction: Direction) -> Optional[int]:
for i, ph in enumerate(self.phases):
if direction in ph:
return i
return None
@dataclass
class Vehicle:
id: str
type: VehicleType
route: list[str]
route_idx: int = 0
position_in_road: int = 0
spawn_tick: int = 0
wait_ticks: int = 0
cleared: bool = False
clear_tick: Optional[int] = None
def is_emergency(self) -> bool:
return self.type in EMERGENCY_TYPES
@dataclass
class Incident:
id: str
road_id: str
kind: IncidentKind
start_tick: int
end_tick: Optional[int]
active: bool = False
described: bool = False
@dataclass
class Plan:
id: str
op: str
created_tick: int
expires_tick: Optional[int]
targets: list[str]
params: dict
reason: str = ""
snapshot: dict = field(default_factory=dict)
@dataclass
class Corridor:
id: str
intersections: list[str]
direction: Direction
coordinated: bool = False
plan_id: Optional[str] = None
target_speed: Optional[float] = None
phase_offsets: dict[str, int] = field(default_factory=dict)
@dataclass
class SpawnEvent:
tick: int
vehicle_id: str
vehicle_type: VehicleType
route: list[str]
@dataclass
class IncidentEvent:
tick: int
incident: Incident
@dataclass
class Metrics:
cleared_civilian: int = 0
cleared_emergency: int = 0
spawned_civilian: int = 0
spawned_emergency: int = 0
wasted_green_ticks: int = 0
gridlock_events: int = 0
emergency_clear_times: list[int] = field(default_factory=list)
max_wait_ticks_seen: int = 0
invalid_actions: int = 0
stalled_streak: int = 0
@dataclass
class TickStats:
cleared_civ: int = 0
cleared_em: int = 0
wasted_green: int = 0
gridlock: int = 0
moved_any: bool = False
@dataclass
class World:
tick: int
horizon: int
task: str
seed: int
rng: np.random.Generator
roads: dict[str, Road]
intersections: dict[str, Intersection]
corridors: dict[str, Corridor]
vehicles: dict[str, Vehicle] = field(default_factory=dict)
incidents: list[Incident] = field(default_factory=list)
active_plans: dict[str, Plan] = field(default_factory=dict)
spawn_schedule: list[SpawnEvent] = field(default_factory=list)
incident_schedule: list[IncidentEvent] = field(default_factory=list)
metrics: Metrics = field(default_factory=Metrics)
event_log: list[str] = field(default_factory=list)
interventions_used: int = 0
interventions_budget: int = 0
last_action_error: Optional[str] = None
next_plan_seq: int = 0
reroute_overrides: dict[str, list[str]] = field(default_factory=dict)
controller_mode: str = "dqn" # "fixed", "max_pressure", "dqn"
rl_controller: Any = None
dqn_decision_interval: int = 5
dqn_tick_counter: int = 0
def log(self, msg: str) -> None:
self.event_log.append(f"t={self.tick} {msg}")
if len(self.event_log) > 200:
self.event_log = self.event_log[-200:]
def new_plan_id(self) -> str:
self.next_plan_seq += 1
return f"plan_{self.next_plan_seq:04d}"