Spaces:
Sleeping
Sleeping
File size: 2,553 Bytes
0245295 e22b3e1 aa49a5f e22b3e1 785b66d aa49a5f 785b66d 8dac098 785b66d aa49a5f 785b66d aa49a5f 8dac098 785b66d aa49a5f 785b66d aa49a5f 785b66d aa49a5f 785b66d aa49a5f 8dac098 aa49a5f 8dac098 aa49a5f 785b66d aa49a5f 785b66d e22b3e1 785b66d aa49a5f e22b3e1 785b66d 8dac098 785b66d e22b3e1 785b66d 0245295 785b66d 5e3bb19 785b66d 9c16f41 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | import os
from openai import OpenAI
from env import TrafficEnv
CONFIG = {
"max_steps": 20,
"max_queue": 20,
"arrival_rate": (0, 2),
"discharge_rate": (3, 5),
"emergency_prob": 0.02,
"switch_penalty": 0.2,
"starvation_threshold": 10,
"burst_prob": 0.1,
"burst_multiplier": 1.2,
}
def strict_score(x):
x = (float(x) + 1.0) / 2.0
return max(0.001, min(0.999, x))
def build_client():
api_base_url = os.environ.get("API_BASE_URL")
api_key = os.environ.get("API_KEY")
model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")
if api_base_url and api_key:
client = OpenAI(base_url=api_base_url, api_key=api_key)
return client, model_name, True
return None, model_name, False
def choose_action(client, model_name, state):
prompt = f"""
You are controlling a traffic signal at a 4-way intersection.
Current state:
{state}
Available actions:
0 = keep current signal phase
1 = switch signal phase
Reply with only one number: 0 or 1
""".strip()
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": "Reply with only 0 or 1."},
{"role": "user", "content": prompt},
],
temperature=0,
)
content = response.choices[0].message.content.strip()
try:
action = int(content)
if action not in (0, 1):
action = 0
except Exception:
action = 0
return action
def run_task(task_name, config, client, model_name, use_llm):
env = TrafficEnv(config)
state = env.reset()
print("[START]", flush=True)
done = False
step_idx = 0
total_reward = 0.0
while not done:
action = choose_action(client, model_name, state) if use_llm else 0
state, reward, done, info = env.step(action)
step_score = strict_score(reward)
print(
f"[STEP] task={task_name}, step={step_idx}, action={action}, score={step_score:.3f}, done={done}",
flush=True,
)
total_reward += reward
step_idx += 1
avg_reward = total_reward / max(1, step_idx)
final_score = strict_score(avg_reward)
print(f"[END] task={task_name}, score={final_score:.3f}", flush=True)
if __name__ == "__main__":
client, model_name, use_llm = build_client()
tasks = [
("easy", CONFIG),
("medium", CONFIG),
("hard", CONFIG),
]
for task_name, config in tasks:
run_task(task_name, config, client, model_name, use_llm |