import os from openai import OpenAI from env import TrafficEnv CONFIG = { "max_steps": 20, "max_queue": 20, "arrival_rate": (0, 2), "discharge_rate": (3, 5), "emergency_prob": 0.02, "switch_penalty": 0.2, "starvation_threshold": 10, "burst_prob": 0.1, "burst_multiplier": 1.2, } def strict_score(x): x = (float(x) + 1.0) / 2.0 return max(0.001, min(0.999, x)) def build_client(): api_base_url = os.environ.get("API_BASE_URL") api_key = os.environ.get("API_KEY") model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini") if api_base_url and api_key: client = OpenAI(base_url=api_base_url, api_key=api_key) return client, model_name, True return None, model_name, False def choose_action(client, model_name, state): prompt = f""" You are controlling a traffic signal at a 4-way intersection. Current state: {state} Available actions: 0 = keep current signal phase 1 = switch signal phase Reply with only one number: 0 or 1 """.strip() response = client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": "Reply with only 0 or 1."}, {"role": "user", "content": prompt}, ], temperature=0, ) content = response.choices[0].message.content.strip() try: action = int(content) if action not in (0, 1): action = 0 except Exception: action = 0 return action def run_task(task_name, config, client, model_name, use_llm): env = TrafficEnv(config) state = env.reset() print("[START]", flush=True) done = False step_idx = 0 total_reward = 0.0 while not done: action = choose_action(client, model_name, state) if use_llm else 0 state, reward, done, info = env.step(action) step_score = strict_score(reward) print( f"[STEP] task={task_name}, step={step_idx}, action={action}, score={step_score:.3f}, done={done}", flush=True, ) total_reward += reward step_idx += 1 avg_reward = total_reward / max(1, step_idx) final_score = strict_score(avg_reward) print(f"[END] task={task_name}, score={final_score:.3f}", flush=True) if __name__ == "__main__": client, model_name, use_llm = build_client() tasks = [ ("easy", CONFIG), ("medium", CONFIG), ("hard", CONFIG), ] for task_name, config in tasks: run_task(task_name, config, client, model_name, use_llm