File size: 2,553 Bytes
0245295
e22b3e1
 
 
aa49a5f
 
 
 
 
 
 
 
 
 
 
 
 
e22b3e1
 
 
785b66d
 
 
 
aa49a5f
785b66d
8dac098
785b66d
aa49a5f
785b66d
aa49a5f
8dac098
785b66d
 
aa49a5f
785b66d
 
aa49a5f
785b66d
 
 
 
 
 
aa49a5f
 
785b66d
aa49a5f
8dac098
 
aa49a5f
8dac098
aa49a5f
 
785b66d
 
aa49a5f
785b66d
 
e22b3e1
785b66d
aa49a5f
e22b3e1
785b66d
 
 
 
 
 
 
 
 
 
 
 
 
8dac098
785b66d
 
 
 
 
 
 
 
 
 
 
 
 
 
e22b3e1
785b66d
 
0245295
785b66d
 
 
 
 
5e3bb19
785b66d
9c16f41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
from openai import OpenAI
from env import TrafficEnv

CONFIG = {
    "max_steps": 20,
    "max_queue": 20,
    "arrival_rate": (0, 2),
    "discharge_rate": (3, 5),
    "emergency_prob": 0.02,
    "switch_penalty": 0.2,
    "starvation_threshold": 10,
    "burst_prob": 0.1,
    "burst_multiplier": 1.2,
}

def strict_score(x):
    x = (float(x) + 1.0) / 2.0
    return max(0.001, min(0.999, x))

def build_client():
    api_base_url = os.environ.get("API_BASE_URL")
    api_key = os.environ.get("API_KEY")
    model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")

    if api_base_url and api_key:
        client = OpenAI(base_url=api_base_url, api_key=api_key)
        return client, model_name, True

    return None, model_name, False

def choose_action(client, model_name, state):
    prompt = f"""
You are controlling a traffic signal at a 4-way intersection.

Current state:
{state}

Available actions:
0 = keep current signal phase
1 = switch signal phase

Reply with only one number: 0 or 1
""".strip()

    response = client.chat.completions.create(
        model=model_name,
        messages=[
            {"role": "system", "content": "Reply with only 0 or 1."},
            {"role": "user", "content": prompt},
        ],
        temperature=0,
    )

    content = response.choices[0].message.content.strip()

    try:
        action = int(content)
        if action not in (0, 1):
            action = 0
    except Exception:
        action = 0

    return action

def run_task(task_name, config, client, model_name, use_llm):
    env = TrafficEnv(config)
    state = env.reset()

    print("[START]", flush=True)

    done = False
    step_idx = 0
    total_reward = 0.0

    while not done:
        action = choose_action(client, model_name, state) if use_llm else 0
        state, reward, done, info = env.step(action)

        step_score = strict_score(reward)
        print(
            f"[STEP] task={task_name}, step={step_idx}, action={action}, score={step_score:.3f}, done={done}",
            flush=True,
        )

        total_reward += reward
        step_idx += 1

    avg_reward = total_reward / max(1, step_idx)
    final_score = strict_score(avg_reward)
    print(f"[END] task={task_name}, score={final_score:.3f}", flush=True)

if __name__ == "__main__":
    client, model_name, use_llm = build_client()

    tasks = [
        ("easy", CONFIG),
        ("medium", CONFIG),
        ("hard", CONFIG),
    ]

    for task_name, config in tasks:
        run_task(task_name, config, client, model_name, use_llm