Spaces:
Sleeping
Sleeping
File size: 6,606 Bytes
5893134 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Callable
from agents.message_protocol import DistrictDirective, parse_district_directive
class BaseDistrictCoordinator(ABC):
@abstractmethod
def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
raise NotImplementedError
class RuleBasedDistrictCoordinator(BaseDistrictCoordinator):
"""
Fast, deterministic, and robust.
Good first coordinator and good fallback if the LLM output fails.
"""
def __init__(
self,
imbalance_threshold: float = 0.15,
border_pressure_threshold: float = 0.65,
default_duration: int = 2,
):
self.imbalance_threshold = imbalance_threshold
self.border_pressure_threshold = border_pressure_threshold
self.default_duration = default_duration
def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
district_id = district_summary.get("district_id", "unknown")
intersection_ids = district_summary.get("intersection_ids", [])
emergency = district_summary.get("emergency_vehicle", {})
if emergency.get("present", False):
return (
DistrictDirective(
mode="emergency_route",
target_intersections=emergency.get("route", intersection_ids),
duration=2,
rationale=f"Emergency vehicle detected in district {district_id}.",
corridor=emergency.get("corridor"),
district_weight=1.0,
)
.validate()
.to_dict()
)
corridor_loads = district_summary.get("corridor_loads", {})
ns = float(corridor_loads.get("ns", corridor_loads.get("north_south", 0.0)))
ew = float(corridor_loads.get("ew", corridor_loads.get("east_west", 0.0)))
border_pressure = district_summary.get("border_pressure", {})
border_max = 0.0
if isinstance(border_pressure, dict) and border_pressure:
border_max = max(float(v) for v in border_pressure.values())
if ew - ns > self.imbalance_threshold:
return (
DistrictDirective(
mode="prioritize_ew",
target_intersections=intersection_ids,
duration=self.default_duration,
rationale="East-west corridor is currently more congested than north-south.",
corridor="ew",
district_weight=(
0.7 if border_max < self.border_pressure_threshold else 0.9
),
)
.validate()
.to_dict()
)
if ns - ew > self.imbalance_threshold:
return (
DistrictDirective(
mode="prioritize_ns",
target_intersections=intersection_ids,
duration=self.default_duration,
rationale="North-south corridor is currently more congested than east-west.",
corridor="ns",
district_weight=(
0.7 if border_max < self.border_pressure_threshold else 0.9
),
)
.validate()
.to_dict()
)
if border_max >= self.border_pressure_threshold:
return (
DistrictDirective(
mode="damp_border_inflow",
target_intersections=intersection_ids,
duration=2,
rationale="Border pressure is high; reduce spill-in and smooth cross-district flow.",
district_weight=0.8,
)
.validate()
.to_dict()
)
return (
DistrictDirective(
mode="none",
target_intersections=[],
duration=1,
rationale="District is reasonably balanced.",
district_weight=0.5,
)
.validate()
.to_dict()
)
class LLMDistrictCoordinator(BaseDistrictCoordinator):
"""
LLM-backed coordinator.
`generator_fn` should accept a prompt string and return either:
- a JSON string, or
- a dict
Example:
coordinator = LLMDistrictCoordinator(generator_fn=my_model_call)
"""
def __init__(
self,
generator_fn: Callable[[str], str | dict[str, Any]],
fallback: BaseDistrictCoordinator | None = None,
max_prompt_chars: int = 4000,
):
self.generator_fn = generator_fn
self.fallback = fallback or RuleBasedDistrictCoordinator()
self.max_prompt_chars = max_prompt_chars
def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
prompt = self.build_prompt(district_summary)
try:
raw = self.generator_fn(prompt)
directive = parse_district_directive(raw).to_dict()
# If the LLM returns a no-op too often or malformed content,
# the parser still makes it safe. We keep that behavior.
return directive
except Exception:
return self.fallback.decide(district_summary)
def build_prompt(self, district_summary: dict[str, Any]) -> str:
summary_text = repr(district_summary)
if len(summary_text) > self.max_prompt_chars:
summary_text = summary_text[: self.max_prompt_chars] + " ...[truncated]"
return f"""You are a district-level traffic coordinator.
Your job is to choose a single strategic directive for the next few cycles.
Allowed modes:
- none
- prioritize_ns
- prioritize_ew
- green_wave
- emergency_route
- damp_border_inflow
Return ONLY valid JSON with these fields:
{{
"mode": string,
"target_intersections": list[string],
"duration": int,
"rationale": string,
"corridor": string or null,
"district_weight": float
}}
Guidelines:
- Use emergency_route if an emergency vehicle is present.
- Use prioritize_ns or prioritize_ew when one corridor is clearly more congested.
- Use damp_border_inflow when cross-district border pressure is high.
- Keep duration between 1 and 5.
- district_weight should be between 0.0 and 1.0.
District summary:
{summary_text}
"""
|