File size: 11,535 Bytes
c8f6f13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d4de56
 
 
 
 
 
c8f6f13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d4de56
c8f6f13
 
7d4de56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8f6f13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d4de56
c8f6f13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d4de56
c8f6f13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d4de56
 
 
 
 
 
 
c8f6f13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d4de56
c8f6f13
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
"""
inference.py  —  Traffic Signal Optimization · OpenEnv Hackathon Submission
============================================================================

Env variables expected by the evaluator
----------------------------------------
  API_BASE_URL   Base URL of the LLM endpoint (e.g. https://router.huggingface.co/v1)
  MODEL_NAME     Model identifier          (e.g. meta-llama/Llama-3.2-3B-Instruct)
  HF_TOKEN       HuggingFace / API key

stdout log format  (parsed by the OpenEnv validator)
-----------------------------------------------------
  [START]
  [STEP] step=0, score=0.512300, reward=0.024600, done=False
  ...
  [END]

HTTP endpoints  (OpenEnv spec: reset / step / state)
----------------------------------------------------
  GET  /           — UI
  GET  /health     — liveness probe        ← returns {"status": "healthy"}
  GET  /metadata   — env name/description  ← required by validator
  GET  /schema     — action/obs/state      ← required by validator
  POST /mcp        — JSON-RPC 2.0 stub     ← required by validator
  GET  /state      — current env state     (required by OpenEnv spec)
  GET  /tasks      — enumerate tasks       (required by validator)
  POST /reset      — start new episode
  POST /step       — advance one step
  POST /auto_step  — agent picks + steps
  POST /grader     — run baseline on all tasks, return scores
"""

import os
import sys

from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from env import TrafficEnv
from tasks import get_config
from baseline_agent import RuleBasedAgent
import openai


# ---------------------------------------------------------------------------
# LLM Agent
# ---------------------------------------------------------------------------

class LLMAgent:
    """
    OpenAI-compatible LLM agent with a rule-based fallback.
    Reads API_BASE_URL / MODEL_NAME / HF_TOKEN from the environment.
    """

    def __init__(self) -> None:
        api_base   = os.environ.get("API_BASE_URL", "").strip()
        api_key    = os.environ.get("HF_TOKEN", "not-needed")
        self.model = os.environ.get("MODEL_NAME", "gpt-3.5-turbo")

        self.client = None
        if api_base:
            try:
                self.client = openai.OpenAI(base_url=api_base, api_key=api_key)
            except Exception:
                self.client = None

        self.fallback = RuleBasedAgent()

    def select_action(self, state: dict) -> int:
        if self.client is not None:
            prompt = (
                f"Traffic intersection state:\n{state}\n\n"
                "You control the traffic signal. Reply with ONLY 0 or 1.\n"
                "0 = keep current green phase\n"
                "1 = switch to the other phase"
            )
            try:
                resp = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": "You are a traffic signal controller. Output only 0 or 1."},
                        {"role": "user",   "content": prompt},
                    ],
                    max_tokens=5,
                    temperature=0.0,
                )
                content = resp.choices[0].message.content.strip()
                self.fallback.select_action(state)   # keep step counter in sync
                return 1 if "1" in content else 0
            except Exception:
                pass
        return self.fallback.select_action(state)

    def reset(self) -> None:
        self.fallback.reset()


# ---------------------------------------------------------------------------
# Shared server-level env / agent  (used by HTTP endpoints)
# ---------------------------------------------------------------------------

_env   = TrafficEnv(get_config("medium"))
_agent = LLMAgent()


# ---------------------------------------------------------------------------
# FastAPI application
# ---------------------------------------------------------------------------

app = FastAPI(
    title="Traffic Signal Optimization — OpenEnv",
    description="4-way intersection RL environment · Meta × PyTorch OpenEnv Hackathon",
    version="1.0.0",
)


# ── Meta / liveness ─────────────────────────────────────────────────────────

@app.get("/", response_class=HTMLResponse)
def root() -> str:
    with open("index.html", "r", encoding="utf-8") as fh:
        return fh.read()


# ── FIX 1: /health must return "healthy", not "ok" ──────────────────────────
@app.get("/health")
def health() -> dict:
    """Liveness probe — validator strictly checks status == 'healthy'."""
    return {"status": "healthy"}


# ── FIX 2: /metadata endpoint (required by openenv-core validator) ───────────
@app.get("/metadata")
def metadata() -> dict:
    """Environment metadata — validator checks for 'name' and 'description' fields."""
    return {
        "name": "TrafficSignalOptimization-v1",
        "description": (
            "AI-driven Traffic Signal Optimization for a 4-way urban intersection. "
            "An RL environment that minimises congestion, reduces average waiting time, "
            "responds to emergency vehicles, and maintains signal stability across "
            "three difficulty tiers: easy, medium, and hard."
        ),
    }


