arrow072 commited on
Commit
aa49a5f
·
verified ·
1 Parent(s): e22b3e1

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +64 -139
inference.py CHANGED
@@ -1,153 +1,78 @@
1
  import os
2
  from openai import OpenAI
3
-
4
  from env import TrafficEnv
5
 
6
- # Try importing task configs from tasks.py
7
- # If they are missing, use safe fallback configs
8
- try:
9
- from tasks import EASY_CONFIG, MEDIUM_CONFIG, HARD_CONFIG
10
- except Exception:
11
- EASY_CONFIG = {
12
- "max_steps": 20,
13
- "max_queue": 20,
14
- "arrival_rate": (0, 2),
15
- "discharge_rate": (3, 5),
16
- "emergency_prob": 0.01,
17
- "switch_penalty": 0.2,
18
- "starvation_threshold": 10,
19
- "burst_prob": 0.0,
20
- "burst_multiplier": 1.0,
21
- }
22
-
23
- MEDIUM_CONFIG = {
24
- "max_steps": 20,
25
- "max_queue": 20,
26
- "arrival_rate": (1, 3),
27
- "discharge_rate": (3, 5),
28
- "emergency_prob": 0.03,
29
- "switch_penalty": 0.2,
30
- "starvation_threshold": 10,
31
- "burst_prob": 0.2,
32
- "burst_multiplier": 1.5,
33
- }
34
-
35
- HARD_CONFIG = {
36
- "max_steps": 20,
37
- "max_queue": 20,
38
- "arrival_rate": (2, 4),
39
- "discharge_rate": (3, 5),
40
- "emergency_prob": 0.05,
41
- "switch_penalty": 0.2,
42
- "starvation_threshold": 8,
43
- "burst_prob": 0.35,
44
- "burst_multiplier": 2.0,
45
- }
46
-
47
-
48
- def strict_score(x: float) -> float:
49
- """
50
- Convert any raw value into a score strictly inside (0, 1).
51
- This avoids validator failures for exact 0.0 or 1.0.
52
- """
53
- # If x is RL reward in [-1, 1], map to [0, 1]
54
  x = (float(x) + 1.0) / 2.0
55
- # Clamp strictly inside (0, 1)
56
  return max(0.001, min(0.999, x))
57
 
58
-
59
- class LLMAgent:
60
- def __init__(self):
61
- self.client = OpenAI(
62
- base_url=os.environ["API_BASE_URL"],
63
- api_key=os.environ["API_KEY"],
64
- )
65
- self.model_name = os.environ["MODEL_NAME"]
66
-
67
- def reset(self):
68
- pass
69
-
70
- def select_action(self, state: dict) -> int:
71
- prompt = f"""
72
- You are a traffic control agent for a 4-way intersection.
73
-
74
- Current state:
75
- {state}
76
-
77
- Available actions:
78
- 0 = keep current signal phase
79
- 1 = switch signal phase
80
-
81
- Reply with only one number: 0 or 1
82
- """.strip()
83
-
84
- response = self.client.chat.completions.create(
85
- model=self.model_name,
86
- messages=[
87
- {
88
- "role": "system",
89
- "content": "You are a traffic signal controller. Reply with only 0 or 1.",
90
- },
91
- {
92
- "role": "user",
93
- "content": prompt,
94
- },
95
- ],
96
- temperature=0,
97
- )
98
-
99
- content = response.choices[0].message.content.strip()
100
-
101
- try:
102
- action = int(content)
103
- if action not in (0, 1):
104
- action = 0
105
- except Exception:
106
  action = 0
 
 
107
 
108
- return action
109
-
110
-
111
- def run_task(task_name: str, config: dict) -> float:
112
- env = TrafficEnv(config)
113
- agent = LLMAgent()
114
-
115
- state = env.reset()
116
- agent.reset()
117
-
118
- print("[START]", flush=True)
119
-
120
- done = False
121
- step_idx = 0
122
- total_reward = 0.0
123
-
124
- while not done:
125
- action = agent.select_action(state)
126
- state, reward, done, info = env.step(action)
127
-
128
- step_score = strict_score(reward)
129
- print(
130
- f"[STEP] task={task_name}, step={step_idx}, action={action}, score={step_score:.3f}, done={done}",
131
- flush=True,
132
- )
133
-
134
- total_reward += reward
135
- step_idx += 1
136
-
137
- avg_reward = total_reward / max(1, step_idx)
138
- final_score = strict_score(avg_reward)
139
 
140
- print(f"[END] task={task_name}, score={final_score:.3f}", flush=True)
141
 
142
- return final_score
143
 
 
 
144
 
145
- if __name__ == "__main__":
146
- tasks = [
147
- ("easy", EASY_CONFIG),
148
- ("medium", MEDIUM_CONFIG),
149
- ("hard", HARD_CONFIG),
150
- ]
151
 
152
- for task_name, config in tasks:
153
- run_task(task_name, config)
 
1
  import os
2
  from openai import OpenAI
 
3
  from env import TrafficEnv
4
 
5
+ # Minimal config (no tasks.py dependency)
6
+ CONFIG = {
7
+ "max_steps": 20,
8
+ "max_queue": 20,
9
+ "arrival_rate": (0, 2),
10
+ "discharge_rate": (3, 5),
11
+ "emergency_prob": 0.02,
12
+ "switch_penalty": 0.2,
13
+ "starvation_threshold": 10,
14
+ "burst_prob": 0.1,
15
+ "burst_multiplier": 1.2,
16
+ }
17
+
18
+ def strict_score(x):
19
+ # Convert [-1,1] → (0,1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  x = (float(x) + 1.0) / 2.0
 
21
  return max(0.001, min(0.999, x))
22
 
23
+ # LLM client (IMPORTANT)
24
+ client = OpenAI(
25
+ base_url=os.environ["API_BASE_URL"],
26
+ api_key=os.environ["API_KEY"]
27
+ )
28
+
29
+ MODEL_NAME = os.environ["MODEL_NAME"]
30
+
31
+ env = TrafficEnv(CONFIG)
32
+
33
+ print("[START]", flush=True)
34
+
35
+ state = env.reset()
36
+ done = False
37
+ step_count = 0
38
+ total_reward = 0.0
39
+
40
+ while not done:
41
+ prompt = f"""
42
+ State: {state}
43
+ Choose action:
44
+ 0 = keep
45
+ 1 = switch
46
+ Reply only 0 or 1
47
+ """
48
+
49
+ response = client.chat.completions.create(
50
+ model=MODEL_NAME,
51
+ messages=[
52
+ {"role": "system", "content": "Reply only 0 or 1."},
53
+ {"role": "user", "content": prompt}
54
+ ],
55
+ temperature=0
56
+ )
57
+
58
+ try:
59
+ action = int(response.choices[0].message.content.strip())
60
+ if action not in [0, 1]:
 
 
 
 
 
 
 
 
 
 
61
  action = 0
62
+ except:
63
+ action = 0
64
 
65
+ state, reward, done, info = env.step(action)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ score = strict_score(reward)
68
 
69
+ print(f"[STEP] step={step_count}, score={score:.3f}, done={done}", flush=True)
70
 
71
+ total_reward += reward
72
+ step_count += 1
73
 
74
+ # Final score (IMPORTANT)
75
+ final_score = total_reward / max(1, step_count)
76
+ final_score = strict_score(final_score)
 
 
 
77
 
78
+ print(f"[END] score={final_score:.3f}", flush=True)