File size: 6,606 Bytes
5893134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Callable

from agents.message_protocol import DistrictDirective, parse_district_directive


class BaseDistrictCoordinator(ABC):
    @abstractmethod
    def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
        raise NotImplementedError


class RuleBasedDistrictCoordinator(BaseDistrictCoordinator):
    """

    Fast, deterministic, and robust.

    Good first coordinator and good fallback if the LLM output fails.

    """

    def __init__(

        self,

        imbalance_threshold: float = 0.15,

        border_pressure_threshold: float = 0.65,

        default_duration: int = 2,

    ):
        self.imbalance_threshold = imbalance_threshold
        self.border_pressure_threshold = border_pressure_threshold
        self.default_duration = default_duration

    def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
        district_id = district_summary.get("district_id", "unknown")
        intersection_ids = district_summary.get("intersection_ids", [])

        emergency = district_summary.get("emergency_vehicle", {})
        if emergency.get("present", False):
            return (
                DistrictDirective(
                    mode="emergency_route",
                    target_intersections=emergency.get("route", intersection_ids),
                    duration=2,
                    rationale=f"Emergency vehicle detected in district {district_id}.",
                    corridor=emergency.get("corridor"),
                    district_weight=1.0,
                )
                .validate()
                .to_dict()
            )

        corridor_loads = district_summary.get("corridor_loads", {})
        ns = float(corridor_loads.get("ns", corridor_loads.get("north_south", 0.0)))
        ew = float(corridor_loads.get("ew", corridor_loads.get("east_west", 0.0)))

        border_pressure = district_summary.get("border_pressure", {})
        border_max = 0.0
        if isinstance(border_pressure, dict) and border_pressure:
            border_max = max(float(v) for v in border_pressure.values())

        if ew - ns > self.imbalance_threshold:
            return (
                DistrictDirective(
                    mode="prioritize_ew",
                    target_intersections=intersection_ids,
                    duration=self.default_duration,
                    rationale="East-west corridor is currently more congested than north-south.",
                    corridor="ew",
                    district_weight=(
                        0.7 if border_max < self.border_pressure_threshold else 0.9
                    ),
                )
                .validate()
                .to_dict()
            )

        if ns - ew > self.imbalance_threshold:
            return (
                DistrictDirective(
                    mode="prioritize_ns",
                    target_intersections=intersection_ids,
                    duration=self.default_duration,
                    rationale="North-south corridor is currently more congested than east-west.",
                    corridor="ns",
                    district_weight=(
                        0.7 if border_max < self.border_pressure_threshold else 0.9
                    ),
                )
                .validate()
                .to_dict()
            )

        if border_max >= self.border_pressure_threshold:
            return (
                DistrictDirective(
                    mode="damp_border_inflow",
                    target_intersections=intersection_ids,
                    duration=2,
                    rationale="Border pressure is high; reduce spill-in and smooth cross-district flow.",
                    district_weight=0.8,
                )
                .validate()
                .to_dict()
            )

        return (
            DistrictDirective(
                mode="none",
                target_intersections=[],
                duration=1,
                rationale="District is reasonably balanced.",
                district_weight=0.5,
            )
            .validate()
            .to_dict()
        )


class LLMDistrictCoordinator(BaseDistrictCoordinator):
    """

    LLM-backed coordinator.



    `generator_fn` should accept a prompt string and return either:

      - a JSON string, or

      - a dict



    Example:

        coordinator = LLMDistrictCoordinator(generator_fn=my_model_call)

    """

    def __init__(

        self,

        generator_fn: Callable[[str], str | dict[str, Any]],

        fallback: BaseDistrictCoordinator | None = None,

        max_prompt_chars: int = 4000,

    ):
        self.generator_fn = generator_fn
        self.fallback = fallback or RuleBasedDistrictCoordinator()
        self.max_prompt_chars = max_prompt_chars

    def decide(self, district_summary: dict[str, Any]) -> dict[str, Any]:
        prompt = self.build_prompt(district_summary)
        try:
            raw = self.generator_fn(prompt)
            directive = parse_district_directive(raw).to_dict()

            # If the LLM returns a no-op too often or malformed content,
            # the parser still makes it safe. We keep that behavior.
            return directive
        except Exception:
            return self.fallback.decide(district_summary)

    def build_prompt(self, district_summary: dict[str, Any]) -> str:
        summary_text = repr(district_summary)
        if len(summary_text) > self.max_prompt_chars:
            summary_text = summary_text[: self.max_prompt_chars] + " ...[truncated]"

        return f"""You are a district-level traffic coordinator.



Your job is to choose a single strategic directive for the next few cycles.



Allowed modes:

- none

- prioritize_ns

- prioritize_ew

- green_wave

- emergency_route

- damp_border_inflow



Return ONLY valid JSON with these fields:

{{

  "mode": string,

  "target_intersections": list[string],

  "duration": int,

  "rationale": string,

  "corridor": string or null,

  "district_weight": float

}}



Guidelines:

- Use emergency_route if an emergency vehicle is present.

- Use prioritize_ns or prioritize_ew when one corridor is clearly more congested.

- Use damp_border_inflow when cross-district border pressure is high.

- Keep duration between 1 and 5.

- district_weight should be between 0.0 and 1.0.



District summary:

{summary_text}

"""