| from __future__ import annotations |
|
|
| from abc import ABC, abstractmethod |
| from typing import Any, Callable |
|
|
| from agents.message_protocol import DistrictDirective, parse_district_directive |
|
|
|
|
| class BaseDistrictCoordinator(ABC): |
| @abstractmethod |
| def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]: |
| raise NotImplementedError |
|
|
|
|
| class RuleBasedDistrictCoordinator(BaseDistrictCoordinator): |
| """ |
| Fast, deterministic, and robust. |
| Good first coordinator and good fallback if the LLM output fails. |
| """ |
|
|
| def __init__( |
| self, |
| imbalance_threshold: float = 0.15, |
| border_pressure_threshold: float = 0.65, |
| default_duration: int = 2, |
| ): |
| self.imbalance_threshold = imbalance_threshold |
| self.border_pressure_threshold = border_pressure_threshold |
| self.default_duration = default_duration |
|
|
| def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]: |
| district_id = district_summary.get("district_id", "unknown") |
| intersection_ids = district_summary.get("intersection_ids", []) |
|
|
| emergency = district_summary.get("emergency_vehicle", {}) |
| if emergency.get("present", False): |
| return ( |
| DistrictDirective( |
| mode="emergency_route", |
| target_intersections=emergency.get("route", intersection_ids), |
| duration=2, |
| rationale=f"Emergency vehicle detected in district {district_id}.", |
| corridor=emergency.get("corridor"), |
| district_weight=1.0, |
| ) |
| .validate() |
| .to_dict() |
| ) |
|
|
| corridor_loads = district_summary.get("corridor_loads", {}) |
| ns = float(corridor_loads.get("ns", corridor_loads.get("north_south", 0.0))) |
| ew = float(corridor_loads.get("ew", corridor_loads.get("east_west", 0.0))) |
|
|
| border_pressure = district_summary.get("border_pressure", {}) |
| border_max = 0.0 |
| if isinstance(border_pressure, dict) and border_pressure: |
| border_max = max(float(v) for v in border_pressure.values()) |
|
|
| if ew - ns > self.imbalance_threshold: |
| return ( |
| DistrictDirective( |
| mode="prioritize_ew", |
| target_intersections=intersection_ids, |
| duration=self.default_duration, |
| rationale="East-west corridor is currently more congested than north-south.", |
| corridor="ew", |
| district_weight=( |
| 0.7 if border_max < self.border_pressure_threshold else 0.9 |
| ), |
| ) |
| .validate() |
| .to_dict() |
| ) |
|
|
| if ns - ew > self.imbalance_threshold: |
| return ( |
| DistrictDirective( |
| mode="prioritize_ns", |
| target_intersections=intersection_ids, |
| duration=self.default_duration, |
| rationale="North-south corridor is currently more congested than east-west.", |
| corridor="ns", |
| district_weight=( |
| 0.7 if border_max < self.border_pressure_threshold else 0.9 |
| ), |
| ) |
| .validate() |
| .to_dict() |
| ) |
|
|
| if border_max >= self.border_pressure_threshold: |
| return ( |
| DistrictDirective( |
| mode="damp_border_inflow", |
| target_intersections=intersection_ids, |
| duration=2, |
| rationale="Border pressure is high; reduce spill-in and smooth cross-district flow.", |
| district_weight=0.8, |
| ) |
| .validate() |
| .to_dict() |
| ) |
|
|
| return ( |
| DistrictDirective( |
| mode="none", |
| target_intersections=[], |
| duration=1, |
| rationale="District is reasonably balanced.", |
| district_weight=0.5, |
| ) |
| .validate() |
| .to_dict() |
| ) |
|
|
|
|
| class LLMDistrictCoordinator(BaseDistrictCoordinator): |
| """ |
| LLM-backed coordinator. |
| |
| `generator_fn` should accept a prompt string and return either: |
| - a JSON string, or |
| - a dict |
| |
| Example: |
| coordinator = LLMDistrictCoordinator(generator_fn=my_model_call) |
| """ |
|
|
| def __init__( |
| self, |
| generator_fn: Callable[[str], str | dict[str, Any]], |
| fallback: BaseDistrictCoordinator | None = None, |
| max_prompt_chars: int = 4000, |
| ): |
| self.generator_fn = generator_fn |
| self.fallback = fallback or RuleBasedDistrictCoordinator() |
| self.max_prompt_chars = max_prompt_chars |
|
|
| def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]: |
| prompt = self.build_prompt(district_summary) |
| try: |
| raw = self.generator_fn(prompt) |
| directive = parse_district_directive(raw).to_dict() |
|
|
| |
| |
| return directive |
| except Exception: |
| return self.fallback.decide(district_summary) |
|
|
| def build_prompt(self, district_summary: dict[str, Any]) -> str: |
| summary_text = repr(district_summary) |
| if len(summary_text) > self.max_prompt_chars: |
| summary_text = summary_text[: self.max_prompt_chars] + " ...[truncated]" |
|
|
| return f"""You are a district-level traffic coordinator. |
| |
| Your job is to choose a single strategic directive for the next few cycles. |
| |
| Allowed modes: |
| - none |
| - prioritize_ns |
| - prioritize_ew |
| - green_wave |
| - emergency_route |
| - damp_border_inflow |
| |
| Return ONLY valid JSON with these fields: |
| {{ |
| "mode": string, |
| "target_intersections": list[string], |
| "duration": int, |
| "rationale": string, |
| "corridor": string or null, |
| "district_weight": float |
| }} |
| |
| Guidelines: |
| - Use emergency_route if an emergency vehicle is present. |
| - Use prioritize_ns or prioritize_ew when one corridor is clearly more congested. |
| - Use damp_border_inflow when cross-district border pressure is high. |
| - Keep duration between 1 and 5. |
| - district_weight should be between 0.0 and 1.0. |
| |
| District summary: |
| {summary_text} |
| """ |
|
|