arrow072 commited on
Commit
7b14fc3
·
verified ·
1 Parent(s): 94ed1c9

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +148 -126
inference.py CHANGED
@@ -1,130 +1,152 @@
1
- from fastapi import FastAPI
2
- from fastapi.responses import HTMLResponse
3
- from pydantic import BaseModel
4
- from env import TrafficEnv
5
- from tasks import get_config
6
- from baseline_agent import RuleBasedAgent
7
  import os
8
- import openai
9
-
10
- class LLMAgent:
11
- def __init__(self):
12
- try:
13
- self.client = openai.OpenAI(
14
- base_url=os.environ["API_BASE_URL"],
15
- api_key=os.environ["API_KEY"]
16
- )
17
- except Exception:
18
- self.client = None
19
- self.fallback = RuleBasedAgent()
20
-
21
- def select_action(self, state):
22
- prompt = f"Traffic state: {state}. Reply with 1 to switch phase, 0 to keep phase. Output only 1 or 0."
23
- try:
24
- response = self.client.chat.completions.create(
25
- model="gpt-3.5-turbo",
26
- messages=[
27
- {"role": "system", "content": "You are a traffic signal controller."},
28
- {"role": "user", "content": prompt}
29
- ],
30
- max_tokens=5,
31
- temperature=0.0
32
- )
33
- content = response.choices[0].message.content.strip()
34
- # Still call fallback to maintain its internal step counter
35
- self.fallback.select_action(state)
36
-
37
- if "1" in content:
38
- return 1
39
- else:
40
- return 0
41
- except Exception as e:
42
- return self.fallback.select_action(state)
43
-
44
- def reset(self):
45
- self.fallback.reset()
46
-
47
- app = FastAPI()
48
- env = TrafficEnv(get_config("medium"))
49
- agent = LLMAgent()
50
-
51
- class Action(BaseModel):
52
- action: int
53
-
54
- @app.get("/", response_class=HTMLResponse)
55
- def root():
56
- with open("index.html", "r", encoding="utf-8") as f:
57
- return f.read()
58
-
59
- @app.post("/reset")
60
- def reset():
61
- state = env.reset()
62
- try:
63
- state = state.tolist()
64
- except:
65
- pass
66
- agent.reset()
67
- return {"state":state}
68
-
69
- @app.post("/step")
70
- def step(data:Action):
71
- state,reward,done,info = env.step(data.action)
72
- try:
73
- state = state.tolist()
74
- except:
75
- pass
76
- return {
77
- "state":state,
78
- "reward":reward,
79
- "done":done,
80
- "info":info
81
- }
82
-
83
- @app.post("/auto_step")
84
- def auto_step():
85
- state_dict = env.get_state()
86
- action = agent.select_action(state_dict)
87
- state,reward,done,info = env.step(action)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  try:
89
- state = state.tolist()
90
- except:
91
- pass
92
- return {
93
- "state":state,
94
- "reward":reward,
95
- "done":done,
96
- "info":info,
97
- "action_taken": action
98
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  if __name__ == "__main__":
101
- import sys
102
- tasks_to_run = ["easy", "medium", "hard"]
103
- if len(sys.argv) > 1:
104
- # e.g., if validator optionally passes a task name as argument
105
- task_arg = sys.argv[1].replace("--task=", "").replace("--task", "")
106
- if task_arg in tasks_to_run:
107
- tasks_to_run = [task_arg]
108
-
109
- for task_name in tasks_to_run:
110
- config = get_config(task_name)
111
- eval_env = TrafficEnv(config)
112
- eval_agent = LLMAgent()
113
-
114
- state = eval_env.reset()
115
- eval_agent.reset()
116
-
117
- print("[START]", flush=True)
118
-
119
- done = False
120
- step_idx = 0
121
- total_reward = 0.0
122
-
123
- while not done:
124
- action = eval_agent.select_action(state)
125
- state, reward, done, info = eval_env.step(action)
126
- print(f"[STEP] step={step_idx}, reward={reward}, done={done}", flush=True)
127
- step_idx += 1
128
- total_reward += reward
129
-
130
- print("[END]", flush=True)
 
 
 
 
 
 
 
1
  import os
2
+ from openai import OpenAI
3
+ from env import TrafficEnv
4
+
5
+ # Safe task configs for 3 modes
6
+ EASY_CONFIG = {
7
+ "max_steps": 20,
8
+ "max_queue": 20,
9
+ "arrival_rate": (0, 2),
10
+ "discharge_rate": (3, 5),
11
+ "emergency_prob": 0.01,
12
+ "switch_penalty": 0.2,
13
+ "starvation_threshold": 10,
14
+ "burst_prob": 0.0,
15
+ "burst_multiplier": 1.0,
16
+ }
17
+
18
+ MEDIUM_CONFIG = {
19
+ "max_steps": 20,
20
+ "max_queue": 20,
21
+ "arrival_rate": (1, 3),
22
+ "discharge_rate": (3, 5),
23
+ "emergency_prob": 0.03,
24
+ "switch_penalty": 0.2,
25
+ "starvation_threshold": 10,
26
+ "burst_prob": 0.2,
27
+ "burst_multiplier": 1.5,
28
+ }
29
+
30
+ HARD_CONFIG = {
31
+ "max_steps": 20,
32
+ "max_queue": 20,
33
+ "arrival_rate": (2, 4),
34
+ "discharge_rate": (3, 5),
35
+ "emergency_prob": 0.05,
36
+ "switch_penalty": 0.2,
37
+ "starvation_threshold": 8,
38
+ "burst_prob": 0.35,
39
+ "burst_multiplier": 2.0,
40
+ }
41
+
42
+
43
+ def strict_score(x: float) -> float:
44
+ # Map RL reward from [-1, 1] to [0, 1], then clamp strictly inside (0,1)
45
+ x = (float(x) + 1.0) / 2.0
46
+ return max(0.001, min(0.999, x))
47
+
48
+
49
+ def build_client():
50
+ api_base_url = os.environ.get("API_BASE_URL")
51
+ api_key = os.environ.get("API_KEY")
52
+ model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")
53
+
54
+ if api_base_url and api_key:
55
+ client = OpenAI(
56
+ base_url=api_base_url,
57
+ api_key=api_key,
58
+ )
59
+ return client, model_name, True
60
+
61
+ # Fallback for environments where these vars are not present
62
+ return None, model_name, False
63
+
64
+
65
+ def choose_action(client, model_name, state):
66
+ prompt = f"""
67
+ You are controlling a traffic signal at a 4-way intersection.
68
+
69
+ Current state:
70
+ {state}
71
+
72
+ Available actions:
73
+ 0 = keep current signal phase
74
+ 1 = switch signal phase
75
+
76
+ Reply with only one number: 0 or 1
77
+ """.strip()
78
+
79
+ response = client.chat.completions.create(
80
+ model=model_name,
81
+ messages=[
82
+ {
83
+ "role": "system",
84
+ "content": "You are a traffic signal controller. Reply with only 0 or 1."
85
+ },
86
+ {
87
+ "role": "user",
88
+ "content": prompt
89
+ }
90
+ ],
91
+ temperature=0,
92
+ )
93
+
94
+ content = response.choices[0].message.content.strip()
95
+
96
  try:
97
+ action = int(content)
98
+ if action not in (0, 1):
99
+ action = 0
100
+ except Exception:
101
+ action = 0
102
+
103
+ return action
104
+
105
+
106
+ def run_task(task_name, config, client, model_name, use_llm):
107
+ env = TrafficEnv(config)
108
+ state = env.reset()
109
+
110
+ print("[START]", flush=True)
111
+
112
+ done = False
113
+ step_idx = 0
114
+ total_reward = 0.0
115
+
116
+ while not done:
117
+ if use_llm:
118
+ action = choose_action(client, model_name, state)
119
+ else:
120
+ # Safe fallback so the script never crashes outside validator
121
+ action = 0
122
+
123
+ state, reward, done, info = env.step(action)
124
+
125
+ step_score = strict_score(reward)
126
+ print(
127
+ f"[STEP] task={task_name}, step={step_idx}, score={step_score:.3f}, done={done}",
128
+ flush=True,
129
+ )
130
+
131
+ total_reward += reward
132
+ step_idx += 1
133
+
134
+ avg_reward = total_reward / max(1, step_idx)
135
+ final_score = strict_score(avg_reward)
136
+
137
+ print(f"[END] task={task_name}, score={final_score:.3f}", flush=True)
138
+
139
+ return final_score
140
+
141
 
142
  if __name__ == "__main__":
143
+ client, model_name, use_llm = build_client()
144
+
145
+ tasks = [
146
+ ("easy", EASY_CONFIG),
147
+ ("medium", MEDIUM_CONFIG),
148
+ ("hard", HARD_CONFIG),
149
+ ]
150
+
151
+ for task_name, config in tasks:
152
+ run_task(task_name, config, client, model_name, use_llm)