# ── FIX 3: /schema endpoint (required by openenv-core validator) ─────────────
@app.get("/schema")
def schema() -> dict:
    """Action / observation / state schemas — all three keys required by validator."""
    return {
        "action": {
            "type": "Discrete",
            "n": 2,
            "description": "0 = keep current phase, 1 = switch phase",
        },
        "observation": {
            "type": "Dict",
            "keys": [
                "north_cars", "south_cars", "east_cars", "west_cars",
                "waiting_times", "phase", "emergency_flags", "step_count",
            ],
        },
        "state": {
            "type": "Dict",
            "keys": [
                "north_cars", "south_cars", "east_cars", "west_cars",
                "waiting_times", "phase", "emergency_flags", "step_count",
            ],
        },
    }


# ── FIX 4: /mcp endpoint (required by openenv-core validator) ────────────────
@app.post("/mcp")
def mcp(request: dict = {}) -> dict:
    """JSON-RPC 2.0 stub — validator checks jsonrpc == '2.0'."""
    return {"jsonrpc": "2.0", "id": None, "result": {"status": "ok"}}


@app.get("/tasks")
def list_tasks() -> dict:
    """Enumerate the 3 difficulty tasks for the validator."""
    return {
        "tasks": [
            {
                "id": "easy",
                "description": "Stable low-volume traffic, rare emergencies (1%)",
                "max_steps": 50,
                "arrival_rate": [0, 1],
                "emergency_prob": 0.01,
            },
            {
                "id": "medium",
                "description": "Moderate traffic with 10% burst events, 5% emergency",
                "max_steps": 100,
                "arrival_rate": [1, 3],
                "emergency_prob": 0.05,
            },
            {
                "id": "hard",
                "description": "High-intensity traffic, 20% bursts, 15% emergency, strict fairness",
                "max_steps": 200,
                "arrival_rate": [2, 5],
                "emergency_prob": 0.15,
            },
        ]
    }


# ── Core OpenEnv API ─────────────────────────────────────────────────────────

@app.post("/reset")
def reset_env() -> dict:
    state = _env.reset()
    _agent.reset()
    return {"state": state}


class Action(BaseModel):
    action: int


@app.post("/step")
def step_env(data: Action) -> dict:
    state, reward, done, info = _env.step(data.action)
    score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
    return {"state": state, "reward": reward, "score": score, "done": done, "info": info}


@app.get("/state")
def get_state() -> dict:
    """
    Return current environment state.
    Required by OpenEnv spec (the reset / step / state triple).
    """
    return {"state": _env.get_state()}


# ── Convenience endpoints ────────────────────────────────────────────────────

@app.post("/auto_step")
def auto_step() -> dict:
    state_dict = _env.get_state()
    action     = _agent.select_action(state_dict)
    state, reward, done, info = _env.step(action)
    score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
    return {"state": state, "reward": reward, "score": score,
            "done": done, "info": info, "action_taken": action}


@app.post("/grader")
def grader() -> dict:
    """
    Run the rule-based baseline on all 3 tasks and return per-task scores
    normalised to open interval (0, 1) as required by the validator.
    """
    results: dict = {}
    for task_id in ("easy", "medium", "hard"):
        cfg      = get_config(task_id)
        eval_env = TrafficEnv(cfg)
        agent    = RuleBasedAgent()
        state    = eval_env.reset()
        agent.reset()

        total_reward = 0.0
        steps        = 0
        done         = False

        while not done:
            action = agent.select_action(state)
            state, reward, done, info = eval_env.step(action)
            total_reward += reward
            steps        += 1

        mean_reward = total_reward / max(1, steps)
        score = round(max(0.001, min(0.999, (mean_reward + 1.0) / 2.0)), 6)
        results[task_id] = {
            "score":        score,
            "steps":        steps,
            "total_reward": round(total_reward, 4),
            "info":         info,
        }
    return results


# ---------------------------------------------------------------------------
# CLI entry-point — produces structured stdout for the OpenEnv validator
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    tasks_to_run = ["easy", "medium", "hard"]

    if len(sys.argv) > 1:
        raw = sys.argv[1].replace("--task=", "").replace("--task", "").strip()
        if raw in tasks_to_run:
            tasks_to_run = [raw]

    for task_name in tasks_to_run:
        config     = get_config(task_name)
        eval_env   = TrafficEnv(config)
        eval_agent = LLMAgent()

        state = eval_env.reset()
        eval_agent.reset()

        print("[START]", flush=True)

        done         = False
        step_idx     = 0
        total_reward = 0.0

        while not done:
            action = eval_agent.select_action(state)
            state, reward, done, info = eval_env.step(action)
            total_reward += reward

            # score: reward normalised to open interval (0, 1)
            score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)

            print(
                f"[STEP] step={step_idx}, score={score}, "
                f"reward={round(reward, 6)}, done={done}",
                flush=True,
            )
            step_idx += 1

        print("[END]", flush=True)