traffic_light_env / inference.py
rishabh16196's picture
Upload folder using huggingface_hub
baedb36 verified
"""
Inference Script — Traffic Light Environment
===================================
MANDATORY
- Before submitting, ensure the following variables are defined in your environment configuration:
OPENAI_API_KEY Your API key (also accepts HF_TOKEN or API_KEY as fallbacks).
API_BASE_URL The API endpoint for the LLM.
MODEL_NAME The model identifier to use for inference.
IMAGE_NAME The name of the local image to use for the environment if you are using from_docker_image()
- Defaults are set only for API_BASE_URL and MODEL_NAME:
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
- The inference script must be named `inference.py` and placed in the root directory of the project
- Participants must use OpenAI Client for all LLM calls using above variables
STDOUT FORMAT
- The script must emit exactly three line types to stdout, in this order:
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> rewards=<r1,r2,...,rn>
Example:
[START] task=balanced env=traffic_light_env model=Qwen2.5-72B-Instruct
[STEP] step=1 action=phase(0) reward=-2.40 done=false error=null
[STEP] step=2 action=phase(1) reward=-5.10 done=false error=null
...
[END] success=true steps=200 score=0.624 rewards=-2.40,-5.10,...
"""
import asyncio
import os
import textwrap
from typing import Any, Dict, List, Optional
from openai import OpenAI
from traffic_light_env import TrafficLightAction, TrafficLightEnv
from traffic_light_env.models import (
DILEMMA_FRACTIONS,
DIRECTION_NAMES,
NUM_PHASES,
TASK_NAMES,
VEHICLE_TYPE_NAMES,
)
IMAGE_NAME = os.getenv("IMAGE_NAME")
API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
BENCHMARK = "traffic_light_env"
MAX_STEPS = 200
TEMPERATURE = 0.2
MAX_TOKENS = 128
# Per-task tuning parameters: (min_hold, switch_threshold, llm_interval)
# min_hold: minimum steps to hold a phase before considering switch
# switch_threshold: opposing axis must be this factor busier to trigger switch
# llm_interval: consult LLM every N steps (0 = never use LLM for this task)
TASK_PARAMS: Dict[str, Dict[str, Any]] = {
"balanced": {"min_hold": 8, "switch_thresh": 1.6, "llm_interval": 15},
"rush_hour_ns": {"min_hold": 8, "switch_thresh": 1.8, "llm_interval": 0},
"rush_hour_ew": {"min_hold": 8, "switch_thresh": 1.8, "llm_interval": 0},
"alternating_surge": {"min_hold": 6, "switch_thresh": 1.4, "llm_interval": 0}, # pattern-based
"random_spikes": {"min_hold": 8, "switch_thresh": 1.5, "llm_interval": 15},
"gridlock": {"min_hold": 8, "switch_thresh": 1.3, "llm_interval": 0}, # fixed timer
"emergency_vehicle": {"min_hold": 8, "switch_thresh": 1.6, "llm_interval": 0}, # heuristic only
}
DEFAULT_PARAMS = {"min_hold": 8, "switch_thresh": 1.8, "llm_interval": 10}
# Tasks to run. Override with TRAFFIC_LIGHT_TASKS env var (comma-separated).
TASKS = os.getenv("TRAFFIC_LIGHT_TASKS", ",".join(TASK_NAMES)).split(",")
SYSTEM_PROMPT = textwrap.dedent(
"""
You are a traffic light controller at a 4-way intersection. 4 directions
(NS, SN, EW, WE) with 2 lanes each (8 total). You pick one of 6 phases:
0 = NS+SN corridor (4 lanes green — best throughput for N-S axis)
1 = EW+WE corridor (4 lanes green — best throughput for E-W axis)
2 = NS only 3 = SN only 4 = EW only 5 = WE only
CRITICAL RULES — switching phases costs 2 dead steps (yellow) + dilemma-zone
risk (vehicles that can't stop safely). Every unnecessary switch HURTS your score.
DECISION FRAMEWORK:
1. If currently in yellow transition → keep the pending phase (no choice).
2. If emergency vehicle present → switch to its corridor ONCE, then hold.
3. If held current phase < 8 steps → KEEP current phase (too early to switch).
4. Only switch if opposing axis queue is >1.8× current axis queue.
5. Prefer corridor phases (0 or 1) for maximum throughput.
6. Use single-direction phases (2-5) ONLY if one direction has >3× its opposite.
Scoring: 40% waiting (lower=better), 40% throughput (higher=better), 20% safety
(fewer dilemma vehicles=better). The fixed-timer baseline scores 0.81 by switching
every 10 steps. You should switch LESS often than that on balanced traffic.
Respond: one line with the phase digit (0-5), then a brief reason.
Format: <digit> <reason>
Example: 0 NS+SN corridor has more vehicles, hold current phase
"""
).strip()
# ---------------------------------------------------------------------------
# Logging helpers (mandatory format)
# ---------------------------------------------------------------------------
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
flush=True,
)
# ---------------------------------------------------------------------------
# Dilemma risk estimation
# ---------------------------------------------------------------------------
def estimate_dilemma_risk(obs: Any, green_dirs: List[int]) -> float:
"""Estimate how many vehicles would be in the dilemma zone if we switch now."""
v100 = obs.vehicles_100m
dir_labels = ["NS", "SN", "EW", "WE"]
risk = 0.0
for d in green_dirs:
for vt in VEHICLE_TYPE_NAMES:
count = v100.get(vt, [0, 0, 0, 0])[d]
if count > 0:
risk += count * DILEMMA_FRACTIONS[vt]
return risk
def get_green_dirs(phase: int) -> List[int]:
"""Return which directions are green for a given phase."""
mapping = {0: [0, 1], 1: [2, 3], 2: [0], 3: [1], 4: [2], 5: [3]}
return mapping.get(phase, [])
# ---------------------------------------------------------------------------
# Task-specific strategies
# ---------------------------------------------------------------------------
def _alternating_surge_strategy(obs: Any, current_phase: int, time_in_phase: int) -> int:
"""
Surge pattern: NS/SN surge when (step//30)%2==0, EW/WE surge otherwise.
Pre-emptively switch 2 steps before surge boundary to absorb yellow transition.
"""
step = obs.step_number
period = 30
# Which surge are we in now?
ns_surge = (step // period) % 2 == 0
# When does the next surge boundary hit?
next_boundary = ((step // period) + 1) * period
steps_to_boundary = next_boundary - step
# Target corridor for current surge
target = 0 if ns_surge else 1
# Pre-emptive switch: 2 steps before boundary, switch to upcoming corridor
if steps_to_boundary <= 2:
upcoming_target = 1 if ns_surge else 0 # opposite of current surge
if current_phase != upcoming_target:
return upcoming_target
return current_phase
# During surge, ensure we're on the right corridor
if current_phase != target and time_in_phase >= 6:
return target
# If we're on the right corridor, check for load imbalance within the axis
if current_phase == target and time_in_phase >= 10:
if target == 0: # NS/SN corridor
ns_sn_100 = obs.ns_100m + obs.sn_100m
ew_we_100 = obs.ew_100m + obs.we_100m
# If EW/WE is building up massively despite NS surge, give it some time
if ew_we_100 > ns_sn_100 * 2.5 and ew_we_100 > 20:
return 1
else: # EW/WE corridor
ns_sn_100 = obs.ns_100m + obs.sn_100m
ew_we_100 = obs.ew_100m + obs.we_100m
if ns_sn_100 > ew_we_100 * 2.5 and ns_sn_100 > 20:
return 0
return current_phase
def _gridlock_strategy(obs: Any, current_phase: int, time_in_phase: int) -> int:
"""
Gridlock: all directions have equal rate 2.0.
Use fixed timer (~10 steps) switching between corridor 0 and 1.
Matches the fixed-timer baseline approach which scores 0.848.
Only use corridor phases for maximum throughput.
"""
GRIDLOCK_CYCLE = 10
# Ensure we only use corridor phases
if current_phase not in (0, 1):
return 0 # Reset to corridor
if time_in_phase >= GRIDLOCK_CYCLE:
# Check dilemma risk before switching
green_dirs = get_green_dirs(current_phase)
dilemma = estimate_dilemma_risk(obs, green_dirs)
# Delay switch by 1-2 steps if dilemma risk is very high
if dilemma > 8 and time_in_phase < GRIDLOCK_CYCLE + 2:
return current_phase
# Alternate between corridors
return 1 if current_phase == 0 else 0
return current_phase
def _emergency_strategy(obs: Any, current_phase: int, time_in_phase: int,
emergency_handled: bool) -> int:
"""
Emergency vehicle task: prioritize clearing the emergency ASAP.
Emergency clearance is 40% of the grade — must be within 3 steps for 1.0 score.
Strategy: use corridor phase covering the emergency direction (greens 4 lanes,
including the emergency lane, while maintaining throughput).
"""
if obs.emergency_direction >= 0:
d = obs.emergency_direction
# Use corridor phase — it greens the emergency direction AND its opposite
# for better throughput, while still clearing the emergency
target = 0 if d <= 1 else 1
if current_phase != target:
return target
return current_phase
# Before emergency appears (step < 10), use balanced strategy but
# position on phase 0 (NS+SN) to be ready for 50% of emergencies
if not emergency_handled and obs.step_number < 10:
# Pre-position: stay on phase 0 — if emergency is NS/SN, we're ready
return _balanced_strategy(obs, current_phase, time_in_phase, "balanced")
# After emergency cleared, use standard balanced strategy
return _balanced_strategy(obs, current_phase, time_in_phase, "balanced")
def _rush_hour_strategy(obs: Any, current_phase: int, time_in_phase: int,
task_name: str) -> int:
"""
Rush hour: one axis is much busier (rate ~2.0 vs ~0.4).
Strategy: stay on the busy corridor most of the time.
Give quiet axis brief windows (~6 steps) to prevent total starvation.
Switch back to busy corridor as soon as quiet axis is drained.
"""
if task_name == "rush_hour_ns":
busy_corridor = 0 # NS+SN
else:
busy_corridor = 1 # EW+WE
ns_sn_100 = obs.ns_100m + obs.sn_100m
ew_we_100 = obs.ew_100m + obs.we_100m
ns_sn_load = ns_sn_100 + 0.3 * (obs.ns_500m + obs.sn_500m)
ew_we_load = ew_we_100 + 0.3 * (obs.ew_500m + obs.we_500m)
busy_load = ns_sn_load if busy_corridor == 0 else ew_we_load
quiet_load = ew_we_load if busy_corridor == 0 else ns_sn_load
busy_100 = ns_sn_100 if busy_corridor == 0 else ew_we_100
quiet_100 = ew_we_100 if busy_corridor == 0 else ns_sn_100
green_dirs = get_green_dirs(current_phase)
dilemma = estimate_dilemma_risk(obs, green_dirs)
if current_phase == busy_corridor:
# On busy corridor — hold for at least 8 steps
if time_in_phase < 8:
return current_phase
# Switch only if quiet axis is building up significantly
# and busy axis is somewhat drained
if quiet_100 > 15 and quiet_load > busy_load * 0.6 and dilemma < 6:
return 1 - busy_corridor
# Force give quiet axis a window after extended hold
if time_in_phase >= 12 and quiet_100 > 8 and dilemma < 6:
return 1 - busy_corridor
return current_phase
else:
# On quiet corridor — return to busy corridor quickly
if time_in_phase < 5:
return current_phase
# Return once quiet axis is drained or busy axis is building
if quiet_100 <= 4 or busy_100 > quiet_100 * 1.5 or time_in_phase >= 7:
return busy_corridor
return current_phase
def _balanced_strategy(obs: Any, current_phase: int, time_in_phase: int,
task_name: str) -> int:
"""General adaptive strategy for balanced/random tasks."""
params = TASK_PARAMS.get(task_name, DEFAULT_PARAMS)
min_hold = params["min_hold"]
thresh = params["switch_thresh"]
ns_sn_100 = obs.ns_100m + obs.sn_100m
ew_we_100 = obs.ew_100m + obs.we_100m
ns_sn_load = ns_sn_100 + 0.3 * (obs.ns_500m + obs.sn_500m)
ew_we_load = ew_we_100 + 0.3 * (obs.ew_500m + obs.we_500m)
green_dirs = get_green_dirs(current_phase)
serves_ns = any(d in [0, 1] for d in green_dirs)
serves_ew = any(d in [2, 3] for d in green_dirs)
if serves_ns and not serves_ew:
current_load, opposing_load = ns_sn_load, ew_we_load
elif serves_ew and not serves_ns:
current_load, opposing_load = ew_we_load, ns_sn_load
else:
current_load, opposing_load = ns_sn_load, ew_we_load
if time_in_phase < min_hold:
return current_phase
# Compute switch ratio
if current_load > 0:
ratio = opposing_load / max(current_load, 1.0)
elif opposing_load > 0:
ratio = 10.0
else:
ratio = 0.0
dilemma = estimate_dilemma_risk(obs, green_dirs)
effective_thresh = thresh + (dilemma * 0.08)
if ratio >= effective_thresh:
if ns_sn_load < ew_we_load:
return 1 # EW+WE corridor
else:
return 0 # NS+SN corridor
# Force switch after max hold to prevent starvation
max_hold = 14 if task_name == "random_spikes" else 12
if time_in_phase >= max_hold and opposing_load > 5 and dilemma < 6:
if serves_ns:
return 1
else:
return 0
return current_phase
# ---------------------------------------------------------------------------
# Smart heuristic (primary decision maker)
# ---------------------------------------------------------------------------
def smart_heuristic(obs: Any, current_phase: int, time_in_phase: int,
task_name: str = "balanced",
emergency_handled: bool = False) -> int:
"""
Task-aware heuristic that minimizes switching while maintaining throughput.
Dispatches to task-specific strategies.
"""
# During yellow, can't change — hold current
if obs.yellow_remaining > 0:
return obs.active_phase if obs.active_phase >= 0 else current_phase
# Emergency override for ANY task (highest priority)
if obs.emergency_direction >= 0:
d = obs.emergency_direction
target = 0 if d <= 1 else 1
if current_phase != target:
return target
return current_phase
# Dispatch to task-specific strategy
if task_name == "alternating_surge":
return _alternating_surge_strategy(obs, current_phase, time_in_phase)
elif task_name == "gridlock":
return _gridlock_strategy(obs, current_phase, time_in_phase)
elif task_name == "emergency_vehicle":
return _emergency_strategy(obs, current_phase, time_in_phase, emergency_handled)
elif task_name in ("rush_hour_ns", "rush_hour_ew"):
return _rush_hour_strategy(obs, current_phase, time_in_phase, task_name)
else:
return _balanced_strategy(obs, current_phase, time_in_phase, task_name)
# ---------------------------------------------------------------------------
# Observation → LLM prompt
# ---------------------------------------------------------------------------
def obs_to_summary(obs: Any) -> str:
"""Build a concise text summary of the current observation for the LLM."""
phase_desc = {
0: "NS+SN corridor", 1: "EW+WE corridor",
2: "NS only", 3: "SN only", 4: "EW only", 5: "WE only",
}
lines = [
f"Step: {obs.step_number}/{MAX_STEPS}",
f"Task: {obs.task_name}",
f"Active phase: {obs.active_phase} ({phase_desc.get(obs.active_phase, 'yellow transition')})",
f"Yellow remaining: {obs.yellow_remaining}",
f"Time in phase: {obs.time_in_phase}",
f"100m queues — NS:{obs.ns_100m} SN:{obs.sn_100m} EW:{obs.ew_100m} WE:{obs.we_100m}",
f"500m queues — NS:{obs.ns_500m} SN:{obs.sn_500m} EW:{obs.ew_500m} WE:{obs.we_500m}",
f"Total waiting: {obs.total_waiting}",
f"Throughput so far: {obs.total_throughput}",
]
# Dilemma risk info
green_dirs = get_green_dirs(obs.active_phase)
dilemma = estimate_dilemma_risk(obs, green_dirs)
lines.append(f"Dilemma risk if switching now: {dilemma:.1f} vehicles")
lines.append(f"Cumulative dilemma-zone vehicles: {obs.total_dilemma_vehicles:.1f}")
if obs.emergency_direction >= 0:
dir_name = DIRECTION_NAMES[obs.emergency_direction].upper()
if obs.emergency_direction <= 1:
phases_help = "phase 0 (corridor) or phase " + str(obs.emergency_direction + 2)
else:
phases_help = "phase 1 (corridor) or phase " + str(obs.emergency_direction + 2)
lines.append(
f"EMERGENCY vehicle in {dir_name} direction (use {phases_help}), "
f"waiting {obs.emergency_wait} steps"
)
# Heuristic recommendation
heuristic_rec = smart_heuristic(obs, obs.active_phase, obs.time_in_phase, obs.task_name)
lines.append(f"\nHeuristic recommends: phase {heuristic_rec} ({phase_desc.get(heuristic_rec, '?')})")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# LLM decision
# ---------------------------------------------------------------------------
def get_phase_from_llm(
client: OpenAI,
obs: Any,
history: List[str],
) -> int:
"""Ask the LLM which phase to choose. Falls back to heuristic on failure."""
user_prompt = obs_to_summary(obs)
if history:
user_prompt += "\n\nRecent actions:\n" + "\n".join(history[-5:])
user_prompt += "\n\nChoose phase (0-5):"
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
stream=False,
)
text = (completion.choices[0].message.content or "").strip()
for ch in text:
if ch in "012345":
return int(ch)
except Exception as exc:
print(f"[DEBUG] Model request failed: {exc}", flush=True)
return smart_heuristic(obs, obs.active_phase, obs.time_in_phase, obs.task_name)
# ---------------------------------------------------------------------------
# Hybrid decision: heuristic + periodic LLM consultation
# ---------------------------------------------------------------------------
def decide_phase(
client: OpenAI,
obs: Any,
history: List[str],
step: int,
current_phase: int,
time_in_phase: int,
task_name: str = "balanced",
emergency_handled: bool = False,
) -> int:
"""
Hybrid approach:
- Use task-specific heuristic for most steps
- Consult LLM at strategic intervals for tasks that benefit from it
- Always use heuristic for emergency overrides and pattern-based tasks
"""
params = TASK_PARAMS.get(task_name, DEFAULT_PARAMS)
llm_interval = params["llm_interval"]
min_hold = params["min_hold"]
# During yellow, just hold
if obs.yellow_remaining > 0:
return current_phase
# Emergency: always use heuristic (fast, deterministic)
if obs.emergency_direction >= 0:
return smart_heuristic(obs, current_phase, time_in_phase, task_name, emergency_handled)
# Consult LLM at strategic intervals (only for tasks where it helps)
if llm_interval > 0 and (step % llm_interval == 0) and time_in_phase >= min_hold:
return get_phase_from_llm(client, obs, history)
# Default: use task-specific heuristic
return smart_heuristic(obs, current_phase, time_in_phase, task_name, emergency_handled)
# ---------------------------------------------------------------------------
# Episode runner
# ---------------------------------------------------------------------------
async def run_task(client: OpenAI, env: TrafficLightEnv, task: str) -> Dict[str, Any]:
"""Run a single task episode and return results."""
history: List[str] = []
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task, env=BENCHMARK, model=MODEL_NAME)
try:
result = await env.reset(task=task)
obs = result.observation
current_phase = 0 # Start at NS+SN corridor
time_in_phase = 0
emergency_handled = False
for step in range(1, MAX_STEPS + 1):
if result.done:
break
phase = decide_phase(
client, obs, history, step,
current_phase, time_in_phase,
task_name=task,
emergency_handled=emergency_handled,
)
# Track if emergency was ever active and then cleared
if obs.emergency_direction >= 0:
emergency_handled = True
# Track phase timing locally
if phase != current_phase:
time_in_phase = 0
current_phase = phase
else:
time_in_phase += 1
action = TrafficLightAction(phase=phase)
result = await env.step(action)
obs = result.observation
reward = result.reward or 0.0
done = result.done
error = None
rewards.append(reward)
steps_taken = step
log_step(
step=step,
action=f"phase({phase})",
reward=reward,
done=done,
error=error,
)
history.append(
f"Step {step}: phase={phase}, waiting={obs.total_waiting}, "
f"throughput={obs.total_throughput}, reward={reward:+.2f}"
)
if done:
score = obs.grade_score if obs.grade_score is not None else 0.0
success = score >= 0.5
break
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return {
"task": task,
"success": success,
"score": score,
"steps": steps_taken,
"grade_details": obs.grade_details if hasattr(obs, "grade_details") else None,
}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
async def main() -> None:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
if IMAGE_NAME:
env = await TrafficLightEnv.from_docker_image(IMAGE_NAME)
else:
base_url = os.getenv("ENV_BASE_URL", "http://localhost:8000")
env = TrafficLightEnv(base_url=base_url)
await env.connect()
try:
all_results = []
for task in TASKS:
task = task.strip()
if task not in TASK_NAMES:
print(f"[DEBUG] Skipping unknown task: {task}", flush=True)
continue
result = await run_task(client, env, task)
all_results.append(result)
# Summary
print("\n=== SUMMARY ===", flush=True)
for r in all_results:
status = "PASS" if r["success"] else "FAIL"
print(
f" [{status}] {r['task']:22s} score={r['score']:.4f} steps={r['steps']}",
flush=True,
)
if r.get("grade_details"):
d = r["grade_details"]
print(
f" waiting={d.get('waiting_score', 0):.3f} "
f"throughput={d.get('throughput_score', 0):.3f} "
f"safety={d.get('safety_score', 0):.3f} "
f"dilemma={d.get('total_dilemma_vehicles', 0):.1f}",
flush=True,
)
avg_score = (
sum(r["score"] for r in all_results) / len(all_results)
if all_results else 0.0
)
print(f" Average score: {avg_score:.4f}", flush=True)
finally:
try:
await env.close()
except Exception as e:
print(f"[DEBUG] env.close() error: {e}", flush=True)
if __name__ == "__main__":
asyncio.run(main())