Spaces:
Running
Running
| from __future__ import annotations | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Callable | |
| from agents.message_protocol import DistrictDirective, parse_district_directive | |
| class BaseDistrictCoordinator(ABC): | |
| def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]: | |
| raise NotImplementedError | |
| class RuleBasedDistrictCoordinator(BaseDistrictCoordinator): | |
| """ | |
| Fast, deterministic, and robust. | |
| Good first coordinator and good fallback if the LLM output fails. | |
| """ | |
| def __init__( | |
| self, | |
| imbalance_threshold: float = 0.15, | |
| border_pressure_threshold: float = 0.65, | |
| default_duration: int = 2, | |
| ): | |
| self.imbalance_threshold = imbalance_threshold | |
| self.border_pressure_threshold = border_pressure_threshold | |
| self.default_duration = default_duration | |
| def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]: | |
| district_id = district_summary.get("district_id", "unknown") | |
| intersection_ids = district_summary.get("intersection_ids", []) | |
| emergency = district_summary.get("emergency_vehicle", {}) | |
| if emergency.get("present", False): | |
| return ( | |
| DistrictDirective( | |
| mode="emergency_route", | |
| target_intersections=emergency.get("route", intersection_ids), | |
| duration=2, | |
| rationale=f"Emergency vehicle detected in district {district_id}.", | |
| corridor=emergency.get("corridor"), | |
| district_weight=1.0, | |
| ) | |
| .validate() | |
| .to_dict() | |
| ) | |
| corridor_loads = district_summary.get("corridor_loads", {}) | |
| ns = float(corridor_loads.get("ns", corridor_loads.get("north_south", 0.0))) | |
| ew = float(corridor_loads.get("ew", corridor_loads.get("east_west", 0.0))) | |
| border_pressure = district_summary.get("border_pressure", {}) | |
| border_max = 0.0 | |
| if isinstance(border_pressure, dict) and border_pressure: | |
| border_max = max(float(v) for v in border_pressure.values()) | |
| if ew - ns > self.imbalance_threshold: | |
| return ( | |
| DistrictDirective( | |
| mode="prioritize_ew", | |
| target_intersections=intersection_ids, | |
| duration=self.default_duration, | |
| rationale="East-west corridor is currently more congested than north-south.", | |
| corridor="ew", | |
| district_weight=( | |
| 0.7 if border_max < self.border_pressure_threshold else 0.9 | |
| ), | |
| ) | |
| .validate() | |
| .to_dict() | |
| ) | |
| if ns - ew > self.imbalance_threshold: | |
| return ( | |
| DistrictDirective( | |
| mode="prioritize_ns", | |
| target_intersections=intersection_ids, | |
| duration=self.default_duration, | |
| rationale="North-south corridor is currently more congested than east-west.", | |
| corridor="ns", | |
| district_weight=( | |
| 0.7 if border_max < self.border_pressure_threshold else 0.9 | |
| ), | |
| ) | |
| .validate() | |
| .to_dict() | |
| ) | |
| if border_max >= self.border_pressure_threshold: | |
| return ( | |
| DistrictDirective( | |
| mode="damp_border_inflow", | |
| target_intersections=intersection_ids, | |
| duration=2, | |
| rationale="Border pressure is high; reduce spill-in and smooth cross-district flow.", | |
| district_weight=0.8, | |
| ) | |
| .validate() | |
| .to_dict() | |
| ) | |
| return ( | |
| DistrictDirective( | |
| mode="none", | |
| target_intersections=[], | |
| duration=1, | |
| rationale="District is reasonably balanced.", | |
| district_weight=0.5, | |
| ) | |
| .validate() | |
| .to_dict() | |
| ) | |
| class LLMDistrictCoordinator(BaseDistrictCoordinator): | |
| """ | |
| LLM-backed coordinator. | |
| `generator_fn` should accept a prompt string and return either: | |
| - a JSON string, or | |
| - a dict | |
| Example: | |
| coordinator = LLMDistrictCoordinator(generator_fn=my_model_call) | |
| """ | |
| def __init__( | |
| self, | |
| generator_fn: Callable[[str], str | dict[str, Any]], | |
| fallback: BaseDistrictCoordinator | None = None, | |
| max_prompt_chars: int = 4000, | |
| ): | |
| self.generator_fn = generator_fn | |
| self.fallback = fallback or RuleBasedDistrictCoordinator() | |
| self.max_prompt_chars = max_prompt_chars | |
| def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]: | |
| prompt = self.build_prompt(district_summary) | |
| try: | |
| raw = self.generator_fn(prompt) | |
| directive = parse_district_directive(raw).to_dict() | |
| # If the LLM returns a no-op too often or malformed content, | |
| # the parser still makes it safe. We keep that behavior. | |
| return directive | |
| except Exception: | |
| return self.fallback.decide(district_summary) | |
| def build_prompt(self, district_summary: dict[str, Any]) -> str: | |
| summary_text = repr(district_summary) | |
| if len(summary_text) > self.max_prompt_chars: | |
| summary_text = summary_text[: self.max_prompt_chars] + " ...[truncated]" | |
| return f"""You are a district-level traffic coordinator. | |
| Your job is to choose a single strategic directive for the next few cycles. | |
| Allowed modes: | |
| - none | |
| - prioritize_ns | |
| - prioritize_ew | |
| - green_wave | |
| - emergency_route | |
| - damp_border_inflow | |
| Return ONLY valid JSON with these fields: | |
| {{ | |
| "mode": string, | |
| "target_intersections": list[string], | |
| "duration": int, | |
| "rationale": string, | |
| "corridor": string or null, | |
| "district_weight": float | |
| }} | |
| Guidelines: | |
| - Use emergency_route if an emergency vehicle is present. | |
| - Use prioritize_ns or prioritize_ew when one corridor is clearly more congested. | |
| - Use damp_border_inflow when cross-district border pressure is high. | |
| - Keep duration between 1 and 5. | |
| - district_weight should be between 0.0 and 1.0. | |
| District summary: | |
| {summary_text} | |
| """ | |