open_ENV / inference.py
arrow072's picture
Update inference.py
7d4de56 verified
"""
inference.py — Traffic Signal Optimization · OpenEnv Hackathon Submission
============================================================================
Env variables expected by the evaluator
----------------------------------------
API_BASE_URL Base URL of the LLM endpoint (e.g. https://router.huggingface.co/v1)
MODEL_NAME Model identifier (e.g. meta-llama/Llama-3.2-3B-Instruct)
HF_TOKEN HuggingFace / API key
stdout log format (parsed by the OpenEnv validator)
-----------------------------------------------------
[START]
[STEP] step=0, score=0.512300, reward=0.024600, done=False
...
[END]
HTTP endpoints (OpenEnv spec: reset / step / state)
----------------------------------------------------
GET / — UI
GET /health — liveness probe ← returns {"status": "healthy"}
GET /metadata — env name/description ← required by validator
GET /schema — action/obs/state ← required by validator
POST /mcp — JSON-RPC 2.0 stub ← required by validator
GET /state — current env state (required by OpenEnv spec)
GET /tasks — enumerate tasks (required by validator)
POST /reset — start new episode
POST /step — advance one step
POST /auto_step — agent picks + steps
POST /grader — run baseline on all tasks, return scores
"""
import os
import sys
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from env import TrafficEnv
from tasks import get_config
from baseline_agent import RuleBasedAgent
import openai
# ---------------------------------------------------------------------------
# LLM Agent
# ---------------------------------------------------------------------------
class LLMAgent:
"""
OpenAI-compatible LLM agent with a rule-based fallback.
Reads API_BASE_URL / MODEL_NAME / HF_TOKEN from the environment.
"""
def __init__(self) -> None:
api_base = os.environ.get("API_BASE_URL", "").strip()
api_key = os.environ.get("HF_TOKEN", "not-needed")
self.model = os.environ.get("MODEL_NAME", "gpt-3.5-turbo")
self.client = None
if api_base:
try:
self.client = openai.OpenAI(base_url=api_base, api_key=api_key)
except Exception:
self.client = None
self.fallback = RuleBasedAgent()
def select_action(self, state: dict) -> int:
if self.client is not None:
prompt = (
f"Traffic intersection state:\n{state}\n\n"
"You control the traffic signal. Reply with ONLY 0 or 1.\n"
"0 = keep current green phase\n"
"1 = switch to the other phase"
)
try:
resp = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a traffic signal controller. Output only 0 or 1."},
{"role": "user", "content": prompt},
],
max_tokens=5,
temperature=0.0,
)
content = resp.choices[0].message.content.strip()
self.fallback.select_action(state) # keep step counter in sync
return 1 if "1" in content else 0
except Exception:
pass
return self.fallback.select_action(state)
def reset(self) -> None:
self.fallback.reset()
# ---------------------------------------------------------------------------
# Shared server-level env / agent (used by HTTP endpoints)
# ---------------------------------------------------------------------------
_env = TrafficEnv(get_config("medium"))
_agent = LLMAgent()
# ---------------------------------------------------------------------------
# FastAPI application
# ---------------------------------------------------------------------------
app = FastAPI(
title="Traffic Signal Optimization — OpenEnv",
description="4-way intersection RL environment · Meta × PyTorch OpenEnv Hackathon",
version="1.0.0",
)
# ── Meta / liveness ─────────────────────────────────────────────────────────
@app.get("/", response_class=HTMLResponse)
def root() -> str:
with open("index.html", "r", encoding="utf-8") as fh:
return fh.read()
# ── FIX 1: /health must return "healthy", not "ok" ──────────────────────────
@app.get("/health")
def health() -> dict:
"""Liveness probe — validator strictly checks status == 'healthy'."""
return {"status": "healthy"}
# ── FIX 2: /metadata endpoint (required by openenv-core validator) ───────────
@app.get("/metadata")
def metadata() -> dict:
"""Environment metadata — validator checks for 'name' and 'description' fields."""
return {
"name": "TrafficSignalOptimization-v1",
"description": (
"AI-driven Traffic Signal Optimization for a 4-way urban intersection. "
"An RL environment that minimises congestion, reduces average waiting time, "
"responds to emergency vehicles, and maintains signal stability across "
"three difficulty tiers: easy, medium, and hard."
),
}
# ── FIX 3: /schema endpoint (required by openenv-core validator) ─────────────
@app.get("/schema")
def schema() -> dict:
"""Action / observation / state schemas — all three keys required by validator."""
return {
"action": {
"type": "Discrete",
"n": 2,
"description": "0 = keep current phase, 1 = switch phase",
},
"observation": {
"type": "Dict",
"keys": [
"north_cars", "south_cars", "east_cars", "west_cars",
"waiting_times", "phase", "emergency_flags", "step_count",
],
},
"state": {
"type": "Dict",
"keys": [
"north_cars", "south_cars", "east_cars", "west_cars",
"waiting_times", "phase", "emergency_flags", "step_count",
],
},
}
# ── FIX 4: /mcp endpoint (required by openenv-core validator) ────────────────
@app.post("/mcp")
def mcp(request: dict = {}) -> dict:
"""JSON-RPC 2.0 stub — validator checks jsonrpc == '2.0'."""
return {"jsonrpc": "2.0", "id": None, "result": {"status": "ok"}}
@app.get("/tasks")
def list_tasks() -> dict:
"""Enumerate the 3 difficulty tasks for the validator."""
return {
"tasks": [
{
"id": "easy",
"description": "Stable low-volume traffic, rare emergencies (1%)",
"max_steps": 50,
"arrival_rate": [0, 1],
"emergency_prob": 0.01,
},
{
"id": "medium",
"description": "Moderate traffic with 10% burst events, 5% emergency",
"max_steps": 100,
"arrival_rate": [1, 3],
"emergency_prob": 0.05,
},
{
"id": "hard",
"description": "High-intensity traffic, 20% bursts, 15% emergency, strict fairness",
"max_steps": 200,
"arrival_rate": [2, 5],
"emergency_prob": 0.15,
},
]
}
# ── Core OpenEnv API ─────────────────────────────────────────────────────────
@app.post("/reset")
def reset_env() -> dict:
state = _env.reset()
_agent.reset()
return {"state": state}
class Action(BaseModel):
action: int
@app.post("/step")
def step_env(data: Action) -> dict:
state, reward, done, info = _env.step(data.action)
score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
return {"state": state, "reward": reward, "score": score, "done": done, "info": info}
@app.get("/state")
def get_state() -> dict:
"""
Return current environment state.
Required by OpenEnv spec (the reset / step / state triple).
"""
return {"state": _env.get_state()}
# ── Convenience endpoints ────────────────────────────────────────────────────
@app.post("/auto_step")
def auto_step() -> dict:
state_dict = _env.get_state()
action = _agent.select_action(state_dict)
state, reward, done, info = _env.step(action)
score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
return {"state": state, "reward": reward, "score": score,
"done": done, "info": info, "action_taken": action}
@app.post("/grader")
def grader() -> dict:
"""
Run the rule-based baseline on all 3 tasks and return per-task scores
normalised to open interval (0, 1) as required by the validator.
"""
results: dict = {}
for task_id in ("easy", "medium", "hard"):
cfg = get_config(task_id)
eval_env = TrafficEnv(cfg)
agent = RuleBasedAgent()
state = eval_env.reset()
agent.reset()
total_reward = 0.0
steps = 0
done = False
while not done:
action = agent.select_action(state)
state, reward, done, info = eval_env.step(action)
total_reward += reward
steps += 1
mean_reward = total_reward / max(1, steps)
score = round(max(0.001, min(0.999, (mean_reward + 1.0) / 2.0)), 6)
results[task_id] = {
"score": score,
"steps": steps,
"total_reward": round(total_reward, 4),
"info": info,
}
return results
# ---------------------------------------------------------------------------
# CLI entry-point — produces structured stdout for the OpenEnv validator
# ---------------------------------------------------------------------------
if __name__ == "__main__":
tasks_to_run = ["easy", "medium", "hard"]
if len(sys.argv) > 1:
raw = sys.argv[1].replace("--task=", "").replace("--task", "").strip()
if raw in tasks_to_run:
tasks_to_run = [raw]
for task_name in tasks_to_run:
config = get_config(task_name)
eval_env = TrafficEnv(config)
eval_agent = LLMAgent()
state = eval_env.reset()
eval_agent.reset()
print("[START]", flush=True)
done = False
step_idx = 0
total_reward = 0.0
while not done:
action = eval_agent.select_action(state)
state, reward, done, info = eval_env.step(action)
total_reward += reward
# score: reward normalised to open interval (0, 1)
score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
print(
f"[STEP] step={step_idx}, score={score}, "
f"reward={round(reward, 6)}, done={done}",
flush=True,
)
step_idx += 1
print("[END]", flush=True)