| from __future__ import annotations |
|
|
| import json |
| from dataclasses import dataclass, field |
| from typing import Any |
|
|
|
|
| DISTRICT_STRATEGIES: tuple[str, ...] = ( |
| "hold", |
| "favor_NS", |
| "favor_EW", |
| "drain_inbound", |
| "drain_outbound", |
| "clear_spillback", |
| "incident_response", |
| "arterial_priority", |
| ) |
| PHASE_BIASES: tuple[str, ...] = ("NONE", "NS", "EW") |
| PRIORITY_CORRIDORS: tuple[str, ...] = ( |
| "NS", |
| "EW", |
| "inbound", |
| "outbound", |
| "arterial", |
| ) |
| DOMINANT_FLOWS: tuple[str, ...] = ("NS", "EW", "BALANCED") |
| CANDIDATE_REASON_TAGS: tuple[str, ...] = ( |
| "congested", |
| "boundary", |
| "spillback", |
| "incident", |
| "outgoing", |
| "overload", |
| "event", |
| ) |
|
|
|
|
| def _round_float(value: float, digits: int = 3) -> float: |
| return round(float(value), digits) |
|
|
|
|
| def _dedupe_string_list(values: list[str] | tuple[str, ...] | None, limit: int | None = None) -> list[str]: |
| normalized: list[str] = [] |
| seen: set[str] = set() |
| for item in values or []: |
| value = str(item).strip() |
| if not value or value in seen: |
| continue |
| normalized.append(value) |
| seen.add(value) |
| if limit is not None and len(normalized) >= limit: |
| break |
| return normalized |
|
|
|
|
| def _stable_reason_list(values: list[str] | tuple[str, ...] | None) -> list[str]: |
| present = {str(item).strip() for item in (values or []) if str(item).strip()} |
| return [item for item in CANDIDATE_REASON_TAGS if item in present] |
|
|
|
|
| def candidate_priority_score(candidate: "CandidateIntersection | dict[str, Any]") -> float: |
| item = candidate.to_dict() if hasattr(candidate, "to_dict") else dict(candidate) |
| queue_total = float(item.get("queue_total", 0.0)) |
| wait_total = float(item.get("wait_total", 0.0)) |
| outgoing_load = float(item.get("outgoing_load", 0.0)) |
| score = queue_total + 1.5 * wait_total + 0.5 * outgoing_load |
| score += 2.0 * float(bool(item.get("spillback_risk", False))) |
| score += 1.5 * float(bool(item.get("incident_proximity", False))) |
| score += 1.0 * float(bool(item.get("is_boundary", False))) |
| score += 0.75 * float(bool(item.get("event_proximity", False))) |
| score += 0.75 * float(bool(item.get("overload_marker", False))) |
| return score |
|
|
|
|
| def candidate_priority_tuple(candidate: "CandidateIntersection | dict[str, Any]") -> tuple[float, float, float, float, str]: |
| item = candidate.to_dict() if hasattr(candidate, "to_dict") else dict(candidate) |
| return ( |
| candidate_priority_score(item), |
| float(item.get("queue_total", 0.0)), |
| float(item.get("wait_total", 0.0)), |
| float(item.get("outgoing_load", 0.0)), |
| str(item.get("intersection_id", "")), |
| ) |
|
|
|
|
| def canonicalize_target_intersections( |
| targets: list[str] | tuple[str, ...] | None, |
| candidates: list["CandidateIntersection | dict[str, Any]"] | None = None, |
| limit: int | None = None, |
| ) -> list[str]: |
| normalized = _dedupe_string_list(targets, limit=None) |
| if not candidates: |
| return normalized[:limit] if limit is not None else normalized |
|
|
| candidate_order = { |
| str(candidate.to_dict()["intersection_id"] if hasattr(candidate, "to_dict") else candidate["intersection_id"]): ( |
| -candidate_priority_tuple(candidate)[0], |
| -candidate_priority_tuple(candidate)[1], |
| -candidate_priority_tuple(candidate)[2], |
| -candidate_priority_tuple(candidate)[3], |
| candidate_priority_tuple(candidate)[4], |
| ) |
| for candidate in candidates |
| } |
| normalized.sort(key=lambda item: candidate_order.get(item, (1.0, 1.0, 1.0, 1.0, item))) |
| if limit is not None: |
| normalized = normalized[:limit] |
| return normalized |
|
|
|
|
| @dataclass |
| class CongestedIntersection: |
| intersection_id: str |
| queue_total: float |
| wait_total: float |
| outgoing_load: float |
| current_phase: int |
| is_boundary: bool |
|
|
| def to_dict(self) -> dict[str, Any]: |
| return { |
| "intersection_id": self.intersection_id, |
| "queue_total": _round_float(self.queue_total), |
| "wait_total": _round_float(self.wait_total), |
| "outgoing_load": _round_float(self.outgoing_load), |
| "current_phase": int(self.current_phase), |
| "is_boundary": bool(self.is_boundary), |
| } |
|
|
| def to_prompt_line(self) -> str: |
| return ( |
| f"- {self.intersection_id} " |
| f"q={self.queue_total:.2f} " |
| f"w={self.wait_total:.2f} " |
| f"out={self.outgoing_load:.2f} " |
| f"phase={self.current_phase} " |
| f"boundary={int(self.is_boundary)}" |
| ) |
|
|
|
|
| @dataclass |
| class CandidateIntersection: |
| intersection_id: str |
| queue_total: float |
| wait_total: float |
| outgoing_load: float |
| current_phase: int |
| is_boundary: bool |
| spillback_risk: bool = False |
| incident_proximity: bool = False |
| overload_marker: bool = False |
| event_proximity: bool = False |
| corridor_alignment: str = "BALANCED" |
| selection_reasons: list[str] = field(default_factory=list) |
|
|
| def validate(self) -> "CandidateIntersection": |
| if self.corridor_alignment not in DOMINANT_FLOWS: |
| raise ValueError( |
| f"Invalid corridor_alignment '{self.corridor_alignment}'. Expected one of {DOMINANT_FLOWS}." |
| ) |
| self.selection_reasons = _stable_reason_list(self.selection_reasons) |
| return self |
|
|
| def to_dict(self) -> dict[str, Any]: |
| self.validate() |
| return { |
| "intersection_id": self.intersection_id, |
| "queue_total": _round_float(self.queue_total), |
| "wait_total": _round_float(self.wait_total), |
| "outgoing_load": _round_float(self.outgoing_load), |
| "current_phase": int(self.current_phase), |
| "is_boundary": bool(self.is_boundary), |
| "spillback_risk": bool(self.spillback_risk), |
| "incident_proximity": bool(self.incident_proximity), |
| "overload_marker": bool(self.overload_marker), |
| "event_proximity": bool(self.event_proximity), |
| "corridor_alignment": self.corridor_alignment, |
| "selection_reasons": list(self.selection_reasons), |
| } |
|
|
| def to_prompt_line(self) -> str: |
| self.validate() |
| reasons = "|".join(self.selection_reasons) if self.selection_reasons else "none" |
| return ( |
| f"- {self.intersection_id} " |
| f"q={self.queue_total:.2f} " |
| f"w={self.wait_total:.2f} " |
| f"out={self.outgoing_load:.2f} " |
| f"phase={self.current_phase} " |
| f"boundary={int(self.is_boundary)} " |
| f"spillback={int(self.spillback_risk)} " |
| f"incident={int(self.incident_proximity)} " |
| f"overload={int(self.overload_marker)} " |
| f"event={int(self.event_proximity)} " |
| f"align={self.corridor_alignment} " |
| f"reasons={reasons}" |
| ) |
|
|
|
|
| @dataclass |
| class DistrictAction: |
| strategy: str = "hold" |
| priority_corridor: str | None = None |
| target_intersections: list[str] = field(default_factory=list) |
| phase_bias: str = "NONE" |
| duration_steps: int = 1 |
|
|
| def validate(self) -> "DistrictAction": |
| if self.strategy not in DISTRICT_STRATEGIES: |
| raise ValueError( |
| f"Invalid strategy '{self.strategy}'. Expected one of {DISTRICT_STRATEGIES}." |
| ) |
| if self.priority_corridor is not None and self.priority_corridor not in PRIORITY_CORRIDORS: |
| raise ValueError( |
| f"Invalid priority_corridor '{self.priority_corridor}'. " |
| f"Expected one of {PRIORITY_CORRIDORS} or None." |
| ) |
| if self.phase_bias not in PHASE_BIASES: |
| raise ValueError( |
| f"Invalid phase_bias '{self.phase_bias}'. Expected one of {PHASE_BIASES}." |
| ) |
| if not isinstance(self.duration_steps, int): |
| raise ValueError("duration_steps must be an integer.") |
| if not 1 <= self.duration_steps <= 20: |
| raise ValueError("duration_steps must be between 1 and 20.") |
| self.target_intersections = _dedupe_string_list(self.target_intersections, limit=8) |
| return self |
|
|
| @classmethod |
| def default_hold(cls, duration_steps: int = 1) -> "DistrictAction": |
| return cls( |
| strategy="hold", |
| priority_corridor=None, |
| target_intersections=[], |
| phase_bias="NONE", |
| duration_steps=max(1, min(int(duration_steps), 20)), |
| ) |
|
|
| @classmethod |
| def from_dict(cls, payload: dict[str, Any]) -> "DistrictAction": |
| return cls( |
| strategy=str(payload.get("strategy", "hold")), |
| priority_corridor=payload.get("priority_corridor"), |
| target_intersections=list(payload.get("target_intersections", [])), |
| phase_bias=str(payload.get("phase_bias", "NONE")), |
| duration_steps=int(payload.get("duration_steps", 1)), |
| ).validate() |
|
|
| @classmethod |
| def from_json(cls, payload: str) -> "DistrictAction": |
| return cls.from_dict(json.loads(payload)) |
|
|
| def to_dict(self) -> dict[str, Any]: |
| self.validate() |
| return { |
| "strategy": self.strategy, |
| "priority_corridor": self.priority_corridor, |
| "target_intersections": list(self.target_intersections), |
| "phase_bias": self.phase_bias, |
| "duration_steps": int(self.duration_steps), |
| } |
|
|
| def to_json(self) -> str: |
| return json.dumps(self.to_dict(), sort_keys=True, separators=(",", ":")) |
|
|
| def to_pretty_json(self) -> str: |
| return json.dumps(self.to_dict(), sort_keys=True, indent=2) |
|
|
| def to_rl_context(self) -> dict[str, Any]: |
| payload = self.to_dict() |
| payload["district_strategy"] = payload.pop("strategy") |
| payload["district_duration_steps"] = payload.pop("duration_steps") |
| return payload |
|
|
|
|
| @dataclass |
| class DistrictStateSummary: |
| city_id: str |
| district_id: str |
| district_type: str |
| scenario_name: str |
| scenario_type: str |
| decision_step: int |
| sim_time: int |
| intersection_count: int |
| avg_queue: float |
| max_queue: float |
| total_queue: float |
| avg_wait: float |
| max_wait: float |
| total_wait: float |
| avg_outgoing_load: float |
| max_outgoing_load: float |
| total_outgoing_load: float |
| recent_throughput: float |
| queue_change: float |
| wait_change: float |
| throughput_change: float |
| ns_queue: float |
| ew_queue: float |
| ns_wait: float |
| ew_wait: float |
| dominant_flow: str |
| boundary_queue_total: float |
| boundary_wait_total: float |
| spillback_risk: bool |
| incident_flag: bool |
| construction_flag: bool |
| overload_flag: bool |
| event_flag: bool |
| top_congested_intersections: list[CongestedIntersection] = field(default_factory=list) |
| candidate_intersections: list[CandidateIntersection] = field(default_factory=list) |
|
|
| def validate(self) -> "DistrictStateSummary": |
| if self.dominant_flow not in DOMINANT_FLOWS: |
| raise ValueError( |
| f"Invalid dominant_flow '{self.dominant_flow}'. Expected one of {DOMINANT_FLOWS}." |
| ) |
| self.top_congested_intersections = list(self.top_congested_intersections[:5]) |
| self.candidate_intersections = list(self.candidate_intersections[:8]) |
| return self |
|
|
| def candidate_ids(self) -> list[str]: |
| self.validate() |
| return [item.intersection_id for item in self.candidate_intersections] |
|
|
| def candidate_lookup(self) -> dict[str, CandidateIntersection]: |
| self.validate() |
| return { |
| item.intersection_id: item |
| for item in self.candidate_intersections |
| } |
|
|
| def to_dict(self) -> dict[str, Any]: |
| self.validate() |
| return { |
| "city_id": self.city_id, |
| "district_id": self.district_id, |
| "district_type": self.district_type, |
| "scenario_name": self.scenario_name, |
| "scenario_type": self.scenario_type, |
| "decision_step": int(self.decision_step), |
| "sim_time": int(self.sim_time), |
| "intersection_count": int(self.intersection_count), |
| "avg_queue": _round_float(self.avg_queue), |
| "max_queue": _round_float(self.max_queue), |
| "total_queue": _round_float(self.total_queue), |
| "avg_wait": _round_float(self.avg_wait), |
| "max_wait": _round_float(self.max_wait), |
| "total_wait": _round_float(self.total_wait), |
| "avg_outgoing_load": _round_float(self.avg_outgoing_load), |
| "max_outgoing_load": _round_float(self.max_outgoing_load), |
| "total_outgoing_load": _round_float(self.total_outgoing_load), |
| "recent_throughput": _round_float(self.recent_throughput), |
| "queue_change": _round_float(self.queue_change), |
| "wait_change": _round_float(self.wait_change), |
| "throughput_change": _round_float(self.throughput_change), |
| "ns_queue": _round_float(self.ns_queue), |
| "ew_queue": _round_float(self.ew_queue), |
| "ns_wait": _round_float(self.ns_wait), |
| "ew_wait": _round_float(self.ew_wait), |
| "dominant_flow": self.dominant_flow, |
| "boundary_queue_total": _round_float(self.boundary_queue_total), |
| "boundary_wait_total": _round_float(self.boundary_wait_total), |
| "spillback_risk": bool(self.spillback_risk), |
| "incident_flag": bool(self.incident_flag), |
| "construction_flag": bool(self.construction_flag), |
| "overload_flag": bool(self.overload_flag), |
| "event_flag": bool(self.event_flag), |
| "top_congested_intersections": [ |
| item.to_dict() for item in self.top_congested_intersections |
| ], |
| "candidate_intersections": [ |
| item.to_dict() for item in self.candidate_intersections |
| ], |
| } |
|
|
| def to_json(self) -> str: |
| return json.dumps(self.to_dict(), sort_keys=True, separators=(",", ":")) |
|
|
| def to_prompt_text(self) -> str: |
| self.validate() |
| top_lines = [item.to_prompt_line() for item in self.top_congested_intersections] |
| candidate_lines = [item.to_prompt_line() for item in self.candidate_intersections] |
| if not top_lines: |
| top_lines = ["- none"] |
| if not candidate_lines: |
| candidate_lines = ["- none"] |
| return "\n".join( |
| [ |
| f"city_id: {self.city_id}", |
| f"district_id: {self.district_id}", |
| f"district_type: {self.district_type}", |
| f"scenario: {self.scenario_name}", |
| f"scenario_type: {self.scenario_type}", |
| f"decision_step: {self.decision_step}", |
| f"sim_time: {self.sim_time}", |
| f"intersection_count: {self.intersection_count}", |
| f"avg_queue: {self.avg_queue:.2f}", |
| f"max_queue: {self.max_queue:.2f}", |
| f"total_queue: {self.total_queue:.2f}", |
| f"avg_wait: {self.avg_wait:.2f}", |
| f"max_wait: {self.max_wait:.2f}", |
| f"total_wait: {self.total_wait:.2f}", |
| f"avg_outgoing_load: {self.avg_outgoing_load:.2f}", |
| f"max_outgoing_load: {self.max_outgoing_load:.2f}", |
| f"total_outgoing_load: {self.total_outgoing_load:.2f}", |
| f"recent_throughput: {self.recent_throughput:.2f}", |
| f"queue_change: {self.queue_change:.2f}", |
| f"wait_change: {self.wait_change:.2f}", |
| f"throughput_change: {self.throughput_change:.2f}", |
| f"ns_queue: {self.ns_queue:.2f}", |
| f"ew_queue: {self.ew_queue:.2f}", |
| f"ns_wait: {self.ns_wait:.2f}", |
| f"ew_wait: {self.ew_wait:.2f}", |
| f"dominant_flow: {self.dominant_flow}", |
| f"boundary_queue_total: {self.boundary_queue_total:.2f}", |
| f"boundary_wait_total: {self.boundary_wait_total:.2f}", |
| f"spillback_risk: {int(self.spillback_risk)}", |
| f"incident_flag: {int(self.incident_flag)}", |
| f"construction_flag: {int(self.construction_flag)}", |
| f"overload_flag: {int(self.overload_flag)}", |
| f"event_flag: {int(self.event_flag)}", |
| "top_congested_intersections:", |
| *top_lines, |
| "candidate_intersections:", |
| *candidate_lines, |
| ] |
| ) |
|
|