traffic-visualizer / agents /district_controller.py
tokev's picture
Add files using upload-large-folder tool
5893134 verified
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Callable
from agents.message_protocol import DistrictDirective, parse_district_directive
class BaseDistrictCoordinator(ABC):
@abstractmethod
def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
raise NotImplementedError
class RuleBasedDistrictCoordinator(BaseDistrictCoordinator):
"""
Fast, deterministic, and robust.
Good first coordinator and good fallback if the LLM output fails.
"""
def __init__(
self,
imbalance_threshold: float = 0.15,
border_pressure_threshold: float = 0.65,
default_duration: int = 2,
):
self.imbalance_threshold = imbalance_threshold
self.border_pressure_threshold = border_pressure_threshold
self.default_duration = default_duration
def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
district_id = district_summary.get("district_id", "unknown")
intersection_ids = district_summary.get("intersection_ids", [])
emergency = district_summary.get("emergency_vehicle", {})
if emergency.get("present", False):
return (
DistrictDirective(
mode="emergency_route",
target_intersections=emergency.get("route", intersection_ids),
duration=2,
rationale=f"Emergency vehicle detected in district {district_id}.",
corridor=emergency.get("corridor"),
district_weight=1.0,
)
.validate()
.to_dict()
)
corridor_loads = district_summary.get("corridor_loads", {})
ns = float(corridor_loads.get("ns", corridor_loads.get("north_south", 0.0)))
ew = float(corridor_loads.get("ew", corridor_loads.get("east_west", 0.0)))
border_pressure = district_summary.get("border_pressure", {})
border_max = 0.0
if isinstance(border_pressure, dict) and border_pressure:
border_max = max(float(v) for v in border_pressure.values())
if ew - ns > self.imbalance_threshold:
return (
DistrictDirective(
mode="prioritize_ew",
target_intersections=intersection_ids,
duration=self.default_duration,
rationale="East-west corridor is currently more congested than north-south.",
corridor="ew",
district_weight=(
0.7 if border_max < self.border_pressure_threshold else 0.9
),
)
.validate()
.to_dict()
)
if ns - ew > self.imbalance_threshold:
return (
DistrictDirective(
mode="prioritize_ns",
target_intersections=intersection_ids,
duration=self.default_duration,
rationale="North-south corridor is currently more congested than east-west.",
corridor="ns",
district_weight=(
0.7 if border_max < self.border_pressure_threshold else 0.9
),
)
.validate()
.to_dict()
)
if border_max >= self.border_pressure_threshold:
return (
DistrictDirective(
mode="damp_border_inflow",
target_intersections=intersection_ids,
duration=2,
rationale="Border pressure is high; reduce spill-in and smooth cross-district flow.",
district_weight=0.8,
)
.validate()
.to_dict()
)
return (
DistrictDirective(
mode="none",
target_intersections=[],
duration=1,
rationale="District is reasonably balanced.",
district_weight=0.5,
)
.validate()
.to_dict()
)
class LLMDistrictCoordinator(BaseDistrictCoordinator):
"""
LLM-backed coordinator.
`generator_fn` should accept a prompt string and return either:
- a JSON string, or
- a dict
Example:
coordinator = LLMDistrictCoordinator(generator_fn=my_model_call)
"""
def __init__(
self,
generator_fn: Callable[[str], str | dict[str, Any]],
fallback: BaseDistrictCoordinator | None = None,
max_prompt_chars: int = 4000,
):
self.generator_fn = generator_fn
self.fallback = fallback or RuleBasedDistrictCoordinator()
self.max_prompt_chars = max_prompt_chars
def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
prompt = self.build_prompt(district_summary)
try:
raw = self.generator_fn(prompt)
directive = parse_district_directive(raw).to_dict()
# If the LLM returns a no-op too often or malformed content,
# the parser still makes it safe. We keep that behavior.
return directive
except Exception:
return self.fallback.decide(district_summary)
def build_prompt(self, district_summary: dict[str, Any]) -> str:
summary_text = repr(district_summary)
if len(summary_text) > self.max_prompt_chars:
summary_text = summary_text[: self.max_prompt_chars] + " ...[truncated]"
return f"""You are a district-level traffic coordinator.
Your job is to choose a single strategic directive for the next few cycles.
Allowed modes:
- none
- prioritize_ns
- prioritize_ew
- green_wave
- emergency_route
- damp_border_inflow
Return ONLY valid JSON with these fields:
{{
"mode": string,
"target_intersections": list[string],
"duration": int,
"rationale": string,
"corridor": string or null,
"district_weight": float
}}
Guidelines:
- Use emergency_route if an emergency vehicle is present.
- Use prioritize_ns or prioritize_ew when one corridor is clearly more congested.
- Use damp_border_inflow when cross-district border pressure is high.
- Keep duration between 1 and 5.
- district_weight should be between 0.0 and 1.0.
District summary:
{summary_text}
"""