arrow072 commited on
Commit
22fc9d3
·
verified ·
1 Parent(s): 7b14fc3

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +126 -148
inference.py CHANGED
@@ -1,152 +1,130 @@
1
- import os
2
- from openai import OpenAI
 
3
  from env import TrafficEnv
4
-
5
- # Safe task configs for 3 modes
6
- EASY_CONFIG = {
7
- "max_steps": 20,
8
- "max_queue": 20,
9
- "arrival_rate": (0, 2),
10
- "discharge_rate": (3, 5),
11
- "emergency_prob": 0.01,
12
- "switch_penalty": 0.2,
13
- "starvation_threshold": 10,
14
- "burst_prob": 0.0,
15
- "burst_multiplier": 1.0,
16
- }
17
-
18
- MEDIUM_CONFIG = {
19
- "max_steps": 20,
20
- "max_queue": 20,
21
- "arrival_rate": (1, 3),
22
- "discharge_rate": (3, 5),
23
- "emergency_prob": 0.03,
24
- "switch_penalty": 0.2,
25
- "starvation_threshold": 10,
26
- "burst_prob": 0.2,
27
- "burst_multiplier": 1.5,
28
- }
29
-
30
- HARD_CONFIG = {
31
- "max_steps": 20,
32
- "max_queue": 20,
33
- "arrival_rate": (2, 4),
34
- "discharge_rate": (3, 5),
35
- "emergency_prob": 0.05,
36
- "switch_penalty": 0.2,
37
- "starvation_threshold": 8,
38
- "burst_prob": 0.35,
39
- "burst_multiplier": 2.0,
40
- }
41
-
42
-
43
- def strict_score(x: float) -> float:
44
- # Map RL reward from [-1, 1] to [0, 1], then clamp strictly inside (0,1)
45
- x = (float(x) + 1.0) / 2.0
46
- return max(0.001, min(0.999, x))
47
-
48
-
49
- def build_client():
50
- api_base_url = os.environ.get("API_BASE_URL")
51
- api_key = os.environ.get("API_KEY")
52
- model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")
53
-
54
- if api_base_url and api_key:
55
- client = OpenAI(
56
- base_url=api_base_url,
57
- api_key=api_key,
58
- )
59
- return client, model_name, True
60
-
61
- # Fallback for environments where these vars are not present
62
- return None, model_name, False
63
-
64
-
65
- def choose_action(client, model_name, state):
66
- prompt = f"""
67
- You are controlling a traffic signal at a 4-way intersection.
68
-
69
- Current state:
70
- {state}
71
-
72
- Available actions:
73
- 0 = keep current signal phase
74
- 1 = switch signal phase
75
-
76
- Reply with only one number: 0 or 1
77
- """.strip()
78
-
79
- response = client.chat.completions.create(
80
- model=model_name,
81
- messages=[
82
- {
83
- "role": "system",
84
- "content": "You are a traffic signal controller. Reply with only 0 or 1."
85
- },
86
- {
87
- "role": "user",
88
- "content": prompt
89
- }
90
- ],
91
- temperature=0,
92
- )
93
-
94
- content = response.choices[0].message.content.strip()
95
-
96
- try:
97
- action = int(content)
98
- if action not in (0, 1):
99
- action = 0
100
- except Exception:
101
- action = 0
102
-
103
- return action
104
-
105
-
106
- def run_task(task_name, config, client, model_name, use_llm):
107
- env = TrafficEnv(config)
108
  state = env.reset()
109
-
110
- print("[START]", flush=True)
111
-
112
- done = False
113
- step_idx = 0
114
- total_reward = 0.0
115
-
116
- while not done:
117
- if use_llm:
118
- action = choose_action(client, model_name, state)
119
- else:
120
- # Safe fallback so the script never crashes outside validator
121
- action = 0
122
-
123
- state, reward, done, info = env.step(action)
124
-
125
- step_score = strict_score(reward)
126
- print(
127
- f"[STEP] task={task_name}, step={step_idx}, score={step_score:.3f}, done={done}",
128
- flush=True,
129
- )
130
-
131
- total_reward += reward
132
- step_idx += 1
133
-
134
- avg_reward = total_reward / max(1, step_idx)
135
- final_score = strict_score(avg_reward)
136
-
137
- print(f"[END] task={task_name}, score={final_score:.3f}", flush=True)
138
-
139
- return final_score
140
-
 
 
 
 
 
141
 
142
  if __name__ == "__main__":
