Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- inference.py +227 -67
inference.py
CHANGED
|
@@ -38,7 +38,13 @@ from typing import Any, Dict, List, Optional
|
|
| 38 |
from openai import OpenAI
|
| 39 |
|
| 40 |
from traffic_light_env import TrafficLightAction, TrafficLightEnv
|
| 41 |
-
from traffic_light_env.models import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
IMAGE_NAME = os.getenv("IMAGE_NAME")
|
| 44 |
API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
|
@@ -47,51 +53,45 @@ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
|
| 47 |
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 48 |
BENCHMARK = "traffic_light_env"
|
| 49 |
MAX_STEPS = 200
|
| 50 |
-
TEMPERATURE = 0.
|
| 51 |
-
MAX_TOKENS =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# Tasks to run. Override with TRAFFIC_LIGHT_TASKS env var (comma-separated).
|
| 54 |
TASKS = os.getenv("TRAFFIC_LIGHT_TASKS", ",".join(TASK_NAMES)).split(",")
|
| 55 |
|
| 56 |
SYSTEM_PROMPT = textwrap.dedent(
|
| 57 |
"""
|
| 58 |
-
You are
|
| 59 |
-
(NS
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
Each dilemma-zone vehicle incurs a -1.5 reward penalty. Avoid switching when
|
| 84 |
-
many heavy vehicles (trucks, buses) are in the green lanes' 100m zones.
|
| 85 |
-
|
| 86 |
-
Strategy tips:
|
| 87 |
-
- Corridor phases (0, 1) green 4 lanes at once — high throughput.
|
| 88 |
-
- Single-direction phases (2-5) useful when one direction is much busier.
|
| 89 |
-
- Consider 500m vehicles: they migrate to 100m soon.
|
| 90 |
-
- For emergency vehicles, prioritize the direction containing the emergency.
|
| 91 |
-
- Avoid switching when trucks/buses are in the 100m zone (high dilemma risk).
|
| 92 |
-
- Minimize total switches — each costs yellow time + dilemma risk + penalty.
|
| 93 |
-
|
| 94 |
-
Respond with ONLY a single digit: 0, 1, 2, 3, 4, or 5
|
| 95 |
"""
|
| 96 |
).strip()
|
| 97 |
|
|
@@ -121,6 +121,127 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
|
|
| 121 |
)
|
| 122 |
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
# ---------------------------------------------------------------------------
|
| 125 |
# Observation → LLM prompt
|
| 126 |
# ---------------------------------------------------------------------------
|
|
@@ -142,15 +263,10 @@ def obs_to_summary(obs: Any) -> str:
|
|
| 142 |
f"Total waiting: {obs.total_waiting}",
|
| 143 |
f"Throughput so far: {obs.total_throughput}",
|
| 144 |
]
|
| 145 |
-
#
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
for vt in ("truck", "bus", "suv"):
|
| 150 |
-
for d in range(4):
|
| 151 |
-
heavy[d] += v100.get(vt, [0, 0, 0, 0])[d]
|
| 152 |
-
heavy_str = " ".join(f"{dir_labels[d]}:{heavy[d]}" for d in range(4))
|
| 153 |
-
lines.append(f"Heavy vehicles (truck+bus+suv) at 100m — {heavy_str}")
|
| 154 |
lines.append(f"Cumulative dilemma-zone vehicles: {obs.total_dilemma_vehicles:.1f}")
|
| 155 |
|
| 156 |
if obs.emergency_direction >= 0:
|
|
@@ -163,6 +279,11 @@ def obs_to_summary(obs: Any) -> str:
|
|
| 163 |
f"EMERGENCY vehicle in {dir_name} direction (use {phases_help}), "
|
| 164 |
f"waiting {obs.emergency_wait} steps"
|
| 165 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
return "\n".join(lines)
|
| 167 |
|
| 168 |
|
|
@@ -178,7 +299,7 @@ def get_phase_from_llm(
|
|
| 178 |
"""Ask the LLM which phase to choose. Falls back to heuristic on failure."""
|
| 179 |
user_prompt = obs_to_summary(obs)
|
| 180 |
if history:
|
| 181 |
-
user_prompt += "\n\nRecent
|
| 182 |
user_prompt += "\n\nChoose phase (0-5):"
|
| 183 |
|
| 184 |
try:
|
|
@@ -199,24 +320,41 @@ def get_phase_from_llm(
|
|
| 199 |
except Exception as exc:
|
| 200 |
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|
| 201 |
|
| 202 |
-
return
|
| 203 |
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
if obs.emergency_direction >= 0:
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
return 1 # EW+WE corridor
|
| 215 |
|
| 216 |
-
#
|
| 217 |
-
|
| 218 |
-
ew_we = (obs.ew_100m + obs.we_100m) + 0.3 * (obs.ew_500m + obs.we_500m)
|
| 219 |
-
return 0 if ns_sn >= ew_we else 1
|
| 220 |
|
| 221 |
|
| 222 |
# ---------------------------------------------------------------------------
|
|
@@ -236,14 +374,26 @@ async def run_task(client: OpenAI, env: TrafficLightEnv, task: str) -> Dict[str,
|
|
| 236 |
try:
|
| 237 |
result = await env.reset(task=task)
|
| 238 |
obs = result.observation
|
|
|
|
|
|
|
| 239 |
|
| 240 |
for step in range(1, MAX_STEPS + 1):
|
| 241 |
if result.done:
|
| 242 |
break
|
| 243 |
|
| 244 |
-
phase =
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
|
|
|
| 247 |
result = await env.step(action)
|
| 248 |
obs = result.observation
|
| 249 |
|
|
@@ -263,7 +413,8 @@ async def run_task(client: OpenAI, env: TrafficLightEnv, task: str) -> Dict[str,
|
|
| 263 |
)
|
| 264 |
|
| 265 |
history.append(
|
| 266 |
-
f"Step {step}: phase={phase}, waiting={obs.total_waiting},
|
|
|
|
| 267 |
)
|
| 268 |
|
| 269 |
if done:
|
|
@@ -315,6 +466,15 @@ async def main() -> None:
|
|
| 315 |
f" [{status}] {r['task']:22s} score={r['score']:.4f} steps={r['steps']}",
|
| 316 |
flush=True,
|
| 317 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
avg_score = (
|
| 319 |
sum(r["score"] for r in all_results) / len(all_results)
|
| 320 |
if all_results else 0.0
|
|
|
|
| 38 |
from openai import OpenAI
|
| 39 |
|
| 40 |
from traffic_light_env import TrafficLightAction, TrafficLightEnv
|
| 41 |
+
from traffic_light_env.models import (
|
| 42 |
+
DILEMMA_FRACTIONS,
|
| 43 |
+
DIRECTION_NAMES,
|
| 44 |
+
NUM_PHASES,
|
| 45 |
+
TASK_NAMES,
|
| 46 |
+
VEHICLE_TYPE_NAMES,
|
| 47 |
+
)
|
| 48 |
|
| 49 |
IMAGE_NAME = os.getenv("IMAGE_NAME")
|
| 50 |
API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
|
|
|
| 53 |
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 54 |
BENCHMARK = "traffic_light_env"
|
| 55 |
MAX_STEPS = 200
|
| 56 |
+
TEMPERATURE = 0.2
|
| 57 |
+
MAX_TOKENS = 128
|
| 58 |
+
|
| 59 |
+
# Strategy parameters
|
| 60 |
+
MIN_HOLD_TIME = 8 # Minimum steps to hold a phase before considering switch
|
| 61 |
+
SWITCH_THRESHOLD = 1.8 # Opposing axis must be this many times busier to switch
|
| 62 |
+
LLM_CONSULT_INTERVAL = 10 # Ask LLM every N steps for strategic guidance
|
| 63 |
+
EMERGENCY_OVERRIDE = True # Immediately switch for emergency vehicles
|
| 64 |
|
| 65 |
# Tasks to run. Override with TRAFFIC_LIGHT_TASKS env var (comma-separated).
|
| 66 |
TASKS = os.getenv("TRAFFIC_LIGHT_TASKS", ",".join(TASK_NAMES)).split(",")
|
| 67 |
|
| 68 |
SYSTEM_PROMPT = textwrap.dedent(
|
| 69 |
"""
|
| 70 |
+
You are a traffic light controller at a 4-way intersection. 4 directions
|
| 71 |
+
(NS, SN, EW, WE) with 2 lanes each (8 total). You pick one of 6 phases:
|
| 72 |
+
|
| 73 |
+
0 = NS+SN corridor (4 lanes green — best throughput for N-S axis)
|
| 74 |
+
1 = EW+WE corridor (4 lanes green — best throughput for E-W axis)
|
| 75 |
+
2 = NS only 3 = SN only 4 = EW only 5 = WE only
|
| 76 |
+
|
| 77 |
+
CRITICAL RULES — switching phases costs 2 dead steps (yellow) + dilemma-zone
|
| 78 |
+
risk (vehicles that can't stop safely). Every unnecessary switch HURTS your score.
|
| 79 |
+
|
| 80 |
+
DECISION FRAMEWORK:
|
| 81 |
+
1. If currently in yellow transition → keep the pending phase (no choice).
|
| 82 |
+
2. If emergency vehicle present → switch to its corridor ONCE, then hold.
|
| 83 |
+
3. If held current phase < 8 steps → KEEP current phase (too early to switch).
|
| 84 |
+
4. Only switch if opposing axis queue is >1.8× current axis queue.
|
| 85 |
+
5. Prefer corridor phases (0 or 1) for maximum throughput.
|
| 86 |
+
6. Use single-direction phases (2-5) ONLY if one direction has >3× its opposite.
|
| 87 |
+
|
| 88 |
+
Scoring: 40% waiting (lower=better), 40% throughput (higher=better), 20% safety
|
| 89 |
+
(fewer dilemma vehicles=better). The fixed-timer baseline scores 0.81 by switching
|
| 90 |
+
every 10 steps. You should switch LESS often than that on balanced traffic.
|
| 91 |
+
|
| 92 |
+
Respond: one line with the phase digit (0-5), then a brief reason.
|
| 93 |
+
Format: <digit> <reason>
|
| 94 |
+
Example: 0 NS+SN corridor has more vehicles, hold current phase
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
"""
|
| 96 |
).strip()
|
| 97 |
|
|
|
|
| 121 |
)
|
| 122 |
|
| 123 |
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
# Dilemma risk estimation
|
| 126 |
+
# ---------------------------------------------------------------------------
|
| 127 |
+
|
| 128 |
+
def estimate_dilemma_risk(obs: Any, green_dirs: List[int]) -> float:
|
| 129 |
+
"""Estimate how many vehicles would be in the dilemma zone if we switch now."""
|
| 130 |
+
v100 = obs.vehicles_100m
|
| 131 |
+
dir_labels = ["NS", "SN", "EW", "WE"]
|
| 132 |
+
risk = 0.0
|
| 133 |
+
for d in green_dirs:
|
| 134 |
+
for vt in VEHICLE_TYPE_NAMES:
|
| 135 |
+
count = v100.get(vt, [0, 0, 0, 0])[d]
|
| 136 |
+
if count > 0:
|
| 137 |
+
risk += count * DILEMMA_FRACTIONS[vt]
|
| 138 |
+
return risk
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_green_dirs(phase: int) -> List[int]:
|
| 142 |
+
"""Return which directions are green for a given phase."""
|
| 143 |
+
mapping = {0: [0, 1], 1: [2, 3], 2: [0], 3: [1], 4: [2], 5: [3]}
|
| 144 |
+
return mapping.get(phase, [])
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ---------------------------------------------------------------------------
|
| 148 |
+
# Smart heuristic (primary decision maker)
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
|
| 151 |
+
def smart_heuristic(obs: Any, current_phase: int, time_in_phase: int) -> int:
|
| 152 |
+
"""
|
| 153 |
+
Heuristic that minimizes switching while maintaining good throughput.
|
| 154 |
+
Key insight: the fixed-timer baseline (switch every 10 steps) scores 0.81.
|
| 155 |
+
We can beat it by being smarter about WHEN to switch.
|
| 156 |
+
"""
|
| 157 |
+
# During yellow, we can't do anything — return current pending or active
|
| 158 |
+
if obs.yellow_remaining > 0:
|
| 159 |
+
return obs.active_phase if obs.active_phase >= 0 else current_phase
|
| 160 |
+
|
| 161 |
+
# Emergency override: immediately switch to emergency corridor
|
| 162 |
+
if obs.emergency_direction >= 0:
|
| 163 |
+
d = obs.emergency_direction
|
| 164 |
+
target = 0 if d <= 1 else 1
|
| 165 |
+
if current_phase != target:
|
| 166 |
+
return target
|
| 167 |
+
return current_phase
|
| 168 |
+
|
| 169 |
+
# Compute axis loads (100m weighted heavily, 500m as future pressure)
|
| 170 |
+
ns_sn_100 = obs.ns_100m + obs.sn_100m
|
| 171 |
+
ew_we_100 = obs.ew_100m + obs.we_100m
|
| 172 |
+
ns_sn_500 = obs.ns_500m + obs.sn_500m
|
| 173 |
+
ew_we_500 = obs.ew_500m + obs.we_500m
|
| 174 |
+
|
| 175 |
+
ns_sn_load = ns_sn_100 + 0.3 * ns_sn_500
|
| 176 |
+
ew_we_load = ew_we_100 + 0.3 * ew_we_500
|
| 177 |
+
|
| 178 |
+
# Determine which corridor the current phase serves
|
| 179 |
+
current_green_dirs = get_green_dirs(current_phase)
|
| 180 |
+
serves_ns = any(d in [0, 1] for d in current_green_dirs)
|
| 181 |
+
serves_ew = any(d in [2, 3] for d in current_green_dirs)
|
| 182 |
+
|
| 183 |
+
current_load = 0.0
|
| 184 |
+
opposing_load = 0.0
|
| 185 |
+
if serves_ns and not serves_ew:
|
| 186 |
+
current_load = ns_sn_load
|
| 187 |
+
opposing_load = ew_we_load
|
| 188 |
+
elif serves_ew and not serves_ns:
|
| 189 |
+
current_load = ew_we_load
|
| 190 |
+
opposing_load = ns_sn_load
|
| 191 |
+
else:
|
| 192 |
+
# Phase serves both or neither — use corridor phases
|
| 193 |
+
current_load = ns_sn_load
|
| 194 |
+
opposing_load = ew_we_load
|
| 195 |
+
|
| 196 |
+
# Don't switch if we haven't held long enough
|
| 197 |
+
if time_in_phase < MIN_HOLD_TIME:
|
| 198 |
+
return current_phase
|
| 199 |
+
|
| 200 |
+
# Check if opposing axis is significantly busier
|
| 201 |
+
if opposing_load > 0 and current_load > 0:
|
| 202 |
+
ratio = opposing_load / max(current_load, 1.0)
|
| 203 |
+
elif opposing_load > 0:
|
| 204 |
+
ratio = 10.0 # Current axis is empty
|
| 205 |
+
else:
|
| 206 |
+
ratio = 0.0 # Opposing axis is empty
|
| 207 |
+
|
| 208 |
+
# Also factor in dilemma risk — if many heavy vehicles in green lanes, don't switch
|
| 209 |
+
dilemma_risk = estimate_dilemma_risk(obs, current_green_dirs)
|
| 210 |
+
|
| 211 |
+
# Adaptive threshold: require higher ratio if dilemma risk is high
|
| 212 |
+
effective_threshold = SWITCH_THRESHOLD + (dilemma_risk * 0.1)
|
| 213 |
+
|
| 214 |
+
if ratio >= effective_threshold:
|
| 215 |
+
# Switch to the opposing corridor
|
| 216 |
+
if serves_ns or (not serves_ew and ns_sn_load < ew_we_load):
|
| 217 |
+
# Check if one EW direction dominates — use single phase
|
| 218 |
+
if obs.ew_100m > 3 * obs.we_100m and obs.ew_100m > 10:
|
| 219 |
+
return 4 # EW only
|
| 220 |
+
elif obs.we_100m > 3 * obs.ew_100m and obs.we_100m > 10:
|
| 221 |
+
return 5 # WE only
|
| 222 |
+
return 1 # EW+WE corridor
|
| 223 |
+
else:
|
| 224 |
+
if obs.ns_100m > 3 * obs.sn_100m and obs.ns_100m > 10:
|
| 225 |
+
return 2 # NS only
|
| 226 |
+
elif obs.sn_100m > 3 * obs.ns_100m and obs.sn_100m > 10:
|
| 227 |
+
return 3 # SN only
|
| 228 |
+
return 0 # NS+SN corridor
|
| 229 |
+
|
| 230 |
+
# Check for very unbalanced single-direction loads within current axis
|
| 231 |
+
if serves_ns and time_in_phase >= MIN_HOLD_TIME + 4:
|
| 232 |
+
if obs.ns_100m > 3 * obs.sn_100m and obs.ns_100m > 15 and current_phase == 0:
|
| 233 |
+
return 2 # Focus on NS only
|
| 234 |
+
elif obs.sn_100m > 3 * obs.ns_100m and obs.sn_100m > 15 and current_phase == 0:
|
| 235 |
+
return 3 # Focus on SN only
|
| 236 |
+
elif serves_ew and time_in_phase >= MIN_HOLD_TIME + 4:
|
| 237 |
+
if obs.ew_100m > 3 * obs.we_100m and obs.ew_100m > 15 and current_phase == 1:
|
| 238 |
+
return 4
|
| 239 |
+
elif obs.we_100m > 3 * obs.ew_100m and obs.we_100m > 15 and current_phase == 1:
|
| 240 |
+
return 5
|
| 241 |
+
|
| 242 |
+
return current_phase
|
| 243 |
+
|
| 244 |
+
|
| 245 |
# ---------------------------------------------------------------------------
|
| 246 |
# Observation → LLM prompt
|
| 247 |
# ---------------------------------------------------------------------------
|
|
|
|
| 263 |
f"Total waiting: {obs.total_waiting}",
|
| 264 |
f"Throughput so far: {obs.total_throughput}",
|
| 265 |
]
|
| 266 |
+
# Dilemma risk info
|
| 267 |
+
green_dirs = get_green_dirs(obs.active_phase)
|
| 268 |
+
dilemma = estimate_dilemma_risk(obs, green_dirs)
|
| 269 |
+
lines.append(f"Dilemma risk if switching now: {dilemma:.1f} vehicles")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
lines.append(f"Cumulative dilemma-zone vehicles: {obs.total_dilemma_vehicles:.1f}")
|
| 271 |
|
| 272 |
if obs.emergency_direction >= 0:
|
|
|
|
| 279 |
f"EMERGENCY vehicle in {dir_name} direction (use {phases_help}), "
|
| 280 |
f"waiting {obs.emergency_wait} steps"
|
| 281 |
)
|
| 282 |
+
|
| 283 |
+
# Heuristic recommendation
|
| 284 |
+
heuristic_rec = smart_heuristic(obs, obs.active_phase, obs.time_in_phase)
|
| 285 |
+
lines.append(f"\nHeuristic recommends: phase {heuristic_rec} ({phase_desc.get(heuristic_rec, '?')})")
|
| 286 |
+
|
| 287 |
return "\n".join(lines)
|
| 288 |
|
| 289 |
|
|
|
|
| 299 |
"""Ask the LLM which phase to choose. Falls back to heuristic on failure."""
|
| 300 |
user_prompt = obs_to_summary(obs)
|
| 301 |
if history:
|
| 302 |
+
user_prompt += "\n\nRecent actions:\n" + "\n".join(history[-5:])
|
| 303 |
user_prompt += "\n\nChoose phase (0-5):"
|
| 304 |
|
| 305 |
try:
|
|
|
|
| 320 |
except Exception as exc:
|
| 321 |
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|
| 322 |
|
| 323 |
+
return smart_heuristic(obs, obs.active_phase, obs.time_in_phase)
|
| 324 |
|
| 325 |
|
| 326 |
+
# ---------------------------------------------------------------------------
|
| 327 |
+
# Hybrid decision: heuristic + periodic LLM consultation
|
| 328 |
+
# ---------------------------------------------------------------------------
|
| 329 |
+
|
| 330 |
+
def decide_phase(
|
| 331 |
+
client: OpenAI,
|
| 332 |
+
obs: Any,
|
| 333 |
+
history: List[str],
|
| 334 |
+
step: int,
|
| 335 |
+
current_phase: int,
|
| 336 |
+
time_in_phase: int,
|
| 337 |
+
) -> int:
|
| 338 |
+
"""
|
| 339 |
+
Hybrid approach:
|
| 340 |
+
- Use heuristic for most steps (fast, no API cost, avoids over-switching)
|
| 341 |
+
- Consult LLM every LLM_CONSULT_INTERVAL steps for strategic decisions
|
| 342 |
+
- Always use heuristic for emergency overrides
|
| 343 |
+
"""
|
| 344 |
+
# During yellow, just hold
|
| 345 |
+
if obs.yellow_remaining > 0:
|
| 346 |
+
return current_phase
|
| 347 |
+
|
| 348 |
+
# Emergency: always use heuristic (fast, deterministic)
|
| 349 |
if obs.emergency_direction >= 0:
|
| 350 |
+
return smart_heuristic(obs, current_phase, time_in_phase)
|
| 351 |
+
|
| 352 |
+
# Consult LLM at strategic intervals when we might need to switch
|
| 353 |
+
if (step % LLM_CONSULT_INTERVAL == 0) and time_in_phase >= MIN_HOLD_TIME:
|
| 354 |
+
return get_phase_from_llm(client, obs, history)
|
|
|
|
| 355 |
|
| 356 |
+
# Default: use heuristic
|
| 357 |
+
return smart_heuristic(obs, current_phase, time_in_phase)
|
|
|
|
|
|
|
| 358 |
|
| 359 |
|
| 360 |
# ---------------------------------------------------------------------------
|
|
|
|
| 374 |
try:
|
| 375 |
result = await env.reset(task=task)
|
| 376 |
obs = result.observation
|
| 377 |
+
current_phase = 0 # Start at NS+SN corridor
|
| 378 |
+
time_in_phase = 0
|
| 379 |
|
| 380 |
for step in range(1, MAX_STEPS + 1):
|
| 381 |
if result.done:
|
| 382 |
break
|
| 383 |
|
| 384 |
+
phase = decide_phase(
|
| 385 |
+
client, obs, history, step,
|
| 386 |
+
current_phase, time_in_phase,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# Track phase timing locally
|
| 390 |
+
if phase != current_phase:
|
| 391 |
+
time_in_phase = 0
|
| 392 |
+
current_phase = phase
|
| 393 |
+
else:
|
| 394 |
+
time_in_phase += 1
|
| 395 |
|
| 396 |
+
action = TrafficLightAction(phase=phase)
|
| 397 |
result = await env.step(action)
|
| 398 |
obs = result.observation
|
| 399 |
|
|
|
|
| 413 |
)
|
| 414 |
|
| 415 |
history.append(
|
| 416 |
+
f"Step {step}: phase={phase}, waiting={obs.total_waiting}, "
|
| 417 |
+
f"throughput={obs.total_throughput}, reward={reward:+.2f}"
|
| 418 |
)
|
| 419 |
|
| 420 |
if done:
|
|
|
|
| 466 |
f" [{status}] {r['task']:22s} score={r['score']:.4f} steps={r['steps']}",
|
| 467 |
flush=True,
|
| 468 |
)
|
| 469 |
+
if r.get("grade_details"):
|
| 470 |
+
d = r["grade_details"]
|
| 471 |
+
print(
|
| 472 |
+
f" waiting={d.get('waiting_score', 0):.3f} "
|
| 473 |
+
f"throughput={d.get('throughput_score', 0):.3f} "
|
| 474 |
+
f"safety={d.get('safety_score', 0):.3f} "
|
| 475 |
+
f"dilemma={d.get('total_dilemma_vehicles', 0):.1f}",
|
| 476 |
+
flush=True,
|
| 477 |
+
)
|
| 478 |
avg_score = (
|
| 479 |
sum(r["score"] for r in all_results) / len(all_results)
|
| 480 |
if all_results else 0.0
|