tokev's picture
Add files using upload-large-folder tool
5893134 verified
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from district_llm.repair import fallback_target_intersections
from district_llm.schema import DistrictAction, DistrictStateSummary
@dataclass
class LocalIntersectionAction:
intersection_id: str
district_id: str
action: int
current_phase: int
next_phase: int
queue_total: float
wait_total: float
outgoing_load: float
is_boundary: bool
@property
def switched(self) -> bool:
return int(self.action) == 1 and self.next_phase != self.current_phase
@dataclass
class DistrictWindowData:
district_id: str
start_summary: DistrictStateSummary
end_summary: DistrictStateSummary
controller_actions: list[LocalIntersectionAction] = field(default_factory=list)
step_count: int = 0
def to_dict(self) -> dict[str, Any]:
return {
"district_id": self.district_id,
"step_count": int(self.step_count),
"queue_delta": round(self.end_summary.total_queue - self.start_summary.total_queue, 3),
"wait_delta": round(self.end_summary.total_wait - self.start_summary.total_wait, 3),
"throughput_delta": round(
self.end_summary.recent_throughput - self.start_summary.recent_throughput,
3,
),
}
def derive_district_action(
window_data: DistrictWindowData,
controller_actions: list[LocalIntersectionAction] | None = None,
district_state: DistrictStateSummary | None = None,
max_target_intersections: int = 3,
) -> DistrictAction:
"""
Deterministic first-pass label extraction from local-controller behavior.
Heuristic order:
1. Incident-heavy windows map to `incident_response`.
2. Strong spillback / boundary pressure maps to `clear_spillback`.
3. Rising boundary demand maps to `drain_inbound`.
4. Persistently high outgoing pressure maps to `drain_outbound`.
5. Boundary-heavy rush windows map to `arterial_priority`.
6. Clear NS/EW directional dominance maps to `favor_NS` / `favor_EW`.
7. Otherwise emit `hold`.
"""
actions = controller_actions if controller_actions is not None else window_data.controller_actions
state = district_state if district_state is not None else window_data.start_summary
end_state = window_data.end_summary
duration_steps = max(1, min(int(window_data.step_count or 1), 20))
phase_counts = {"NS": 0, "EW": 0}
focus_scores: dict[str, float] = {}
boundary_focus = 0
switch_count = 0
for item in actions:
phase_key = "NS" if int(item.next_phase) == 0 else "EW"
phase_counts[phase_key] += 1
switch_count += int(item.switched)
if item.is_boundary:
boundary_focus += 1
focus_scores[item.intersection_id] = focus_scores.get(item.intersection_id, 0.0) + (
item.queue_total + 1.5 * item.wait_total + 2.0 * float(item.switched)
)
total_action_records = max(1, len(actions))
ns_phase_ratio = phase_counts["NS"] / float(total_action_records)
ew_phase_ratio = phase_counts["EW"] / float(total_action_records)
boundary_focus_ratio = boundary_focus / float(total_action_records)
queue_delta = end_state.total_queue - state.total_queue
wait_delta = end_state.total_wait - state.total_wait
boundary_share = state.boundary_queue_total / max(1.0, state.total_queue)
outgoing_pressure = end_state.total_outgoing_load / max(1.0, end_state.total_queue)
if ns_phase_ratio > ew_phase_ratio + 0.1:
phase_bias = "NS"
elif ew_phase_ratio > ns_phase_ratio + 0.1:
phase_bias = "EW"
else:
phase_bias = "NONE"
if phase_bias == "NONE" and state.dominant_flow in {"NS", "EW"}:
phase_bias = state.dominant_flow
def select_targets(
strategy: str,
priority_corridor: str | None,
selected_phase_bias: str,
) -> list[str]:
return fallback_target_intersections(
summary=state,
max_target_intersections=max_target_intersections,
strategy=strategy,
priority_corridor=priority_corridor,
phase_bias=selected_phase_bias,
focus_scores=focus_scores,
)
if state.incident_flag or end_state.incident_flag:
target_intersections = select_targets(
strategy="incident_response",
priority_corridor=phase_bias if phase_bias in {"NS", "EW"} else "arterial",
selected_phase_bias=phase_bias,
)
return DistrictAction(
strategy="incident_response",
priority_corridor=phase_bias if phase_bias in {"NS", "EW"} else "arterial",
target_intersections=target_intersections,
phase_bias=phase_bias,
duration_steps=duration_steps,
).validate()
if state.spillback_risk or end_state.spillback_risk or (boundary_share >= 0.55 and outgoing_pressure >= 0.45):
priority_corridor = "inbound" if boundary_share >= 0.55 else phase_bias if phase_bias in {"NS", "EW"} else None
target_intersections = select_targets(
strategy="clear_spillback",
priority_corridor=priority_corridor,
selected_phase_bias=phase_bias,
)
return DistrictAction(
strategy="clear_spillback",
priority_corridor=priority_corridor,
target_intersections=target_intersections,
phase_bias=phase_bias,
duration_steps=duration_steps,
).validate()
if boundary_share >= 0.55 and (queue_delta >= 0.0 or wait_delta >= 0.0):
target_intersections = select_targets(
strategy="drain_inbound",
priority_corridor="inbound",
selected_phase_bias=phase_bias,
)
return DistrictAction(
strategy="drain_inbound",
priority_corridor="inbound",
target_intersections=target_intersections,
phase_bias=phase_bias,
duration_steps=duration_steps,
).validate()
if outgoing_pressure >= 0.65 and end_state.total_queue >= state.total_queue * 0.9:
target_intersections = select_targets(
strategy="drain_outbound",
priority_corridor="outbound",
selected_phase_bias=phase_bias,
)
return DistrictAction(
strategy="drain_outbound",
priority_corridor="outbound",
target_intersections=target_intersections,
phase_bias=phase_bias,
duration_steps=duration_steps,
).validate()
if (
state.event_flag
or state.overload_flag
or end_state.overload_flag
or (boundary_focus_ratio >= 0.6 and switch_count >= max(2, duration_steps))
):
priority_corridor = phase_bias if phase_bias in {"NS", "EW"} else "arterial"
target_intersections = select_targets(
strategy="arterial_priority",
priority_corridor=priority_corridor,
selected_phase_bias=phase_bias,
)
return DistrictAction(
strategy="arterial_priority",
priority_corridor=priority_corridor,
target_intersections=target_intersections,
phase_bias=phase_bias,
duration_steps=duration_steps,
).validate()
ns_pressure = state.ns_queue + 1.5 * state.ns_wait
ew_pressure = state.ew_queue + 1.5 * state.ew_wait
imbalance_threshold = max(5.0, 0.15 * max(1.0, ns_pressure + ew_pressure))
if ns_pressure - ew_pressure >= imbalance_threshold:
target_intersections = select_targets(
strategy="favor_NS",
priority_corridor="NS",
selected_phase_bias="NS",
)
return DistrictAction(
strategy="favor_NS",
priority_corridor="NS",
target_intersections=target_intersections,
phase_bias="NS",
duration_steps=duration_steps,
).validate()
if ew_pressure - ns_pressure >= imbalance_threshold:
target_intersections = select_targets(
strategy="favor_EW",
priority_corridor="EW",
selected_phase_bias="EW",
)
return DistrictAction(
strategy="favor_EW",
priority_corridor="EW",
target_intersections=target_intersections,
phase_bias="EW",
duration_steps=duration_steps,
).validate()
return DistrictAction.default_hold(duration_steps=duration_steps)