from __future__ import annotations import json from dataclasses import dataclass, field from typing import Any DISTRICT_STRATEGIES: tuple[str, ...] = ( "hold", "favor_NS", "favor_EW", "drain_inbound", "drain_outbound", "clear_spillback", "incident_response", "arterial_priority", ) PHASE_BIASES: tuple[str, ...] = ("NONE", "NS", "EW") PRIORITY_CORRIDORS: tuple[str, ...] = ( "NS", "EW", "inbound", "outbound", "arterial", ) DOMINANT_FLOWS: tuple[str, ...] = ("NS", "EW", "BALANCED") CANDIDATE_REASON_TAGS: tuple[str, ...] = ( "congested", "boundary", "spillback", "incident", "outgoing", "overload", "event", ) def _round_float(value: float, digits: int = 3) -> float: return round(float(value), digits) def _dedupe_string_list(values: list[str] | tuple[str, ...] | None, limit: int | None = None) -> list[str]: normalized: list[str] = [] seen: set[str] = set() for item in values or []: value = str(item).strip() if not value or value in seen: continue normalized.append(value) seen.add(value) if limit is not None and len(normalized) >= limit: break return normalized def _stable_reason_list(values: list[str] | tuple[str, ...] | None) -> list[str]: present = {str(item).strip() for item in (values or []) if str(item).strip()} return [item for item in CANDIDATE_REASON_TAGS if item in present] def candidate_priority_score(candidate: "CandidateIntersection | dict[str, Any]") -> float: item = candidate.to_dict() if hasattr(candidate, "to_dict") else dict(candidate) queue_total = float(item.get("queue_total", 0.0)) wait_total = float(item.get("wait_total", 0.0)) outgoing_load = float(item.get("outgoing_load", 0.0)) score = queue_total + 1.5 * wait_total + 0.5 * outgoing_load score += 2.0 * float(bool(item.get("spillback_risk", False))) score += 1.5 * float(bool(item.get("incident_proximity", False))) score += 1.0 * float(bool(item.get("is_boundary", False))) score += 0.75 * float(bool(item.get("event_proximity", False))) score += 0.75 * float(bool(item.get("overload_marker", False))) return score def candidate_priority_tuple(candidate: "CandidateIntersection | dict[str, Any]") -> tuple[float, float, float, float, str]: item = candidate.to_dict() if hasattr(candidate, "to_dict") else dict(candidate) return ( candidate_priority_score(item), float(item.get("queue_total", 0.0)), float(item.get("wait_total", 0.0)), float(item.get("outgoing_load", 0.0)), str(item.get("intersection_id", "")), ) def canonicalize_target_intersections( targets: list[str] | tuple[str, ...] | None, candidates: list["CandidateIntersection | dict[str, Any]"] | None = None, limit: int | None = None, ) -> list[str]: normalized = _dedupe_string_list(targets, limit=None) if not candidates: return normalized[:limit] if limit is not None else normalized candidate_order = { str(candidate.to_dict()["intersection_id"] if hasattr(candidate, "to_dict") else candidate["intersection_id"]): ( -candidate_priority_tuple(candidate)[0], -candidate_priority_tuple(candidate)[1], -candidate_priority_tuple(candidate)[2], -candidate_priority_tuple(candidate)[3], candidate_priority_tuple(candidate)[4], ) for candidate in candidates } normalized.sort(key=lambda item: candidate_order.get(item, (1.0, 1.0, 1.0, 1.0, item))) if limit is not None: normalized = normalized[:limit] return normalized @dataclass class CongestedIntersection: intersection_id: str queue_total: float wait_total: float outgoing_load: float current_phase: int is_boundary: bool def to_dict(self) -> dict[str, Any]: return { "intersection_id": self.intersection_id, "queue_total": _round_float(self.queue_total), "wait_total": _round_float(self.wait_total), "outgoing_load": _round_float(self.outgoing_load), "current_phase": int(self.current_phase), "is_boundary": bool(self.is_boundary), } def to_prompt_line(self) -> str: return ( f"- {self.intersection_id} " f"q={self.queue_total:.2f} " f"w={self.wait_total:.2f} " f"out={self.outgoing_load:.2f} " f"phase={self.current_phase} " f"boundary={int(self.is_boundary)}" ) @dataclass class CandidateIntersection: intersection_id: str queue_total: float wait_total: float outgoing_load: float current_phase: int is_boundary: bool spillback_risk: bool = False incident_proximity: bool = False overload_marker: bool = False event_proximity: bool = False corridor_alignment: str = "BALANCED" selection_reasons: list[str] = field(default_factory=list) def validate(self) -> "CandidateIntersection": if self.corridor_alignment not in DOMINANT_FLOWS: raise ValueError( f"Invalid corridor_alignment '{self.corridor_alignment}'. Expected one of {DOMINANT_FLOWS}." ) self.selection_reasons = _stable_reason_list(self.selection_reasons) return self def to_dict(self) -> dict[str, Any]: self.validate() return { "intersection_id": self.intersection_id, "queue_total": _round_float(self.queue_total), "wait_total": _round_float(self.wait_total), "outgoing_load": _round_float(self.outgoing_load), "current_phase": int(self.current_phase), "is_boundary": bool(self.is_boundary), "spillback_risk": bool(self.spillback_risk), "incident_proximity": bool(self.incident_proximity), "overload_marker": bool(self.overload_marker), "event_proximity": bool(self.event_proximity), "corridor_alignment": self.corridor_alignment, "selection_reasons": list(self.selection_reasons), } def to_prompt_line(self) -> str: self.validate() reasons = "|".join(self.selection_reasons) if self.selection_reasons else "none" return ( f"- {self.intersection_id} " f"q={self.queue_total:.2f} " f"w={self.wait_total:.2f} " f"out={self.outgoing_load:.2f} " f"phase={self.current_phase} " f"boundary={int(self.is_boundary)} " f"spillback={int(self.spillback_risk)} " f"incident={int(self.incident_proximity)} " f"overload={int(self.overload_marker)} " f"event={int(self.event_proximity)} " f"align={self.corridor_alignment} " f"reasons={reasons}" ) @dataclass class DistrictAction: strategy: str = "hold" priority_corridor: str | None = None target_intersections: list[str] = field(default_factory=list) phase_bias: str = "NONE" duration_steps: int = 1 def validate(self) -> "DistrictAction": if self.strategy not in DISTRICT_STRATEGIES: raise ValueError( f"Invalid strategy '{self.strategy}'. Expected one of {DISTRICT_STRATEGIES}." ) if self.priority_corridor is not None and self.priority_corridor not in PRIORITY_CORRIDORS: raise ValueError( f"Invalid priority_corridor '{self.priority_corridor}'. " f"Expected one of {PRIORITY_CORRIDORS} or None." ) if self.phase_bias not in PHASE_BIASES: raise ValueError( f"Invalid phase_bias '{self.phase_bias}'. Expected one of {PHASE_BIASES}." ) if not isinstance(self.duration_steps, int): raise ValueError("duration_steps must be an integer.") if not 1 <= self.duration_steps <= 20: raise ValueError("duration_steps must be between 1 and 20.") self.target_intersections = _dedupe_string_list(self.target_intersections, limit=8) return self @classmethod def default_hold(cls, duration_steps: int = 1) -> "DistrictAction": return cls( strategy="hold", priority_corridor=None, target_intersections=[], phase_bias="NONE", duration_steps=max(1, min(int(duration_steps), 20)), ) @classmethod def from_dict(cls, payload: dict[str, Any]) -> "DistrictAction": return cls( strategy=str(payload.get("strategy", "hold")), priority_corridor=payload.get("priority_corridor"), target_intersections=list(payload.get("target_intersections", [])), phase_bias=str(payload.get("phase_bias", "NONE")), duration_steps=int(payload.get("duration_steps", 1)), ).validate() @classmethod def from_json(cls, payload: str) -> "DistrictAction": return cls.from_dict(json.loads(payload)) def to_dict(self) -> dict[str, Any]: self.validate() return { "strategy": self.strategy, "priority_corridor": self.priority_corridor, "target_intersections": list(self.target_intersections), "phase_bias": self.phase_bias, "duration_steps": int(self.duration_steps), } def to_json(self) -> str: return json.dumps(self.to_dict(), sort_keys=True, separators=(",", ":")) def to_pretty_json(self) -> str: return json.dumps(self.to_dict(), sort_keys=True, indent=2) def to_rl_context(self) -> dict[str, Any]: payload = self.to_dict() payload["district_strategy"] = payload.pop("strategy") payload["district_duration_steps"] = payload.pop("duration_steps") return payload @dataclass class DistrictStateSummary: city_id: str district_id: str district_type: str scenario_name: str scenario_type: str decision_step: int sim_time: int intersection_count: int avg_queue: float max_queue: float total_queue: float avg_wait: float max_wait: float total_wait: float avg_outgoing_load: float max_outgoing_load: float total_outgoing_load: float recent_throughput: float queue_change: float wait_change: float throughput_change: float ns_queue: float ew_queue: float ns_wait: float ew_wait: float dominant_flow: str boundary_queue_total: float boundary_wait_total: float spillback_risk: bool incident_flag: bool construction_flag: bool overload_flag: bool event_flag: bool top_congested_intersections: list[CongestedIntersection] = field(default_factory=list) candidate_intersections: list[CandidateIntersection] = field(default_factory=list) def validate(self) -> "DistrictStateSummary": if self.dominant_flow not in DOMINANT_FLOWS: raise ValueError( f"Invalid dominant_flow '{self.dominant_flow}'. Expected one of {DOMINANT_FLOWS}." ) self.top_congested_intersections = list(self.top_congested_intersections[:5]) self.candidate_intersections = list(self.candidate_intersections[:8]) return self def candidate_ids(self) -> list[str]: self.validate() return [item.intersection_id for item in self.candidate_intersections] def candidate_lookup(self) -> dict[str, CandidateIntersection]: self.validate() return { item.intersection_id: item for item in self.candidate_intersections } def to_dict(self) -> dict[str, Any]: self.validate() return { "city_id": self.city_id, "district_id": self.district_id, "district_type": self.district_type, "scenario_name": self.scenario_name, "scenario_type": self.scenario_type, "decision_step": int(self.decision_step), "sim_time": int(self.sim_time), "intersection_count": int(self.intersection_count), "avg_queue": _round_float(self.avg_queue), "max_queue": _round_float(self.max_queue), "total_queue": _round_float(self.total_queue), "avg_wait": _round_float(self.avg_wait), "max_wait": _round_float(self.max_wait), "total_wait": _round_float(self.total_wait), "avg_outgoing_load": _round_float(self.avg_outgoing_load), "max_outgoing_load": _round_float(self.max_outgoing_load), "total_outgoing_load": _round_float(self.total_outgoing_load), "recent_throughput": _round_float(self.recent_throughput), "queue_change": _round_float(self.queue_change), "wait_change": _round_float(self.wait_change), "throughput_change": _round_float(self.throughput_change), "ns_queue": _round_float(self.ns_queue), "ew_queue": _round_float(self.ew_queue), "ns_wait": _round_float(self.ns_wait), "ew_wait": _round_float(self.ew_wait), "dominant_flow": self.dominant_flow, "boundary_queue_total": _round_float(self.boundary_queue_total), "boundary_wait_total": _round_float(self.boundary_wait_total), "spillback_risk": bool(self.spillback_risk), "incident_flag": bool(self.incident_flag), "construction_flag": bool(self.construction_flag), "overload_flag": bool(self.overload_flag), "event_flag": bool(self.event_flag), "top_congested_intersections": [ item.to_dict() for item in self.top_congested_intersections ], "candidate_intersections": [ item.to_dict() for item in self.candidate_intersections ], } def to_json(self) -> str: return json.dumps(self.to_dict(), sort_keys=True, separators=(",", ":")) def to_prompt_text(self) -> str: self.validate() top_lines = [item.to_prompt_line() for item in self.top_congested_intersections] candidate_lines = [item.to_prompt_line() for item in self.candidate_intersections] if not top_lines: top_lines = ["- none"] if not candidate_lines: candidate_lines = ["- none"] return "\n".join( [ f"city_id: {self.city_id}", f"district_id: {self.district_id}", f"district_type: {self.district_type}", f"scenario: {self.scenario_name}", f"scenario_type: {self.scenario_type}", f"decision_step: {self.decision_step}", f"sim_time: {self.sim_time}", f"intersection_count: {self.intersection_count}", f"avg_queue: {self.avg_queue:.2f}", f"max_queue: {self.max_queue:.2f}", f"total_queue: {self.total_queue:.2f}", f"avg_wait: {self.avg_wait:.2f}", f"max_wait: {self.max_wait:.2f}", f"total_wait: {self.total_wait:.2f}", f"avg_outgoing_load: {self.avg_outgoing_load:.2f}", f"max_outgoing_load: {self.max_outgoing_load:.2f}", f"total_outgoing_load: {self.total_outgoing_load:.2f}", f"recent_throughput: {self.recent_throughput:.2f}", f"queue_change: {self.queue_change:.2f}", f"wait_change: {self.wait_change:.2f}", f"throughput_change: {self.throughput_change:.2f}", f"ns_queue: {self.ns_queue:.2f}", f"ew_queue: {self.ew_queue:.2f}", f"ns_wait: {self.ns_wait:.2f}", f"ew_wait: {self.ew_wait:.2f}", f"dominant_flow: {self.dominant_flow}", f"boundary_queue_total: {self.boundary_queue_total:.2f}", f"boundary_wait_total: {self.boundary_wait_total:.2f}", f"spillback_risk: {int(self.spillback_risk)}", f"incident_flag: {int(self.incident_flag)}", f"construction_flag: {int(self.construction_flag)}", f"overload_flag: {int(self.overload_flag)}", f"event_flag: {int(self.event_flag)}", "top_congested_intersections:", *top_lines, "candidate_intersections:", *candidate_lines, ] )