open_env_traffic_system / inference.py
arrow072's picture
Update inference.py
9c16f41 verified
import os
from openai import OpenAI
from env import TrafficEnv
CONFIG = {
"max_steps": 20,
"max_queue": 20,
"arrival_rate": (0, 2),
"discharge_rate": (3, 5),
"emergency_prob": 0.02,
"switch_penalty": 0.2,
"starvation_threshold": 10,
"burst_prob": 0.1,
"burst_multiplier": 1.2,
}
def strict_score(x):
x = (float(x) + 1.0) / 2.0
return max(0.001, min(0.999, x))
def build_client():
api_base_url = os.environ.get("API_BASE_URL")
api_key = os.environ.get("API_KEY")
model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")
if api_base_url and api_key:
client = OpenAI(base_url=api_base_url, api_key=api_key)
return client, model_name, True
return None, model_name, False
def choose_action(client, model_name, state):
prompt = f"""
You are controlling a traffic signal at a 4-way intersection.
Current state:
{state}
Available actions:
0 = keep current signal phase
1 = switch signal phase
Reply with only one number: 0 or 1
""".strip()
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": "Reply with only 0 or 1."},
{"role": "user", "content": prompt},
],
temperature=0,
)
content = response.choices[0].message.content.strip()
try:
action = int(content)
if action not in (0, 1):
action = 0
except Exception:
action = 0
return action
def run_task(task_name, config, client, model_name, use_llm):
env = TrafficEnv(config)
state = env.reset()
print("[START]", flush=True)
done = False
step_idx = 0
total_reward = 0.0
while not done:
action = choose_action(client, model_name, state) if use_llm else 0
state, reward, done, info = env.step(action)
step_score = strict_score(reward)
print(
f"[STEP] task={task_name}, step={step_idx}, action={action}, score={step_score:.3f}, done={done}",
flush=True,
)
total_reward += reward
step_idx += 1
avg_reward = total_reward / max(1, step_idx)
final_score = strict_score(avg_reward)
print(f"[END] task={task_name}, score={final_score:.3f}", flush=True)
if __name__ == "__main__":
client, model_name, use_llm = build_client()
tasks = [
("easy", CONFIG),
("medium", CONFIG),
("hard", CONFIG),
]
for task_name, config in tasks:
run_task(task_name, config, client, model_name, use_llm