File size: 3,581 Bytes
0a01302
 
 
d41e4fb
0a01302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d41e4fb
0a01302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94ed1c9
 
0a01302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from env import TrafficEnv
from tasks import get_config
from baseline_agent import RuleBasedAgent
import os
import openai

class LLMAgent:
    def __init__(self):
        try:
            self.client = openai.OpenAI(
                base_url=os.environ["API_BASE_URL"], 
                api_key=os.environ["API_KEY"]
            )
        except Exception:
            self.client = None
        self.fallback = RuleBasedAgent()
        
    def select_action(self, state):
        prompt = f"Traffic state: {state}. Reply with 1 to switch phase, 0 to keep phase. Output only 1 or 0."
        try:
            response = self.client.chat.completions.create(
                model="gpt-3.5-turbo",
                messages=[
                    {"role": "system", "content": "You are a traffic signal controller."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=5,
                temperature=0.0
            )
            content = response.choices[0].message.content.strip()
            # Still call fallback to maintain its internal step counter
            self.fallback.select_action(state)
            
            if "1" in content:
                return 1
            else:
                return 0
        except Exception as e:
            return self.fallback.select_action(state)

    def reset(self):
        self.fallback.reset()

app = FastAPI()
env = TrafficEnv(get_config("medium"))
agent = LLMAgent()

class Action(BaseModel):
    action: int

@app.get("/", response_class=HTMLResponse)
def root():
    with open("index.html", "r", encoding="utf-8") as f:
        return f.read()

@app.post("/reset")
def reset():
    state = env.reset()
    try:
        state = state.tolist()
    except:
        pass
    agent.reset()
    return {"state":state}

@app.post("/step")
def step(data:Action):
    state,reward,done,info = env.step(data.action)
    try:
        state = state.tolist()
    except:
        pass
    return {
        "state":state,
        "reward":reward,
        "done":done,
        "info":info
    }

@app.post("/auto_step")
def auto_step():
    state_dict = env.get_state()
    action = agent.select_action(state_dict)
    state,reward,done,info = env.step(action)
    try:
        state = state.tolist()
    except:
        pass
    return {
        "state":state,
        "reward":reward,
        "done":done,
        "info":info,
        "action_taken": action
    }

if __name__ == "__main__":
    import sys
    tasks_to_run = ["easy", "medium", "hard"]
    if len(sys.argv) > 1:
        # e.g., if validator optionally passes a task name as argument
        task_arg = sys.argv[1].replace("--task=", "").replace("--task", "")
        if task_arg in tasks_to_run:
            tasks_to_run = [task_arg]

    for task_name in tasks_to_run:
        config = get_config(task_name)
        eval_env = TrafficEnv(config)
        eval_agent = LLMAgent()
        
        state = eval_env.reset()
        eval_agent.reset()
        
        print("[START]", flush=True)
        
        done = False
        step_idx = 0
        total_reward = 0.0
        
        while not done:
            action = eval_agent.select_action(state)
            state, reward, done, info = eval_env.step(action)
            print(f"[STEP] step={step_idx}, reward={reward}, done={done}", flush=True)
            step_idx += 1
            total_reward += reward
            
        print("[END]", flush=True)