| from fastapi import FastAPI |
| from fastapi.responses import HTMLResponse |
| from pydantic import BaseModel |
| from env import TrafficEnv |
| from tasks import get_config |
| from baseline_agent import RuleBasedAgent |
| import os |
| import openai |
|
|
| class LLMAgent: |
| def __init__(self): |
| try: |
| self.client = openai.OpenAI( |
| base_url=os.environ["API_BASE_URL"], |
| api_key=os.environ["API_KEY"] |
| ) |
| except Exception: |
| self.client = None |
| self.fallback = RuleBasedAgent() |
| |
| def select_action(self, state): |
| prompt = f"Traffic state: {state}. Reply with 1 to switch phase, 0 to keep phase. Output only 1 or 0." |
| try: |
| response = self.client.chat.completions.create( |
| model="gpt-3.5-turbo", |
| messages=[ |
| {"role": "system", "content": "You are a traffic signal controller."}, |
| {"role": "user", "content": prompt} |
| ], |
| max_tokens=5, |
| temperature=0.0 |
| ) |
| content = response.choices[0].message.content.strip() |
| |
| self.fallback.select_action(state) |
| |
| if "1" in content: |
| return 1 |
| else: |
| return 0 |
| except Exception as e: |
| return self.fallback.select_action(state) |
|
|
| def reset(self): |
| self.fallback.reset() |
|
|
| app = FastAPI() |
| env = TrafficEnv(get_config("medium")) |
| agent = LLMAgent() |
|
|
| class Action(BaseModel): |
| action: int |
|
|
| @app.get("/", response_class=HTMLResponse) |
| def root(): |
| with open("index.html", "r", encoding="utf-8") as f: |
| return f.read() |
|
|
| @app.post("/reset") |
| def reset(): |
| state = env.reset() |
| try: |
| state = state.tolist() |
| except: |
| pass |
| agent.reset() |
| return {"state":state} |
|
|
| @app.post("/step") |
| def step(data:Action): |
| state,reward,done,info = env.step(data.action) |
| try: |
| state = state.tolist() |
| except: |
| pass |
| return { |
| "state":state, |
| "reward":reward, |
| "done":done, |
| "info":info |
| } |
|
|
| @app.post("/auto_step") |
| def auto_step(): |
| state_dict = env.get_state() |
| action = agent.select_action(state_dict) |
| state,reward,done,info = env.step(action) |
| try: |
| state = state.tolist() |
| except: |
| pass |
| return { |
| "state":state, |
| "reward":reward, |
| "done":done, |
| "info":info, |
| "action_taken": action |
| } |
|
|
| if __name__ == "__main__": |
| import sys |
| tasks_to_run = ["easy", "medium", "hard"] |
| if len(sys.argv) > 1: |
| |
| task_arg = sys.argv[1].replace("--task=", "").replace("--task", "") |
| if task_arg in tasks_to_run: |
| tasks_to_run = [task_arg] |
|
|
| for task_name in tasks_to_run: |
| config = get_config(task_name) |
| eval_env = TrafficEnv(config) |
| eval_agent = LLMAgent() |
| |
| state = eval_env.reset() |
| eval_agent.reset() |
| |
| print("[START]", flush=True) |
| |
| done = False |
| step_idx = 0 |
| total_reward = 0.0 |
| |
| while not done: |
| action = eval_agent.select_action(state) |
| state, reward, done, info = eval_env.step(action) |
| print(f"[STEP] step={step_idx}, reward={reward}, done={done}", flush=True) |
| step_idx += 1 |
| total_reward += reward |
| |
| print("[END]", flush=True) |
|
|