""" inference.py — Traffic Signal Optimization · OpenEnv Hackathon Submission ============================================================================ Env variables expected by the evaluator ---------------------------------------- API_BASE_URL Base URL of the LLM endpoint (e.g. https://router.huggingface.co/v1) MODEL_NAME Model identifier (e.g. meta-llama/Llama-3.2-3B-Instruct) HF_TOKEN HuggingFace / API key stdout log format (parsed by the OpenEnv validator) ----------------------------------------------------- [START] [STEP] step=0, score=0.512300, reward=0.024600, done=False ... [END] HTTP endpoints (OpenEnv spec: reset / step / state) ---------------------------------------------------- GET / — UI GET /health — liveness probe ← returns {"status": "healthy"} GET /metadata — env name/description ← required by validator GET /schema — action/obs/state ← required by validator POST /mcp — JSON-RPC 2.0 stub ← required by validator GET /state — current env state (required by OpenEnv spec) GET /tasks — enumerate tasks (required by validator) POST /reset — start new episode POST /step — advance one step POST /auto_step — agent picks + steps POST /grader — run baseline on all tasks, return scores """ import os import sys from fastapi import FastAPI from fastapi.responses import HTMLResponse from pydantic import BaseModel from env import TrafficEnv from tasks import get_config from baseline_agent import RuleBasedAgent import openai # --------------------------------------------------------------------------- # LLM Agent # --------------------------------------------------------------------------- class LLMAgent: """ OpenAI-compatible LLM agent with a rule-based fallback. Reads API_BASE_URL / MODEL_NAME / HF_TOKEN from the environment. """ def __init__(self) -> None: api_base = os.environ.get("API_BASE_URL", "").strip() api_key = os.environ.get("HF_TOKEN", "not-needed") self.model = os.environ.get("MODEL_NAME", "gpt-3.5-turbo") self.client = None if api_base: try: self.client = openai.OpenAI(base_url=api_base, api_key=api_key) except Exception: self.client = None self.fallback = RuleBasedAgent() def select_action(self, state: dict) -> int: if self.client is not None: prompt = ( f"Traffic intersection state:\n{state}\n\n" "You control the traffic signal. Reply with ONLY 0 or 1.\n" "0 = keep current green phase\n" "1 = switch to the other phase" ) try: resp = self.client.chat.completions.create( model=self.model, messages=[ {"role": "system", "content": "You are a traffic signal controller. Output only 0 or 1."}, {"role": "user", "content": prompt}, ], max_tokens=5, temperature=0.0, ) content = resp.choices[0].message.content.strip() self.fallback.select_action(state) # keep step counter in sync return 1 if "1" in content else 0 except Exception: pass return self.fallback.select_action(state) def reset(self) -> None: self.fallback.reset() # --------------------------------------------------------------------------- # Shared server-level env / agent (used by HTTP endpoints) # --------------------------------------------------------------------------- _env = TrafficEnv(get_config("medium")) _agent = LLMAgent() # --------------------------------------------------------------------------- # FastAPI application # --------------------------------------------------------------------------- app = FastAPI( title="Traffic Signal Optimization — OpenEnv", description="4-way intersection RL environment · Meta × PyTorch OpenEnv Hackathon", version="1.0.0", ) # ── Meta / liveness ───────────────────────────────────────────────────────── @app.get("/", response_class=HTMLResponse) def root() -> str: with open("index.html", "r", encoding="utf-8") as fh: return fh.read() # ── FIX 1: /health must return "healthy", not "ok" ────────────────────────── @app.get("/health") def health() -> dict: """Liveness probe — validator strictly checks status == 'healthy'.""" return {"status": "healthy"} # ── FIX 2: /metadata endpoint (required by openenv-core validator) ─────────── @app.get("/metadata") def metadata() -> dict: """Environment metadata — validator checks for 'name' and 'description' fields.""" return { "name": "TrafficSignalOptimization-v1", "description": ( "AI-driven Traffic Signal Optimization for a 4-way urban intersection. " "An RL environment that minimises congestion, reduces average waiting time, " "responds to emergency vehicles, and maintains signal stability across " "three difficulty tiers: easy, medium, and hard." ), } # ── FIX 3: /schema endpoint (required by openenv-core validator) ───────────── @app.get("/schema") def schema() -> dict: """Action / observation / state schemas — all three keys required by validator.""" return { "action": { "type": "Discrete", "n": 2, "description": "0 = keep current phase, 1 = switch phase", }, "observation": { "type": "Dict", "keys": [ "north_cars", "south_cars", "east_cars", "west_cars", "waiting_times", "phase", "emergency_flags", "step_count", ], }, "state": { "type": "Dict", "keys": [ "north_cars", "south_cars", "east_cars", "west_cars", "waiting_times", "phase", "emergency_flags", "step_count", ], }, } # ── FIX 4: /mcp endpoint (required by openenv-core validator) ──────────────── @app.post("/mcp") def mcp(request: dict = {}) -> dict: """JSON-RPC 2.0 stub — validator checks jsonrpc == '2.0'.""" return {"jsonrpc": "2.0", "id": None, "result": {"status": "ok"}} @app.get("/tasks") def list_tasks() -> dict: """Enumerate the 3 difficulty tasks for the validator.""" return { "tasks": [ { "id": "easy", "description": "Stable low-volume traffic, rare emergencies (1%)", "max_steps": 50, "arrival_rate": [0, 1], "emergency_prob": 0.01, }, { "id": "medium", "description": "Moderate traffic with 10% burst events, 5% emergency", "max_steps": 100, "arrival_rate": [1, 3], "emergency_prob": 0.05, }, { "id": "hard", "description": "High-intensity traffic, 20% bursts, 15% emergency, strict fairness", "max_steps": 200, "arrival_rate": [2, 5], "emergency_prob": 0.15, }, ] } # ── Core OpenEnv API ───────────────────────────────────────────────────────── @app.post("/reset") def reset_env() -> dict: state = _env.reset() _agent.reset() return {"state": state} class Action(BaseModel): action: int @app.post("/step") def step_env(data: Action) -> dict: state, reward, done, info = _env.step(data.action) score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6) return {"state": state, "reward": reward, "score": score, "done": done, "info": info} @app.get("/state") def get_state() -> dict: """ Return current environment state. Required by OpenEnv spec (the reset / step / state triple). """ return {"state": _env.get_state()} # ── Convenience endpoints ──────────────────────────────────────────────────── @app.post("/auto_step") def auto_step() -> dict: state_dict = _env.get_state() action = _agent.select_action(state_dict) state, reward, done, info = _env.step(action) score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6) return {"state": state, "reward": reward, "score": score, "done": done, "info": info, "action_taken": action} @app.post("/grader") def grader() -> dict: """ Run the rule-based baseline on all 3 tasks and return per-task scores normalised to open interval (0, 1) as required by the validator. """ results: dict = {} for task_id in ("easy", "medium", "hard"): cfg = get_config(task_id) eval_env = TrafficEnv(cfg) agent = RuleBasedAgent() state = eval_env.reset() agent.reset() total_reward = 0.0 steps = 0 done = False while not done: action = agent.select_action(state) state, reward, done, info = eval_env.step(action) total_reward += reward steps += 1 mean_reward = total_reward / max(1, steps) score = round(max(0.001, min(0.999, (mean_reward + 1.0) / 2.0)), 6) results[task_id] = { "score": score, "steps": steps, "total_reward": round(total_reward, 4), "info": info, } return results # --------------------------------------------------------------------------- # CLI entry-point — produces structured stdout for the OpenEnv validator # --------------------------------------------------------------------------- if __name__ == "__main__": tasks_to_run = ["easy", "medium", "hard"] if len(sys.argv) > 1: raw = sys.argv[1].replace("--task=", "").replace("--task", "").strip() if raw in tasks_to_run: tasks_to_run = [raw] for task_name in tasks_to_run: config = get_config(task_name) eval_env = TrafficEnv(config) eval_agent = LLMAgent() state = eval_env.reset() eval_agent.reset() print("[START]", flush=True) done = False step_idx = 0 total_reward = 0.0 while not done: action = eval_agent.select_action(state) state, reward, done, info = eval_env.step(action) total_reward += reward # score: reward normalised to open interval (0, 1) score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6) print( f"[STEP] step={step_idx}, score={score}, " f"reward={round(reward, 6)}, done={done}", flush=True, ) step_idx += 1 print("[END]", flush=True)