143
- client, model_name, use_llm = build_client()
144
-
145
- tasks = [
146
- ("easy", EASY_CONFIG),
147
- ("medium", MEDIUM_CONFIG),
148
- ("hard", HARD_CONFIG),
149
- ]
150
-
151
- for task_name, config in tasks:
152
- run_task(task_name, config, client, model_name, use_llm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import HTMLResponse
3
+ from pydantic import BaseModel
4
  from env import TrafficEnv
5
+ from tasks import get_config
6
+ from baseline_agent import RuleBasedAgent
7
+ import os
8
+ import openai
9
+
10
+ class LLMAgent:
11
+ def __init__(self):
12
+ try:
13
+ self.client = openai.OpenAI(
14
+ base_url=os.environ["API_BASE_URL"],
15
+ api_key=os.environ["API_KEY"]
16
+ )
17
+ except Exception:
18
+ self.client = None
19
+ self.fallback = RuleBasedAgent()
20
+
21
+ def select_action(self, state):
22
+ prompt = f"Traffic state: {state}. Reply with 1 to switch phase, 0 to keep phase. Output only 1 or 0."
23
+ try:
24
+ response = self.client.chat.completions.create(
25
+ model="gpt-3.5-turbo",
26
+ messages=[
27
+ {"role": "system", "content": "You are a traffic signal controller."},
28
+ {"role": "user", "content": prompt}
29
+ ],
30
+ max_tokens=5,
31
+ temperature=0.0
32
+ )
33
+ content = response.choices[0].message.content.strip()
34
+ # Still call fallback to maintain its internal step counter
35
+ self.fallback.select_action(state)
36
+
37
+ if "1" in content:
38
+ return 1
39
+ else:
40
+ return 0
41
+ except Exception as e:
42
+ return self.fallback.select_action(state)
43
+
44
+ def reset(self):
45
+ self.fallback.reset()
46
+
47
+ app = FastAPI()
48
+ env = TrafficEnv(get_config("medium"))
49
+ agent = LLMAgent()
50
+
51
+ class Action(BaseModel):
52
+ action: int
53
+
54
+ @app.get("/", response_class=HTMLResponse)
55
+ def root():
56
+ with open("index.html", "r", encoding="utf-8") as f:
57
+ return f.read()
58
+
59
+ @app.post("/reset")
60
+ def reset():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  state = env.reset()
62
+ try:
63
+ state = state.tolist()
64
+ except:
65
+ pass
66
+ agent.reset()
67
+ return {"state":state}
68
+
69
+ @app.post("/step")
70
+ def step(data:Action):
71
+ state,reward,done,info = env.step(data.action)
72
+ try:
73
+ state = state.tolist()
74
+ except:
75
+ pass
76
+ return {
77
+ "state":state,
78
+ "reward":reward,
79
+ "done":done,
80
+ "info":info
81
+ }
82
+
83
+ @app.post("/auto_step")
84
+ def auto_step():
85
+ state_dict = env.get_state()
86
+ action = agent.select_action(state_dict)
87
+ state,reward,done,info = env.step(action)
88
+ try:
89
+ state = state.tolist()
90
+ except:
91
+ pass
92
+ return {
93
+ "state":state,
94
+ "reward":reward,
95
+ "done":done,
96
+ "info":info,
97
+ "action_taken": action
98
+ }
99
 
100
  if __name__ == "__main__":
101
+ import sys
102
+ tasks_to_run = ["easy", "medium", "hard"]
103
+ if len(sys.argv) > 1:
104
+ # e.g., if validator optionally passes a task name as argument
105
+ task_arg = sys.argv[1].replace("--task=", "").replace("--task", "")
106
+ if task_arg in tasks_to_run:
107
+ tasks_to_run = [task_arg]
108
+
109
+ for task_name in tasks_to_run:
110
+ config = get_config(task_name)
111
+ eval_env = TrafficEnv(config)
112
+ eval_agent = LLMAgent()
113
+
114
+ state = eval_env.reset()
115
+ eval_agent.reset()
116
+
117
+ print("[START]", flush=True)
118
+
119
+ done = False
120
+ step_idx = 0
121
+ total_reward = 0.0
122
+
123
+ while not done:
124
+ action = eval_agent.select_action(state)
125
+ state, reward, done, info = eval_env.step(action)
126
+ print(f"[STEP] step={step_idx}, reward={reward}, done={done}", flush=True)
127
+ step_idx += 1
128
+ total_reward += reward
129
+
130
+ print("[END]", flush=True)