Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass, field | |
| from typing import Any, Literal, Optional | |
| import numpy as np | |
| Direction = Literal["N", "S", "E", "W"] | |
| VehicleType = Literal["civilian", "ambulance", "fire", "police", "bus"] | |
| IncidentKind = Literal["accident", "construction", "debris"] | |
| EMERGENCY_TYPES: set[str] = {"ambulance", "fire", "police"} | |
| OPPOSITE: dict[Direction, Direction] = {"N": "S", "S": "N", "E": "W", "W": "E"} | |
| class Road: | |
| id: str | |
| from_node: str | |
| to_node: str | |
| approach_direction: Direction | |
| length: int | |
| cells: list[Optional[str]] = field(default_factory=list) | |
| blocked: bool = False | |
| def __post_init__(self): | |
| if not self.cells: | |
| self.cells = [None] * self.length | |
| def occupancy(self) -> int: | |
| return sum(1 for c in self.cells if c is not None) | |
| def queue_at_tail(self) -> int: | |
| n = 0 | |
| for i in range(self.length - 1, -1, -1): | |
| if self.cells[i] is not None: | |
| n += 1 | |
| else: | |
| break | |
| return n | |
| class Intersection: | |
| id: str | |
| position: tuple[int, int] | |
| phases: list[frozenset[Direction]] | |
| current_phase_idx: int = 0 | |
| phase_timer: int = 0 | |
| min_phase_ticks: int = 6 | |
| max_phase_ticks: int = 45 | |
| incoming: dict[Direction, str] = field(default_factory=dict) | |
| outgoing: dict[Direction, str] = field(default_factory=dict) | |
| bias: dict[Direction, float] = field(default_factory=lambda: {"N": 1.0, "S": 1.0, "E": 1.0, "W": 1.0}) | |
| preempt_direction: Optional[Direction] = None | |
| preempt_expires_tick: Optional[int] = None | |
| neighbors: list[str] = field(default_factory=list) | |
| def current_phase(self) -> frozenset[Direction]: | |
| return self.phases[self.current_phase_idx] | |
| def phase_name(self) -> str: | |
| return "+".join(sorted(self.current_phase())) | |
| def phase_idx_containing(self, direction: Direction) -> Optional[int]: | |
| for i, ph in enumerate(self.phases): | |
| if direction in ph: | |
| return i | |
| return None | |
| class Vehicle: | |
| id: str | |
| type: VehicleType | |
| route: list[str] | |
| route_idx: int = 0 | |
| position_in_road: int = 0 | |
| spawn_tick: int = 0 | |
| wait_ticks: int = 0 | |
| cleared: bool = False | |
| clear_tick: Optional[int] = None | |
| def is_emergency(self) -> bool: | |
| return self.type in EMERGENCY_TYPES | |
| class Incident: | |
| id: str | |
| road_id: str | |
| kind: IncidentKind | |
| start_tick: int | |
| end_tick: Optional[int] | |
| active: bool = False | |
| described: bool = False | |
| class Plan: | |
| id: str | |
| op: str | |
| created_tick: int | |
| expires_tick: Optional[int] | |
| targets: list[str] | |
| params: dict | |
| reason: str = "" | |
| snapshot: dict = field(default_factory=dict) | |
| class Corridor: | |
| id: str | |
| intersections: list[str] | |
| direction: Direction | |
| coordinated: bool = False | |
| plan_id: Optional[str] = None | |
| target_speed: Optional[float] = None | |
| phase_offsets: dict[str, int] = field(default_factory=dict) | |
| class SpawnEvent: | |
| tick: int | |
| vehicle_id: str | |
| vehicle_type: VehicleType | |
| route: list[str] | |
| class IncidentEvent: | |
| tick: int | |
| incident: Incident | |
| class Metrics: | |
| cleared_civilian: int = 0 | |
| cleared_emergency: int = 0 | |
| spawned_civilian: int = 0 | |
| spawned_emergency: int = 0 | |
| wasted_green_ticks: int = 0 | |
| gridlock_events: int = 0 | |
| emergency_clear_times: list[int] = field(default_factory=list) | |
| max_wait_ticks_seen: int = 0 | |
| invalid_actions: int = 0 | |
| stalled_streak: int = 0 | |
| class TickStats: | |
| cleared_civ: int = 0 | |
| cleared_em: int = 0 | |
| wasted_green: int = 0 | |
| gridlock: int = 0 | |
| moved_any: bool = False | |
| class World: | |
| tick: int | |
| horizon: int | |
| task: str | |
| seed: int | |
| rng: np.random.Generator | |
| roads: dict[str, Road] | |
| intersections: dict[str, Intersection] | |
| corridors: dict[str, Corridor] | |
| vehicles: dict[str, Vehicle] = field(default_factory=dict) | |
| incidents: list[Incident] = field(default_factory=list) | |
| active_plans: dict[str, Plan] = field(default_factory=dict) | |
| spawn_schedule: list[SpawnEvent] = field(default_factory=list) | |
| incident_schedule: list[IncidentEvent] = field(default_factory=list) | |
| metrics: Metrics = field(default_factory=Metrics) | |
| event_log: list[str] = field(default_factory=list) | |
| interventions_used: int = 0 | |
| interventions_budget: int = 0 | |
| last_action_error: Optional[str] = None | |
| next_plan_seq: int = 0 | |
| reroute_overrides: dict[str, list[str]] = field(default_factory=dict) | |
| controller_mode: str = "dqn" # "fixed", "max_pressure", "dqn" | |
| rl_controller: Any = None | |
| dqn_decision_interval: int = 5 | |
| dqn_tick_counter: int = 0 | |
| def log(self, msg: str) -> None: | |
| self.event_log.append(f"t={self.tick} {msg}") | |
| if len(self.event_log) > 200: | |
| self.event_log = self.event_log[-200:] | |
| def new_plan_id(self) -> str: | |
| self.next_plan_seq += 1 | |
| return f"plan_{self.next_plan_seq:04d}" | |