from fastapi import FastAPI from fastapi.responses import HTMLResponse from pydantic import BaseModel from env import TrafficEnv from tasks import get_config from baseline_agent import RuleBasedAgent import os import openai class LLMAgent: def __init__(self): try: self.client = openai.OpenAI( base_url=os.environ["API_BASE_URL"], api_key=os.environ["API_KEY"] ) except Exception: self.client = None self.fallback = RuleBasedAgent() def select_action(self, state): prompt = f"Traffic state: {state}. Reply with 1 to switch phase, 0 to keep phase. Output only 1 or 0." try: response = self.client.chat.completions.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are a traffic signal controller."}, {"role": "user", "content": prompt} ], max_tokens=5, temperature=0.0 ) content = response.choices[0].message.content.strip() # Still call fallback to maintain its internal step counter self.fallback.select_action(state) if "1" in content: return 1 else: return 0 except Exception as e: return self.fallback.select_action(state) def reset(self): self.fallback.reset() app = FastAPI() env = TrafficEnv(get_config("medium")) agent = LLMAgent() class Action(BaseModel): action: int @app.get("/", response_class=HTMLResponse) def root(): with open("index.html", "r", encoding="utf-8") as f: return f.read() @app.post("/reset") def reset(): state = env.reset() try: state = state.tolist() except: pass agent.reset() return {"state":state} @app.post("/step") def step(data:Action): state,reward,done,info = env.step(data.action) try: state = state.tolist() except: pass return { "state":state, "reward":reward, "done":done, "info":info } @app.post("/auto_step") def auto_step(): state_dict = env.get_state() action = agent.select_action(state_dict) state,reward,done,info = env.step(action) try: state = state.tolist() except: pass return { "state":state, "reward":reward, "done":done, "info":info, "action_taken": action } if __name__ == "__main__": import sys tasks_to_run = ["easy", "medium", "hard"] if len(sys.argv) > 1: # e.g., if validator optionally passes a task name as argument task_arg = sys.argv[1].replace("--task=", "").replace("--task", "") if task_arg in tasks_to_run: tasks_to_run = [task_arg] for task_name in tasks_to_run: config = get_config(task_name) eval_env = TrafficEnv(config) eval_agent = LLMAgent() state = eval_env.reset() eval_agent.reset() print("[START]", flush=True) done = False step_idx = 0 total_reward = 0.0 while not done: action = eval_agent.select_action(state) state, reward, done, info = eval_env.step(action) print(f"[STEP] step={step_idx}, reward={reward}, done={done}", flush=True) step_idx += 1 total_reward += reward print("[END]", flush=True)