meta_env / inference.py
arrow072's picture
Update inference.py
0a01302 verified
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from env import TrafficEnv
from tasks import get_config
from baseline_agent import RuleBasedAgent
import os
import openai
class LLMAgent:
def __init__(self):
try:
self.client = openai.OpenAI(
base_url=os.environ["API_BASE_URL"],
api_key=os.environ["API_KEY"]
)
except Exception:
self.client = None
self.fallback = RuleBasedAgent()
def select_action(self, state):
prompt = f"Traffic state: {state}. Reply with 1 to switch phase, 0 to keep phase. Output only 1 or 0."
try:
response = self.client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a traffic signal controller."},
{"role": "user", "content": prompt}
],
max_tokens=5,
temperature=0.0
)
content = response.choices[0].message.content.strip()
# Still call fallback to maintain its internal step counter
self.fallback.select_action(state)
if "1" in content:
return 1
else:
return 0
except Exception as e:
return self.fallback.select_action(state)
def reset(self):
self.fallback.reset()
app = FastAPI()
env = TrafficEnv(get_config("medium"))
agent = LLMAgent()
class Action(BaseModel):
action: int
@app.get("/", response_class=HTMLResponse)
def root():
with open("index.html", "r", encoding="utf-8") as f:
return f.read()
@app.post("/reset")
def reset():
state = env.reset()
try:
state = state.tolist()
except:
pass
agent.reset()
return {"state":state}
@app.post("/step")
def step(data:Action):
state,reward,done,info = env.step(data.action)
try:
state = state.tolist()
except:
pass
return {
"state":state,
"reward":reward,
"done":done,
"info":info
}
@app.post("/auto_step")
def auto_step():
state_dict = env.get_state()
action = agent.select_action(state_dict)
state,reward,done,info = env.step(action)
try:
state = state.tolist()
except:
pass
return {
"state":state,
"reward":reward,
"done":done,
"info":info,
"action_taken": action
}
if __name__ == "__main__":
import sys
tasks_to_run = ["easy", "medium", "hard"]
if len(sys.argv) > 1:
# e.g., if validator optionally passes a task name as argument
task_arg = sys.argv[1].replace("--task=", "").replace("--task", "")
if task_arg in tasks_to_run:
tasks_to_run = [task_arg]
for task_name in tasks_to_run:
config = get_config(task_name)
eval_env = TrafficEnv(config)
eval_agent = LLMAgent()
state = eval_env.reset()
eval_agent.reset()
print("[START]", flush=True)
done = False
step_idx = 0
total_reward = 0.0
while not done:
action = eval_agent.select_action(state)
state, reward, done, info = eval_env.step(action)
print(f"[STEP] step={step_idx}, reward={reward}, done={done}", flush=True)
step_idx += 1
total_reward += reward
print("[END]", flush=True)