File size: 11,535 Bytes
c8f6f13 7d4de56 c8f6f13 7d4de56 c8f6f13 7d4de56 c8f6f13 7d4de56 c8f6f13 7d4de56 c8f6f13 7d4de56 c8f6f13 7d4de56 c8f6f13 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 | """
inference.py — Traffic Signal Optimization · OpenEnv Hackathon Submission
============================================================================
Env variables expected by the evaluator
----------------------------------------
API_BASE_URL Base URL of the LLM endpoint (e.g. https://router.huggingface.co/v1)
MODEL_NAME Model identifier (e.g. meta-llama/Llama-3.2-3B-Instruct)
HF_TOKEN HuggingFace / API key
stdout log format (parsed by the OpenEnv validator)
-----------------------------------------------------
[START]
[STEP] step=0, score=0.512300, reward=0.024600, done=False
...
[END]
HTTP endpoints (OpenEnv spec: reset / step / state)
----------------------------------------------------
GET / — UI
GET /health — liveness probe ← returns {"status": "healthy"}
GET /metadata — env name/description ← required by validator
GET /schema — action/obs/state ← required by validator
POST /mcp — JSON-RPC 2.0 stub ← required by validator
GET /state — current env state (required by OpenEnv spec)
GET /tasks — enumerate tasks (required by validator)
POST /reset — start new episode
POST /step — advance one step
POST /auto_step — agent picks + steps
POST /grader — run baseline on all tasks, return scores
"""
import os
import sys
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from env import TrafficEnv
from tasks import get_config
from baseline_agent import RuleBasedAgent
import openai
# ---------------------------------------------------------------------------
# LLM Agent
# ---------------------------------------------------------------------------
class LLMAgent:
"""
OpenAI-compatible LLM agent with a rule-based fallback.
Reads API_BASE_URL / MODEL_NAME / HF_TOKEN from the environment.
"""
def __init__(self) -> None:
api_base = os.environ.get("API_BASE_URL", "").strip()
api_key = os.environ.get("HF_TOKEN", "not-needed")
self.model = os.environ.get("MODEL_NAME", "gpt-3.5-turbo")
self.client = None
if api_base:
try:
self.client = openai.OpenAI(base_url=api_base, api_key=api_key)
except Exception:
self.client = None
self.fallback = RuleBasedAgent()
def select_action(self, state: dict) -> int:
if self.client is not None:
prompt = (
f"Traffic intersection state:\n{state}\n\n"
"You control the traffic signal. Reply with ONLY 0 or 1.\n"
"0 = keep current green phase\n"
"1 = switch to the other phase"
)
try:
resp = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a traffic signal controller. Output only 0 or 1."},
{"role": "user", "content": prompt},
],
max_tokens=5,
temperature=0.0,
)
content = resp.choices[0].message.content.strip()
self.fallback.select_action(state) # keep step counter in sync
return 1 if "1" in content else 0
except Exception:
pass
return self.fallback.select_action(state)
def reset(self) -> None:
self.fallback.reset()
# ---------------------------------------------------------------------------
# Shared server-level env / agent (used by HTTP endpoints)
# ---------------------------------------------------------------------------
_env = TrafficEnv(get_config("medium"))
_agent = LLMAgent()
# ---------------------------------------------------------------------------
# FastAPI application
# ---------------------------------------------------------------------------
app = FastAPI(
title="Traffic Signal Optimization — OpenEnv",
description="4-way intersection RL environment · Meta × PyTorch OpenEnv Hackathon",
version="1.0.0",
)
# ── Meta / liveness ─────────────────────────────────────────────────────────
@app.get("/", response_class=HTMLResponse)
def root() -> str:
with open("index.html", "r", encoding="utf-8") as fh:
return fh.read()
# ── FIX 1: /health must return "healthy", not "ok" ──────────────────────────
@app.get("/health")
def health() -> dict:
"""Liveness probe — validator strictly checks status == 'healthy'."""
return {"status": "healthy"}
# ── FIX 2: /metadata endpoint (required by openenv-core validator) ───────────
@app.get("/metadata")
def metadata() -> dict:
"""Environment metadata — validator checks for 'name' and 'description' fields."""
return {
"name": "TrafficSignalOptimization-v1",
"description": (
"AI-driven Traffic Signal Optimization for a 4-way urban intersection. "
"An RL environment that minimises congestion, reduces average waiting time, "
"responds to emergency vehicles, and maintains signal stability across "
"three difficulty tiers: easy, medium, and hard."
),
}
# ── FIX 3: /schema endpoint (required by openenv-core validator) ─────────────
@app.get("/schema")
def schema() -> dict:
"""Action / observation / state schemas — all three keys required by validator."""
return {
"action": {
"type": "Discrete",
"n": 2,
"description": "0 = keep current phase, 1 = switch phase",
},
"observation": {
"type": "Dict",
"keys": [
"north_cars", "south_cars", "east_cars", "west_cars",
"waiting_times", "phase", "emergency_flags", "step_count",
],
},
"state": {
"type": "Dict",
"keys": [
"north_cars", "south_cars", "east_cars", "west_cars",
"waiting_times", "phase", "emergency_flags", "step_count",
],
},
}
# ── FIX 4: /mcp endpoint (required by openenv-core validator) ────────────────
@app.post("/mcp")
def mcp(request: dict = {}) -> dict:
"""JSON-RPC 2.0 stub — validator checks jsonrpc == '2.0'."""
return {"jsonrpc": "2.0", "id": None, "result": {"status": "ok"}}
@app.get("/tasks")
def list_tasks() -> dict:
"""Enumerate the 3 difficulty tasks for the validator."""
return {
"tasks": [
{
"id": "easy",
"description": "Stable low-volume traffic, rare emergencies (1%)",
"max_steps": 50,
"arrival_rate": [0, 1],
"emergency_prob": 0.01,
},
{
"id": "medium",
"description": "Moderate traffic with 10% burst events, 5% emergency",
"max_steps": 100,
"arrival_rate": [1, 3],
"emergency_prob": 0.05,
},
{
"id": "hard",
"description": "High-intensity traffic, 20% bursts, 15% emergency, strict fairness",
"max_steps": 200,
"arrival_rate": [2, 5],
"emergency_prob": 0.15,
},
]
}
# ── Core OpenEnv API ─────────────────────────────────────────────────────────
@app.post("/reset")
def reset_env() -> dict:
state = _env.reset()
_agent.reset()
return {"state": state}
class Action(BaseModel):
action: int
@app.post("/step")
def step_env(data: Action) -> dict:
state, reward, done, info = _env.step(data.action)
score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
return {"state": state, "reward": reward, "score": score, "done": done, "info": info}
@app.get("/state")
def get_state() -> dict:
"""
Return current environment state.
Required by OpenEnv spec (the reset / step / state triple).
"""
return {"state": _env.get_state()}
# ── Convenience endpoints ────────────────────────────────────────────────────
@app.post("/auto_step")
def auto_step() -> dict:
state_dict = _env.get_state()
action = _agent.select_action(state_dict)
state, reward, done, info = _env.step(action)
score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
return {"state": state, "reward": reward, "score": score,
"done": done, "info": info, "action_taken": action}
@app.post("/grader")
def grader() -> dict:
"""
Run the rule-based baseline on all 3 tasks and return per-task scores
normalised to open interval (0, 1) as required by the validator.
"""
results: dict = {}
for task_id in ("easy", "medium", "hard"):
cfg = get_config(task_id)
eval_env = TrafficEnv(cfg)
agent = RuleBasedAgent()
state = eval_env.reset()
agent.reset()
total_reward = 0.0
steps = 0
done = False
while not done:
action = agent.select_action(state)
state, reward, done, info = eval_env.step(action)
total_reward += reward
steps += 1
mean_reward = total_reward / max(1, steps)
score = round(max(0.001, min(0.999, (mean_reward + 1.0) / 2.0)), 6)
results[task_id] = {
"score": score,
"steps": steps,
"total_reward": round(total_reward, 4),
"info": info,
}
return results
# ---------------------------------------------------------------------------
# CLI entry-point — produces structured stdout for the OpenEnv validator
# ---------------------------------------------------------------------------
if __name__ == "__main__":
tasks_to_run = ["easy", "medium", "hard"]
if len(sys.argv) > 1:
raw = sys.argv[1].replace("--task=", "").replace("--task", "").strip()
if raw in tasks_to_run:
tasks_to_run = [raw]
for task_name in tasks_to_run:
config = get_config(task_name)
eval_env = TrafficEnv(config)
eval_agent = LLMAgent()
state = eval_env.reset()
eval_agent.reset()
print("[START]", flush=True)
done = False
step_idx = 0
total_reward = 0.0
while not done:
action = eval_agent.select_action(state)
state, reward, done, info = eval_env.step(action)
total_reward += reward
# score: reward normalised to open interval (0, 1)
score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
print(
f"[STEP] step={step_idx}, score={score}, "
f"reward={round(reward, 6)}, done={done}",
flush=True,
)
step_idx += 1
print("[END]", flush=True)
|