File size: 3,581 Bytes
0a01302 d41e4fb 0a01302 d41e4fb 0a01302 94ed1c9 0a01302 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from env import TrafficEnv
from tasks import get_config
from baseline_agent import RuleBasedAgent
import os
import openai
class LLMAgent:
def __init__(self):
try:
self.client = openai.OpenAI(
base_url=os.environ["API_BASE_URL"],
api_key=os.environ["API_KEY"]
)
except Exception:
self.client = None
self.fallback = RuleBasedAgent()
def select_action(self, state):
prompt = f"Traffic state: {state}. Reply with 1 to switch phase, 0 to keep phase. Output only 1 or 0."
try:
response = self.client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a traffic signal controller."},
{"role": "user", "content": prompt}
],
max_tokens=5,
temperature=0.0
)
content = response.choices[0].message.content.strip()
# Still call fallback to maintain its internal step counter
self.fallback.select_action(state)
if "1" in content:
return 1
else:
return 0
except Exception as e:
return self.fallback.select_action(state)
def reset(self):
self.fallback.reset()
app = FastAPI()
env = TrafficEnv(get_config("medium"))
agent = LLMAgent()
class Action(BaseModel):
action: int
@app.get("/", response_class=HTMLResponse)
def root():
with open("index.html", "r", encoding="utf-8") as f:
return f.read()
@app.post("/reset")
def reset():
state = env.reset()
try:
state = state.tolist()
except:
pass
agent.reset()
return {"state":state}
@app.post("/step")
def step(data:Action):
state,reward,done,info = env.step(data.action)
try:
state = state.tolist()
except:
pass
return {
"state":state,
"reward":reward,
"done":done,
"info":info
}
@app.post("/auto_step")
def auto_step():
state_dict = env.get_state()
action = agent.select_action(state_dict)
state,reward,done,info = env.step(action)
try:
state = state.tolist()
except:
pass
return {
"state":state,
"reward":reward,
"done":done,
"info":info,
"action_taken": action
}
if __name__ == "__main__":
import sys
tasks_to_run = ["easy", "medium", "hard"]
if len(sys.argv) > 1:
# e.g., if validator optionally passes a task name as argument
task_arg = sys.argv[1].replace("--task=", "").replace("--task", "")
if task_arg in tasks_to_run:
tasks_to_run = [task_arg]
for task_name in tasks_to_run:
config = get_config(task_name)
eval_env = TrafficEnv(config)
eval_agent = LLMAgent()
state = eval_env.reset()
eval_agent.reset()
print("[START]", flush=True)
done = False
step_idx = 0
total_reward = 0.0
while not done:
action = eval_agent.select_action(state)
state, reward, done, info = eval_env.step(action)
print(f"[STEP] step={step_idx}, reward={reward}, done={done}", flush=True)
step_idx += 1
total_reward += reward
print("[END]", flush=True)
|