from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Any import numpy as np from district_llm.schema import CandidateIntersection, CongestedIntersection, DistrictStateSummary, candidate_priority_score from env.utils import load_json @dataclass class _SummaryContext: previous_summaries: dict[str, DistrictStateSummary] previous_finished_vehicles: int class DistrictStateSummaryBuilder: def __init__(self, top_k: int = 3, candidate_limit: int = 6): self.top_k = int(top_k) self.candidate_limit = int(candidate_limit) self._context = _SummaryContext(previous_summaries={}, previous_finished_vehicles=0) self._scenario_metadata: dict[str, Any] | None = None self._road_endpoints: dict[str, tuple[str, str]] | None = None self._incident_intersections: set[str] = set() def reset(self) -> None: self._context = _SummaryContext(previous_summaries={}, previous_finished_vehicles=0) self._scenario_metadata = None self._road_endpoints = None self._incident_intersections = set() def build_all(self, env, observation_batch: dict[str, Any]) -> dict[str, DistrictStateSummary]: if self._scenario_metadata is None: metadata_path = Path(env.scenario_dir) / "scenario_metadata.json" self._scenario_metadata = load_json(metadata_path) if metadata_path.exists() else {} self._road_endpoints = self._load_road_endpoints(Path(env.roadnet_path)) self._incident_intersections = self._derive_incident_intersections() lane_vehicle_count = env.adapter.get_lane_vehicle_count() finished_vehicles = int(env.adapter.get_finished_vehicle_count()) district_summaries: dict[str, DistrictStateSummary] = {} for district_id in env.districts: district_summaries[district_id] = self._build_single( env=env, observation_batch=observation_batch, lane_vehicle_count=lane_vehicle_count, district_id=district_id, finished_vehicles=finished_vehicles, ) self._context.previous_summaries = district_summaries self._context.previous_finished_vehicles = finished_vehicles return district_summaries def _build_single( self, env, observation_batch: dict[str, Any], lane_vehicle_count: dict[str, int], district_id: str, finished_vehicles: int, ) -> DistrictStateSummary: district = env.districts[district_id] scenario_metadata = self._scenario_metadata or {} intersection_ids = observation_batch["intersection_ids"] district_ids = observation_batch["district_ids"] incoming_counts = observation_batch["incoming_counts"] incoming_waiting = observation_batch["incoming_waiting"] current_phase = observation_batch["current_phase"] queue_totals: list[float] = [] wait_totals: list[float] = [] outgoing_loads: list[float] = [] ns_queue = 0.0 ew_queue = 0.0 ns_wait = 0.0 ew_wait = 0.0 boundary_queue_total = 0.0 boundary_wait_total = 0.0 congestion_items: list[CongestedIntersection] = [] candidate_seed_items: list[dict[str, Any]] = [] for index, intersection_id in enumerate(intersection_ids): if district_ids[index] != district_id: continue queue_total = float(np.asarray(incoming_counts[index], dtype=np.float32).sum()) wait_total = float(np.asarray(incoming_waiting[index], dtype=np.float32).sum()) outgoing_load = self._compute_outgoing_load( env=env, lane_vehicle_count=lane_vehicle_count, intersection_id=intersection_id, ) queue_totals.append(queue_total) wait_totals.append(wait_total) outgoing_loads.append(outgoing_load) midpoint = incoming_counts.shape[1] // 2 ns_queue_local = float(np.asarray(incoming_counts[index][:midpoint], dtype=np.float32).sum()) ew_queue_local = float(np.asarray(incoming_counts[index][midpoint:], dtype=np.float32).sum()) ns_wait_local = float(np.asarray(incoming_waiting[index][:midpoint], dtype=np.float32).sum()) ew_wait_local = float(np.asarray(incoming_waiting[index][midpoint:], dtype=np.float32).sum()) ns_queue += ns_queue_local ew_queue += ew_queue_local ns_wait += ns_wait_local ew_wait += ew_wait_local intersection_config = env.intersections[intersection_id] if intersection_config.is_boundary: boundary_queue_total += queue_total boundary_wait_total += wait_total congestion_items.append( CongestedIntersection( intersection_id=intersection_id, queue_total=queue_total, wait_total=wait_total, outgoing_load=outgoing_load, current_phase=int(current_phase[index]), is_boundary=bool(intersection_config.is_boundary), ) ) candidate_seed_items.append( { "intersection_id": intersection_id, "queue_total": queue_total, "wait_total": wait_total, "outgoing_load": outgoing_load, "current_phase": int(current_phase[index]), "is_boundary": bool(intersection_config.is_boundary), "spillback_risk": bool( outgoing_load >= max(6.0, queue_total * 0.6) or ( intersection_config.is_boundary and outgoing_load >= max(4.0, queue_total * 0.4) ) ), "incident_proximity": intersection_id in self._incident_intersections, "corridor_alignment": self._compute_corridor_alignment( ns_queue=ns_queue_local, ew_queue=ew_queue_local, ns_wait=ns_wait_local, ew_wait=ew_wait_local, ), } ) queue_array = np.asarray(queue_totals or [0.0], dtype=np.float32) wait_array = np.asarray(wait_totals or [0.0], dtype=np.float32) outgoing_array = np.asarray(outgoing_loads or [0.0], dtype=np.float32) previous_summary = self._context.previous_summaries.get(district_id) recent_throughput = float( finished_vehicles - self._context.previous_finished_vehicles if self._context.previous_finished_vehicles else 0.0 ) queue_change = 0.0 if previous_summary is None else float(queue_array.sum() - previous_summary.total_queue) wait_change = 0.0 if previous_summary is None else float(wait_array.sum() - previous_summary.total_wait) throughput_change = ( 0.0 if previous_summary is None else recent_throughput - previous_summary.recent_throughput ) directional_ns = ns_queue + 1.5 * ns_wait directional_ew = ew_queue + 1.5 * ew_wait if directional_ns > directional_ew * 1.1: dominant_flow = "NS" elif directional_ew > directional_ns * 1.1: dominant_flow = "EW" else: dominant_flow = "BALANCED" boundary_share = boundary_queue_total / max(1.0, float(queue_array.sum())) spillback_risk = bool( outgoing_array.max() >= max(8.0, queue_array.max() * 0.5) or (boundary_share >= 0.6 and queue_change >= 0.0) ) top_intersections = sorted( congestion_items, key=lambda item: (item.queue_total + 1.5 * item.wait_total + 0.5 * item.outgoing_load), reverse=True, )[: self.top_k] overload_flag = bool( scenario_metadata.get("overload_district") == district_id or (scenario_metadata.get("name") == "district_overload" and queue_array.sum() >= 25.0) ) event_flag = bool(scenario_metadata.get("event_district") == district_id) incident_flag = bool( scenario_metadata.get("name") in {"accident", "construction"} or bool(scenario_metadata.get("blocked_roads")) ) construction_flag = bool(scenario_metadata.get("name") == "construction") candidate_intersections = self._build_candidate_intersections( candidate_seed_items=candidate_seed_items, overload_flag=overload_flag, event_flag=event_flag, ) return DistrictStateSummary( city_id=env.city_id, district_id=district_id, district_type=district.district_type, scenario_name=env.scenario_name, scenario_type=str(scenario_metadata.get("intensity", env.scenario_name)), decision_step=int(observation_batch["decision_step"]), sim_time=int(observation_batch["sim_time"]), intersection_count=int(len(district.intersection_ids)), avg_queue=float(queue_array.mean()), max_queue=float(queue_array.max()), total_queue=float(queue_array.sum()), avg_wait=float(wait_array.mean()), max_wait=float(wait_array.max()), total_wait=float(wait_array.sum()), avg_outgoing_load=float(outgoing_array.mean()), max_outgoing_load=float(outgoing_array.max()), total_outgoing_load=float(outgoing_array.sum()), recent_throughput=recent_throughput, queue_change=queue_change, wait_change=wait_change, throughput_change=throughput_change, ns_queue=ns_queue, ew_queue=ew_queue, ns_wait=ns_wait, ew_wait=ew_wait, dominant_flow=dominant_flow, boundary_queue_total=boundary_queue_total, boundary_wait_total=boundary_wait_total, spillback_risk=spillback_risk, incident_flag=incident_flag, construction_flag=construction_flag, overload_flag=overload_flag, event_flag=event_flag, top_congested_intersections=top_intersections, candidate_intersections=candidate_intersections, ).validate() @staticmethod def _compute_outgoing_load(env, lane_vehicle_count: dict[str, int], intersection_id: str) -> float: intersection_config = env.intersections[intersection_id] if not intersection_config.outgoing_lanes: return 0.0 return float( sum(float(lane_vehicle_count.get(lane_id, 0)) for lane_id in intersection_config.outgoing_lanes) ) @staticmethod def _compute_corridor_alignment( ns_queue: float, ew_queue: float, ns_wait: float, ew_wait: float, ) -> str: ns_pressure = ns_queue + 1.5 * ns_wait ew_pressure = ew_queue + 1.5 * ew_wait if ns_pressure > ew_pressure * 1.1: return "NS" if ew_pressure > ns_pressure * 1.1: return "EW" return "BALANCED" @staticmethod def _load_road_endpoints(roadnet_path: Path) -> dict[str, tuple[str, str]]: roadnet = load_json(roadnet_path) return { str(road["id"]): ( str(road["startIntersection"]), str(road["endIntersection"]), ) for road in roadnet.get("roads", []) } def _derive_incident_intersections(self) -> set[str]: if not self._road_endpoints: return set() scenario_metadata = self._scenario_metadata or {} details = scenario_metadata.get("details", {}) incident_roads = list(scenario_metadata.get("blocked_roads", [])) incident_roads.extend(details.get("accident_roads", [])) incident_roads.extend(details.get("construction_roads", [])) if not incident_roads: incident_roads.extend(list((scenario_metadata.get("penalized_roads") or {}).keys())) intersections: set[str] = set() for road_id in incident_roads: endpoints = self._road_endpoints.get(str(road_id)) if endpoints is None: continue intersections.update(endpoints) return intersections def _build_candidate_intersections( self, candidate_seed_items: list[dict[str, Any]], overload_flag: bool, event_flag: bool, ) -> list[CandidateIntersection]: if not candidate_seed_items or self.candidate_limit <= 0: return [] def severity_key(item: dict[str, Any]) -> tuple[float, float, float, float, str]: candidate = CandidateIntersection( intersection_id=str(item["intersection_id"]), queue_total=float(item["queue_total"]), wait_total=float(item["wait_total"]), outgoing_load=float(item["outgoing_load"]), current_phase=int(item["current_phase"]), is_boundary=bool(item["is_boundary"]), spillback_risk=bool(item["spillback_risk"]), incident_proximity=bool(item["incident_proximity"]), overload_marker=overload_flag, event_proximity=event_flag, corridor_alignment=str(item["corridor_alignment"]), selection_reasons=[], ) return ( candidate_priority_score(candidate), float(item["queue_total"]), float(item["wait_total"]), float(item["outgoing_load"]), str(item["intersection_id"]), ) overall_sorted = sorted( candidate_seed_items, key=lambda item: ( -severity_key(item)[0], -severity_key(item)[1], -severity_key(item)[2], -severity_key(item)[3], severity_key(item)[4], ), ) boundary_sorted = [item for item in overall_sorted if item["is_boundary"]] spillback_sorted = [item for item in overall_sorted if item["spillback_risk"]] incident_sorted = [item for item in overall_sorted if item["incident_proximity"]] outgoing_sorted = sorted( candidate_seed_items, key=lambda item: ( -float(item["outgoing_load"]), -float(item["queue_total"]), -float(item["wait_total"]), str(item["intersection_id"]), ), ) reason_tags: dict[str, set[str]] = {} selected_ids: list[str] = [] def mark(items: list[dict[str, Any]], tag: str, limit: int) -> None: for item in items[:limit]: intersection_id = str(item["intersection_id"]) reason_tags.setdefault(intersection_id, set()).add(tag) if intersection_id not in selected_ids: selected_ids.append(intersection_id) mark(overall_sorted, "congested", max(1, min(self.top_k, self.candidate_limit))) mark(boundary_sorted, "boundary", min(2, self.candidate_limit)) mark(spillback_sorted, "spillback", min(2, self.candidate_limit)) mark(incident_sorted, "incident", min(2, self.candidate_limit)) mark(outgoing_sorted, "outgoing", min(2, self.candidate_limit)) if overload_flag: mark(overall_sorted, "overload", min(2, self.candidate_limit)) if event_flag: event_seed = boundary_sorted if boundary_sorted else outgoing_sorted mark(event_seed, "event", min(2, self.candidate_limit)) for item in overall_sorted: if len(selected_ids) >= self.candidate_limit: break intersection_id = str(item["intersection_id"]) if intersection_id in selected_ids: continue selected_ids.append(intersection_id) reason_tags.setdefault(intersection_id, {"congested"}) seed_lookup = { str(item["intersection_id"]): item for item in candidate_seed_items } candidates = [ CandidateIntersection( intersection_id=intersection_id, queue_total=float(seed_lookup[intersection_id]["queue_total"]), wait_total=float(seed_lookup[intersection_id]["wait_total"]), outgoing_load=float(seed_lookup[intersection_id]["outgoing_load"]), current_phase=int(seed_lookup[intersection_id]["current_phase"]), is_boundary=bool(seed_lookup[intersection_id]["is_boundary"]), spillback_risk=bool(seed_lookup[intersection_id]["spillback_risk"]), incident_proximity=bool(seed_lookup[intersection_id]["incident_proximity"]), overload_marker=overload_flag, event_proximity=event_flag, corridor_alignment=str(seed_lookup[intersection_id]["corridor_alignment"]), selection_reasons=sorted(reason_tags.get(intersection_id, {"congested"})), ).validate() for intersection_id in selected_ids[: self.candidate_limit] ] return sorted( candidates, key=lambda item: ( -candidate_priority_score(item), -item.queue_total, -item.wait_total, -item.outgoing_load, item.intersection_id, ), )