arrow072 commited on
Commit
c0b7f5e
·
verified ·
1 Parent(s): 5121e53

Upload 12 files

Browse files
Files changed (12) hide show
  1. Dockerfile +9 -0
  2. README.md +155 -5
  3. baseline_agent.py +154 -0
  4. inference.py +328 -0
  5. openenv.yaml +208 -0
  6. pyproject.toml +21 -0
  7. requirements.txt +6 -0
  8. server/app.py +14 -0
  9. tasks.py +161 -0
  10. test_env.py +331 -0
  11. test_inference.py +19 -0
  12. uv.lock +0 -0
Dockerfile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . .
6
+
7
+ RUN pip install fastapi uvicorn numpy pydantic openai "openenv-core>=0.2.0"
8
+
9
+ CMD ["uvicorn","inference:app","--host","0.0.0.0","--port","7860"]
README.md CHANGED
@@ -1,10 +1,160 @@
1
  ---
2
- title: 'Meta Env '
3
- emoji: 🦀
4
- colorFrom: pink
5
- colorTo: gray
6
  sdk: docker
 
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Traffic Signal Optimization — OpenEnv Elite
3
+ emoji: 🚦
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  ---
10
 
11
+ # 🚥 Traffic Signal Optimization OpenEnv Elite
12
+
13
+ > **Meta × PyTorch OpenEnv Hackathon Submission**
14
+ >
15
+ > A world-class Reinforcement Learning environment for urban traffic control, featuring stochastic multi-lane dynamics, emergency vehicle prioritization, and sophisticated fairness-driven rewards.
16
+
17
+ ---
18
+
19
+ ## 🏗️ Problem Statement
20
+
21
+ Fixed-cycle traffic signals are a relic of the past. In modern urban environments, they create **needless congestion**, increase **CO2 emissions**, and — most critically — cause **life-threatening delays** for emergency vehicles.
22
+
23
+ This project provides a high-fidelity 4-way intersection simulation designed for OpenEnv. It challenges RL agents to move beyond simple throughput and master the art of **dynamic balancing**: serving high-demand lanes while maintaining fairness for low-traffic directions and clearing "Golden Windows" for emergency responders.
24
+
25
+ ---
26
+
27
+ ## 🚀 Quick Start
28
+
29
+ ```bash
30
+ # Run the complete suite: Simulation + Sanity Checks + Comparison
31
+ python test_env.py
32
+
33
+ # Run a specific high-intensity scenario
34
+ python test_env.py hard
35
+ ```
36
+
37
+ ```python
38
+ from env import TrafficEnv
39
+ from tasks import get_config
40
+ from baseline_agent import RuleBasedAgent
41
+
42
+ # 1. Load a structured difficulty profile
43
+ config = get_config("medium")
44
+ env = TrafficEnv(config)
45
+
46
+ # 2. Initialize our sophisticated Rule-Based Controller
47
+ agent = RuleBasedAgent()
48
+
49
+ state = env.reset()
50
+ done = False
51
+
52
+ while not done:
53
+ action = agent.select_action(state)
54
+ state, reward, done, info = env.step(action)
55
+
56
+ print(f"Total Cleared: {info['total_cleared']}")
57
+ print(f"Fairness Index: {info['fairness_score']:.2f}")
58
+ ```
59
+
60
+ ---
61
+
62
+ ## 🧠 Environment Design Philosophy
63
+
64
+ ### State Space
65
+ The environment exposes a **14-dimensional** continuous observation vector, providing the agent with full situational awareness:
66
+ - **Queues (4)**: Exact vehicle count per lane [N, S, E, W].
67
+ - **Wait Pressure (4)**: Cumulative "impatience" score per lane.
68
+ - **Emergency Flags (4)**: Binary detection of EVs per lane.
69
+ - **Signal State (2)**: Current phase [0=NS, 1=EW] and step count.
70
+
71
+ ### Action Space
72
+ - `0`: **Maintain** — keep the current green phase.
73
+ - `1`: **Switch** — transition the signal (includes yellow-phase discharge friction).
74
+
75
+ ---
76
+
77
+ ## 💎 Reward Engineering (The "Judge's Choice")
78
+
79
+ Our reward function is the core of this submission. It isn't just a count; it's a **multi-objective ethical framework** clipped to `[-1, 1]`:
80
+
81
+ | Component | Logic | Purpose |
82
+ | :--- | :--- | :--- |
83
+ | **Throughput (+)** | `+0.20 * cars_cleared` | Incentivizes active vehicle flow. |
84
+ | **Density (-)** | `-0.40 * total_congestion` | Penalizes letting the intersection fill up. |
85
+ | **Bottleneck (-)** | `-0.15 * max_queue` | Discourages extreme build-up in any single lane. |
86
+ | **Stability (-)** | `-switch_penalty` | Prevents "flickering" and promotes signal stability. |
87
+ | **Fairness (+/-)** | `+0.10` bonus / `-penalty` | Rewards balanced service; penalizes starvation. |
88
+ | **Emergency (🚨)** | `Golden Window` Bonus | Massive reward for clearing EVs within target steps. |
89
+ | **EV Delay (-)** | `Exponential Penalty` | Punishes agents for delaying life-saving vehicles. |
90
+
91
+ ---
92
+
93
+ ## 📊 Evaluation Metrics
94
+
95
+ We track **8 key performance indicators** per episode to ensure a winning submission can be quantified:
96
+
97
+ 1. **Total Cleared**: Raw efficiency metric.
98
+ 2. **Avg Waiting Time**: The "commuter frustration" index.
99
+ 3. **Max Queue Length**: Gauges system robustness against bottlenecks.
100
+ 4. **Signal Switch Count**: Measures policy stability.
101
+ 5. **Congestion Score**: Final system state snapshot.
102
+ 6. **Avg EV Clear Time**: Critical safety metric (lower is better).
103
+ 7. **Fairness Score**: [0, 1] index — how equally did we serve all lanes?
104
+ 8. **Total EV Penalty**: Measures total failure to prioritize safety.
105
+
106
+ ---
107
+
108
+ ## ⚡ Task Difficulty Levels
109
+
110
+ | Parameter | Easy | Medium | Hard |
111
+ | :--- | :--- | :--- | :--- |
112
+ | **Arrival Rate** | 0–1 | 1–3 | 2–5 |
113
+ | **Discharge Rate** | 4–5 | 3–5 | 2–4 |
114
+ | **Burst Frequency** | 0% | 10% | 20% |
115
+ | **Emergency Prob** | 1% | 5% | 15% |
116
+ | **EV Golden Window** | 8 steps | 5 steps | 3 steps |
117
+ | **Fairness Limit** | 20 steps | 15 steps | 10 steps |
118
+
119
+ ---
120
+
121
+ ## 🚑 Emergency & Fairness Logic
122
+
123
+ ### The "Golden Window"
124
+ When an Emergency Vehicle (EV) appears, the agent is granted a bonus if it switches and clears the lane within the **Golden Window** (defined per difficulty). Failing to do so triggers an **exponential delay penalty**, simulating the real-world cost of stopping an ambulance or fire truck.
125
+
126
+ ### Fairness Guard
127
+ To prevent "Starvation" (where the agent ignores a low-traffic lane to optimize throughput on a high-traffic lane), a **Fairness Score** is calculated. If a lane remains red beyond the **Starvation Limit**, the agent suffers a heavy penalty. This forces the agent to learn the complex trade-off between total throughput and social fairness.
128
+
129
+ ---
130
+
131
+ ## 🚶 Step Walkthrough
132
+
133
+ ```text
134
+ Step 12: 🚨 Ambulance detected in East lane (currently RED).
135
+ - EW Queue: 4, EV Timer: 0
136
+ - Agent receives p_emergency penalty.
137
+
138
+ Step 13: Agent Action: 1 (SWITCH to EW).
139
+ - Switch penalty applied (-0.20).
140
+ - NS lanes stop; EW lanes turn GREEN.
141
+
142
+ Step 14: EV Cleared!
143
+ - EV Clear Time: 2 steps.
144
+ - Agent receives r_ev_bonus (+0.25) for "Golden Window" clearance.
145
+ - Total cleared (+0.60 reward).
146
+ ```
147
+
148
+ ---
149
+
150
+ ## 🔮 Future Improvements
151
+
152
+ - **Multi-Intersection Coordination**: Extending to a grid of agents using MARL.
153
+ - **Pedestrian Logic**: Adding crosswalks and pedestrian priority.
154
+ - **V2X Communication**: Providing agents with ahead-of-time traffic predictions.
155
+
156
+ ---
157
+
158
+ ## 📜 License
159
+
160
+ MIT © 2026 Meta x PyTorch OpenEnv Hackathon
baseline_agent.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ baseline_agent.py — Rule-Based Traffic Signal Controller
3
+ =========================================================
4
+
5
+ A deterministic agent that makes signal decisions using handcrafted
6
+ heuristics. Acts as the reproducible baseline for comparison against
7
+ trained RL policies.
8
+
9
+ Decision hierarchy (highest priority first):
10
+ 1. Emergency vehicle preemption — switch if an emergency vehicle is
11
+ stuck at a red light and minimum green time has been served.
12
+ 2. Minimum green time — never switch before a floor number of steps
13
+ to prevent rapid oscillation.
14
+ 3. Queue-imbalance trigger — switch when the queued-vehicle disparity
15
+ between NS and EW exceeds a configurable threshold.
16
+ 4. Maximum green cap — force a switch if one direction has been green
17
+ for too long (fairness guard).
18
+ 5. Default — keep current phase.
19
+
20
+ Usage
21
+ -----
22
+ from baseline_agent import RuleBasedAgent
23
+ agent = RuleBasedAgent(min_green_time=5, imbalance_threshold=5)
24
+ action = agent.select_action(state) # 0 or 1
25
+ """
26
+
27
+ from __future__ import annotations
28
+ from typing import Any, Dict
29
+
30
+
31
+ class RuleBasedAgent:
32
+ """
33
+ Rule-based traffic signal controller.
34
+
35
+ Parameters
36
+ ----------
37
+ min_green_time : int
38
+ Minimum number of steps to hold a phase before switching.
39
+ Prevents oscillatory behaviour.
40
+ imbalance_threshold : int
41
+ Minimum queue difference (NS vs EW) required to trigger a switch.
42
+ max_green_time : int
43
+ Maximum consecutive steps before forcing a phase change.
44
+ Acts as a starvation safety net.
45
+ emergency_min_green : int
46
+ Reduced minimum green time used when an emergency vehicle is
47
+ waiting on a red lane.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ min_green_time: int = 5,
53
+ imbalance_threshold: int = 5,
54
+ max_green_time: int = 20,
55
+ emergency_min_green: int = 2,
56
+ ) -> None:
57
+ self.min_green_time = min_green_time
58
+ self.imbalance_threshold = imbalance_threshold
59
+ self.max_green_time = max_green_time
60
+ self.emergency_min_green = emergency_min_green
61
+
62
+ # Steps since last switch
63
+ self._steps_since_switch: int = 0
64
+
65
+ # ------------------------------------------------------------------
66
+ # Public API
67
+ # ------------------------------------------------------------------
68
+
69
+ def select_action(self, state: Dict[str, Any]) -> int:
70
+ """
71
+ Choose an action given the current environment state.
72
+
73
+ Parameters
74
+ ----------
75
+ state : dict
76
+ State dictionary as returned by ``TrafficEnv.get_state()``.
77
+
78
+ Returns
79
+ -------
80
+ int
81
+ 0 → keep current signal phase
82
+ 1 → switch signal phase
83
+ """
84
+ self._steps_since_switch += 1
85
+
86
+ north = state["north_cars"]
87
+ south = state["south_cars"]
88
+ east = state["east_cars"]
89
+ west = state["west_cars"]
90
+ phase = state["phase"]
91
+
92
+ # emergency_flags may be a dict (TrafficEnv) or a list (legacy)
93
+ ef = state["emergency_flags"]
94
+ if isinstance(ef, dict):
95
+ ev_north, ev_south = ef["north"], ef["south"]
96
+ ev_east, ev_west = ef["east"], ef["west"]
97
+ else:
98
+ ev_north, ev_south, ev_east, ev_west = (bool(x) for x in ef)
99
+
100
+ ns_total = north + south
101
+ ew_total = east + west
102
+
103
+ # ── Rule 1: Emergency preemption ──────────────────────────────
104
+ # High priority: switch if an EV is blocked on a red lane.
105
+ # We apply a small safety buffer (2 steps) to avoid rapid jitter.
106
+ emergency_on_red = False
107
+ if phase == 0 and (ev_east or ev_west):
108
+ emergency_on_red = True
109
+ elif phase == 1 and (ev_north or ev_south):
110
+ emergency_on_red = True
111
+
112
+ if emergency_on_red:
113
+ if self._steps_since_switch >= self.emergency_min_green:
114
+ return self._switch()
115
+
116
+ # ── Rule 2: Oscillation Damping (Minimum Green Time) ──────────
117
+ if self._steps_since_switch < self.min_green_time:
118
+ return 0
119
+
120
+ # ── Rule 3: Congestion/Pressure Trigger ───────────────────────
121
+ # We use a weighted pressure calculation (Queues + EV presence).
122
+ ns_pressure = ns_total + (20 if (ev_north or ev_south) else 0)
123
+ ew_pressure = ew_total + (20 if (ev_east or ev_west) else 0)
124
+
125
+ if phase == 0: # NS currently green
126
+ # Only switch if EW pressure is significantly higher
127
+ if ew_pressure > ns_pressure + self.imbalance_threshold:
128
+ return self._switch()
129
+ else: # EW currently green
130
+ if ns_pressure > ew_pressure + self.imbalance_threshold:
131
+ return self._switch()
132
+
133
+ # ── Rule 4: Fairness Guard (Maximum Green Time) ───���──────────
134
+ if self._steps_since_switch >= self.max_green_time:
135
+ # Only switch if there's actually someone waiting on the other side
136
+ other_side_waiting = (ew_total > 0) if phase == 0 else (ns_total > 0)
137
+ if other_side_waiting:
138
+ return self._switch()
139
+
140
+ # ── Rule 5: Default — hold current phase ─────────────────────
141
+ return 0
142
+
143
+ def reset(self) -> None:
144
+ """Reset internal step counter (call at the start of each episode)."""
145
+ self._steps_since_switch = 0
146
+
147
+ # ------------------------------------------------------------------
148
+ # Internal helpers
149
+ # ------------------------------------------------------------------
150
+
151
+ def _switch(self) -> int:
152
+ """Record a switch and reset the step counter."""
153
+ self._steps_since_switch = 0
154
+ return 1
inference.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py — Traffic Signal Optimization · OpenEnv Hackathon Submission
3
+ ============================================================================
4
+
5
+ Env variables expected by the evaluator
6
+ ----------------------------------------
7
+ API_BASE_URL Base URL of the LLM endpoint (e.g. https://router.huggingface.co/v1)
8
+ MODEL_NAME Model identifier (e.g. meta-llama/Llama-3.2-3B-Instruct)
9
+ HF_TOKEN HuggingFace / API key
10
+
11
+ stdout log format (parsed by the OpenEnv validator)
12
+ -----------------------------------------------------
13
+ [START]
14
+ [STEP] step=0, score=0.512300, reward=0.024600, done=False
15
+ ...
16
+ [END]
17
+
18
+ HTTP endpoints (OpenEnv spec: reset / step / state)
19
+ ----------------------------------------------------
20
+ GET / — UI
21
+ GET /health — liveness probe ← returns {"status": "healthy"}
22
+ GET /metadata — env name/description ← required by validator
23
+ GET /schema — action/obs/state ← required by validator
24
+ POST /mcp — JSON-RPC 2.0 stub ← required by validator
25
+ GET /state — current env state (required by OpenEnv spec)
26
+ GET /tasks — enumerate tasks (required by validator)
27
+ POST /reset — start new episode
28
+ POST /step — advance one step
29
+ POST /auto_step — agent picks + steps
30
+ POST /grader — run baseline on all tasks, return scores
31
+ """
32
+
33
+ import os
34
+ import sys
35
+
36
+ from fastapi import FastAPI
37
+ from fastapi.responses import HTMLResponse
38
+ from pydantic import BaseModel
39
+ from env import TrafficEnv
40
+ from tasks import get_config
41
+ from baseline_agent import RuleBasedAgent
42
+ import openai
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # LLM Agent
47
+ # ---------------------------------------------------------------------------
48
+
49
+ class LLMAgent:
50
+ """
51
+ OpenAI-compatible LLM agent with a rule-based fallback.
52
+ Reads API_BASE_URL / MODEL_NAME / HF_TOKEN from the environment.
53
+ """
54
+
55
+ def __init__(self) -> None:
56
+ api_base = os.environ.get("API_BASE_URL", "").strip()
57
+ api_key = os.environ.get("HF_TOKEN", "not-needed")
58
+ self.model = os.environ.get("MODEL_NAME", "gpt-3.5-turbo")
59
+
60
+ self.client = None
61
+ if api_base:
62
+ try:
63
+ self.client = openai.OpenAI(base_url=api_base, api_key=api_key)
64
+ except Exception:
65
+ self.client = None
66
+
67
+ self.fallback = RuleBasedAgent()
68
+
69
+ def select_action(self, state: dict) -> int:
70
+ if self.client is not None:
71
+ prompt = (
72
+ f"Traffic intersection state:\n{state}\n\n"
73
+ "You control the traffic signal. Reply with ONLY 0 or 1.\n"
74
+ "0 = keep current green phase\n"
75
+ "1 = switch to the other phase"
76
+ )
77
+ try:
78
+ resp = self.client.chat.completions.create(
79
+ model=self.model,
80
+ messages=[
81
+ {"role": "system", "content": "You are a traffic signal controller. Output only 0 or 1."},
82
+ {"role": "user", "content": prompt},
83
+ ],
84
+ max_tokens=5,
85
+ temperature=0.0,
86
+ )
87
+ content = resp.choices[0].message.content.strip()
88
+ self.fallback.select_action(state) # keep step counter in sync
89
+ return 1 if "1" in content else 0
90
+ except Exception:
91
+ pass
92
+ return self.fallback.select_action(state)
93
+
94
+ def reset(self) -> None:
95
+ self.fallback.reset()
96
+
97
+
98
+ # ---------------------------------------------------------------------------
99
+ # Shared server-level env / agent (used by HTTP endpoints)
100
+ # ---------------------------------------------------------------------------
101
+
102
+ _env = TrafficEnv(get_config("medium"))
103
+ _agent = LLMAgent()
104
+
105
+
106
+ # ---------------------------------------------------------------------------
107
+ # FastAPI application
108
+ # ---------------------------------------------------------------------------
109
+
110
+ app = FastAPI(
111
+ title="Traffic Signal Optimization — OpenEnv",
112
+ description="4-way intersection RL environment · Meta × PyTorch OpenEnv Hackathon",
113
+ version="1.0.0",
114
+ )
115
+
116
+
117
+ # ── Meta / liveness ─────────────────────────────────────────────────────────
118
+
119
+ @app.get("/", response_class=HTMLResponse)
120
+ def root() -> str:
121
+ with open("index.html", "r", encoding="utf-8") as fh:
122
+ return fh.read()
123
+
124
+
125
+ # ── FIX 1: /health must return "healthy", not "ok" ──────────────────────────
126
+ @app.get("/health")
127
+ def health() -> dict:
128
+ """Liveness probe — validator strictly checks status == 'healthy'."""
129
+ return {"status": "healthy"}
130
+
131
+
132
+ # ── FIX 2: /metadata endpoint (required by openenv-core validator) ───────────
133
+ @app.get("/metadata")
134
+ def metadata() -> dict:
135
+ """Environment metadata — validator checks for 'name' and 'description' fields."""
136
+ return {
137
+ "name": "TrafficSignalOptimization-v1",
138
+ "description": (
139
+ "AI-driven Traffic Signal Optimization for a 4-way urban intersection. "
140
+ "An RL environment that minimises congestion, reduces average waiting time, "
141
+ "responds to emergency vehicles, and maintains signal stability across "
142
+ "three difficulty tiers: easy, medium, and hard."
143
+ ),
144
+ }
145
+
146
+
147
+ # ── FIX 3: /schema endpoint (required by openenv-core validator) ─────────────
148
+ @app.get("/schema")
149
+ def schema() -> dict:
150
+ """Action / observation / state schemas — all three keys required by validator."""
151
+ return {
152
+ "action": {
153
+ "type": "Discrete",
154
+ "n": 2,
155
+ "description": "0 = keep current phase, 1 = switch phase",
156
+ },
157
+ "observation": {
158
+ "type": "Dict",
159
+ "keys": [
160
+ "north_cars", "south_cars", "east_cars", "west_cars",
161
+ "waiting_times", "phase", "emergency_flags", "step_count",
162
+ ],
163
+ },
164
+ "state": {
165
+ "type": "Dict",
166
+ "keys": [
167
+ "north_cars", "south_cars", "east_cars", "west_cars",
168
+ "waiting_times", "phase", "emergency_flags", "step_count",
169
+ ],
170
+ },
171
+ }
172
+
173
+
174
+ # ── FIX 4: /mcp endpoint (required by openenv-core validator) ────────────────
175
+ @app.post("/mcp")
176
+ def mcp(request: dict = {}) -> dict:
177
+ """JSON-RPC 2.0 stub — validator checks jsonrpc == '2.0'."""
178
+ return {"jsonrpc": "2.0", "id": None, "result": {"status": "ok"}}
179
+
180
+
181
+ @app.get("/tasks")
182
+ def list_tasks() -> dict:
183
+ """Enumerate the 3 difficulty tasks for the validator."""
184
+ return {
185
+ "tasks": [
186
+ {
187
+ "id": "easy",
188
+ "description": "Stable low-volume traffic, rare emergencies (1%)",
189
+ "max_steps": 50,
190
+ "arrival_rate": [0, 1],
191
+ "emergency_prob": 0.01,
192
+ },
193
+ {
194
+ "id": "medium",
195
+ "description": "Moderate traffic with 10% burst events, 5% emergency",
196
+ "max_steps": 100,
197
+ "arrival_rate": [1, 3],
198
+ "emergency_prob": 0.05,
199
+ },
200
+ {
201
+ "id": "hard",
202
+ "description": "High-intensity traffic, 20% bursts, 15% emergency, strict fairness",
203
+ "max_steps": 200,
204
+ "arrival_rate": [2, 5],
205
+ "emergency_prob": 0.15,
206
+ },
207
+ ]
208
+ }
209
+
210
+
211
+ # ── Core OpenEnv API ─────────────────────────────────────────────────────────
212
+
213
+ @app.post("/reset")
214
+ def reset_env() -> dict:
215
+ state = _env.reset()
216
+ _agent.reset()
217
+ return {"state": state}
218
+
219
+
220
+ class Action(BaseModel):
221
+ action: int
222
+
223
+
224
+ @app.post("/step")
225
+ def step_env(data: Action) -> dict:
226
+ state, reward, done, info = _env.step(data.action)
227
+ score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
228
+ return {"state": state, "reward": reward, "score": score, "done": done, "info": info}
229
+
230
+
231
+ @app.get("/state")
232
+ def get_state() -> dict:
233
+ """
234
+ Return current environment state.
235
+ Required by OpenEnv spec (the reset / step / state triple).
236
+ """
237
+ return {"state": _env.get_state()}
238
+
239
+
240
+ # ── Convenience endpoints ────────────────────────────────────────────────────
241
+
242
+ @app.post("/auto_step")
243
+ def auto_step() -> dict:
244
+ state_dict = _env.get_state()
245
+ action = _agent.select_action(state_dict)
246
+ state, reward, done, info = _env.step(action)
247
+ score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
248
+ return {"state": state, "reward": reward, "score": score,
249
+ "done": done, "info": info, "action_taken": action}
250
+
251
+
252
+ @app.post("/grader")
253
+ def grader() -> dict:
254
+ """
255
+ Run the rule-based baseline on all 3 tasks and return per-task scores
256
+ normalised to open interval (0, 1) as required by the validator.
257
+ """
258
+ results: dict = {}
259
+ for task_id in ("easy", "medium", "hard"):
260
+ cfg = get_config(task_id)
261
+ eval_env = TrafficEnv(cfg)
262
+ agent = RuleBasedAgent()
263
+ state = eval_env.reset()
264
+ agent.reset()
265
+
266
+ total_reward = 0.0
267
+ steps = 0
268
+ done = False
269
+
270
+ while not done:
271
+ action = agent.select_action(state)
272
+ state, reward, done, info = eval_env.step(action)
273
+ total_reward += reward
274
+ steps += 1
275
+
276
+ mean_reward = total_reward / max(1, steps)
277
+ score = round(max(0.001, min(0.999, (mean_reward + 1.0) / 2.0)), 6)
278
+ results[task_id] = {
279
+ "score": score,
280
+ "steps": steps,
281
+ "total_reward": round(total_reward, 4),
282
+ "info": info,
283
+ }
284
+ return results
285
+
286
+
287
+ # ---------------------------------------------------------------------------
288
+ # CLI entry-point — produces structured stdout for the OpenEnv validator
289
+ # ---------------------------------------------------------------------------
290
+
291
+ if __name__ == "__main__":
292
+ tasks_to_run = ["easy", "medium", "hard"]
293
+
294
+ if len(sys.argv) > 1:
295
+ raw = sys.argv[1].replace("--task=", "").replace("--task", "").strip()
296
+ if raw in tasks_to_run:
297
+ tasks_to_run = [raw]
298
+
299
+ for task_name in tasks_to_run:
300
+ config = get_config(task_name)
301
+ eval_env = TrafficEnv(config)
302
+ eval_agent = LLMAgent()
303
+
304
+ state = eval_env.reset()
305
+ eval_agent.reset()
306
+
307
+ print("[START]", flush=True)
308
+
309
+ done = False
310
+ step_idx = 0
311
+ total_reward = 0.0
312
+
313
+ while not done:
314
+ action = eval_agent.select_action(state)
315
+ state, reward, done, info = eval_env.step(action)
316
+ total_reward += reward
317
+
318
+ # score: reward normalised to open interval (0, 1)
319
+ score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
320
+
321
+ print(
322
+ f"[STEP] step={step_idx}, score={score}, "
323
+ f"reward={round(reward, 6)}, done={done}",
324
+ flush=True,
325
+ )
326
+ step_idx += 1
327
+
328
+ print("[END]", flush=True)
openenv.yaml ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "1.0"
2
+ name: "TrafficSignalOptimization-v1"
3
+ description: >
4
+ AI-driven Traffic Signal Optimization for a 4-way urban intersection.
5
+ A reinforcement-learning environment that challenges agents to minimise
6
+ congestion, reduce average waiting time, respond to emergency vehicles,
7
+ and maintain signal stability across three difficulty tiers.
8
+
9
+ author: "OpenEnv Submission"
10
+ tags:
11
+ - Reinforcement Learning
12
+ - Traffic Control
13
+ - Smart Cities
14
+ - Safety-Critical
15
+ - Emergency Vehicle Priority
16
+ licence: MIT
17
+
18
+ # ─────────────────────────────────────────────────────────────────────
19
+ # Environment specification
20
+ # ─────────────────────────────────────────────────────────────────────
21
+ environment:
22
+ class: "env.TrafficEnv"
23
+ entry_point: "env:TrafficEnv"
24
+
25
+ state_space:
26
+ type: Dict
27
+ keys:
28
+ north_cars:
29
+ type: Discrete
30
+ description: "Queued vehicles in the North lane"
31
+ range: [0, max_queue]
32
+ south_cars:
33
+ type: Discrete
34
+ description: "Queued vehicles in the South lane"
35
+ range: [0, max_queue]
36
+ east_cars:
37
+ type: Discrete
38
+ description: "Queued vehicles in the East lane"
39
+ range: [0, max_queue]
40
+ west_cars:
41
+ type: Discrete
42
+ description: "Queued vehicles in the West lane"
43
+ range: [0, max_queue]
44
+ waiting_times:
45
+ type: "Dict[str, float]"
46
+ description: "Cumulative waiting-time pressure per lane (north/south/east/west)"
47
+ phase:
48
+ type: Discrete
49
+ values: [0, 1]
50
+ description: "Current green signal: 0 = NS green, 1 = EW green"
51
+ emergency_flags:
52
+ type: "Dict[str, bool]"
53
+ description: "True if an emergency vehicle is present in that lane"
54
+ step_count:
55
+ type: Discrete
56
+ description: "Current step within the episode"
57
+ range: [0, max_steps]
58
+
59
+ action_space:
60
+ type: Discrete
61
+ n: 2
62
+ actions:
63
+ 0: "Keep current signal phase"
64
+ 1: "Switch signal phase (NS ↔ EW)"
65
+
66
+ observation_vector_dim: 14
67
+ # Layout: [N, S, E, W queues | N, S, E, W waits | N, S, E, W EV flags | phase, step]
68
+
69
+ # ─────────────────────────────────────────────────────────────────────
70
+ # Tasks (3 required — validator enumerates and scores each one)
71
+ # ─────────────────────────────────────────────────────────────────────
72
+ tasks:
73
+ - id: easy
74
+ description: "Stable, balanced traffic. Minimal emergencies. Ideal for learning."
75
+ config_key: easy
76
+ max_steps: 50
77
+ score_range: [0.0, 1.0] # open interval (0,1) enforced by grader
78
+ params:
79
+ arrival_rate: [0, 1]
80
+ discharge_rate: [4, 5]
81
+ max_queue: 15
82
+ emergency_prob: 0.01
83
+ burst_prob: 0.0
84
+
85
+ - id: medium
86
+ description: "Random traffic bursts, moderate congestion, occasional emergencies."
87
+ config_key: medium
88
+ max_steps: 100
89
+ score_range: [0.0, 1.0]
90
+ params:
91
+ arrival_rate: [1, 3]
92
+ discharge_rate: [3, 5]
93
+ max_queue: 25
94
+ emergency_prob: 0.05
95
+ burst_prob: 0.10
96
+
97
+ - id: hard
98
+ description: "High-intensity traffic, frequent emergencies, strict fairness constraints."
99
+ config_key: hard
100
+ max_steps: 200
101
+ score_range: [0.0, 1.0]
102
+ params:
103
+ arrival_rate: [2, 5]
104
+ discharge_rate: [2, 4]
105
+ max_queue: 40
106
+ emergency_prob: 0.15
107
+ burst_prob: 0.20
108
+
109
+ # ─────────────────────────────────────────────────────────────────────
110
+ # Reward design (multi-component, clipped to (-0.999, +0.999))
111
+ # Score = (reward + 1) / 2, always in open interval (0, 1)
112
+ # ─────────────────────────────────────────────────────────────────────
113
+ reward:
114
+ range: [-0.999, 0.999]
115
+ score_normalisation: "(reward + 1) / 2 → (0.0005, 0.9995)"
116
+ components:
117
+ efficiency:
118
+ sign: "+"
119
+ description: "Vehicles cleared this step (throughput reward)"
120
+ congestion:
121
+ sign: "-"
122
+ description: "Normalised total queue density"
123
+ max_queue_penalty:
124
+ sign: "-"
125
+ description: "Penalty for extreme bottlenecks in any single lane"
126
+ switch_penalty:
127
+ sign: "-"
128
+ description: "Stability constraint to prevent oscillatory signal toggling"
129
+ improvement_bonus:
130
+ sign: "+"
131
+ description: "Bonus for active decongestion progress"
132
+ fairness_bonus:
133
+ sign: "+"
134
+ description: "Reward for maintaining balanced waiting times across all lanes"
135
+ starvation_penalty:
136
+ sign: "-"
137
+ description: "Penalty for phase-duration exceeding starvation limit"
138
+ emergency_golden_window:
139
+ sign: "+"
140
+ description: "Full bonus for clearing EV within golden window steps"
141
+ emergency_delay:
142
+ sign: "-"
143
+ description: "Exponential penalty for delaying life-saving vehicles"
144
+
145
+ # ─────────────────────────────────────────────────────────────────────
146
+ # Evaluation metrics (returned in info dict on every step)
147
+ # ─────────────────────────────────────────────────────────────────────
148
+ metrics:
149
+ total_cleared:
150
+ type: int
151
+ description: "Total vehicles discharged from the intersection (episode)"
152
+ avg_waiting_time:
153
+ type: float
154
+ description: "Cumulative wait pressure divided by vehicles cleared"
155
+ max_queue_length:
156
+ type: int
157
+ description: "Peak queue length observed in any lane (episode)"
158
+ signal_switch_count:
159
+ type: int
160
+ description: "Total signal changes (lower = more stable)"
161
+ congestion_score:
162
+ type: float
163
+ range: [0.001, 0.999]
164
+ description: "Current normalised total queue depth"
165
+ avg_ev_clear_time:
166
+ type: float
167
+ description: "Average steps taken to clear an emergency vehicle"
168
+ fairness_score:
169
+ type: float
170
+ range: [0.001, 0.999]
171
+ description: "Index representing lane-level service balance"
172
+
173
+ # ─────────────────────────────────────────────────────────────────────
174
+ # Baseline agent
175
+ # ─────────────────────────────────────────────────────────────────────
176
+ baseline:
177
+ class: "baseline_agent.RuleBasedAgent"
178
+ description: >
179
+ Deterministic rule-based agent. Switches based on queue imbalance,
180
+ minimum green time, starvation guard, and emergency preemption.
181
+ parameters:
182
+ min_green_time: 5
183
+ imbalance_threshold: 5
184
+ max_green_time: 15
185
+ emergency_min_green: 2
186
+
187
+ # ─────────────────────────────────────────────────────────────────────
188
+ # HTTP API (OpenEnv spec: reset / step / state)
189
+ # ─────────────────────────────────────────────────────────────────────
190
+ api:
191
+ reset: {method: POST, path: /reset, description: "Start a new episode"}
192
+ step: {method: POST, path: /step, description: "Advance one step"}
193
+ state: {method: GET, path: /state, description: "Get current state"}
194
+ tasks: {method: GET, path: /tasks, description: "List all tasks"}
195
+ grader: {method: POST, path: /grader, description: "Run baseline grader"}
196
+ health: {method: GET, path: /health, description: "Liveness probe"}
197
+
198
+ # ─────────────────────────────────────────────────────────────────────
199
+ # Project files
200
+ # ─────────────────────────────────────────────────────────────────────
201
+ project_structure:
202
+ - env.py: "Core TrafficEnv class"
203
+ - tasks.py: "Easy / Medium / Hard configuration dicts"
204
+ - baseline_agent.py: "Rule-based baseline agent"
205
+ - inference.py: "FastAPI server + LLM agent + CLI validator script"
206
+ - test_env.py: "Simulation runner and correctness checks"
207
+ - openenv.yaml: "This file — environment specification"
208
+ - README.md: "Full documentation"
pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "traffic-signal-openenv"
3
+ version = "0.1.0"
4
+ description = "Traffic Signal Optimization - OpenEnv Elite"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "fastapi>=0.100.0",
9
+ "uvicorn>=0.20.0",
10
+ "numpy>=1.20.0",
11
+ "pydantic>=2.0.0",
12
+ "openenv-core>=0.2.0",
13
+ "openai>=1.0.0",
14
+ ]
15
+
16
+ [project.scripts]
17
+ server = "server.app:main"
18
+
19
+ [build-system]
20
+ requires = ["hatchling"]
21
+ build-backend = "hatchling.build"
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ numpy
4
+ pydantic
5
+ openai
6
+ openenv-core>=0.2.0
server/app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import uvicorn
4
+
5
+ # Add the parent directory to sys.path so 'inference.py' can be imported and env modules
6
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
+
8
+ from inference import app
9
+
10
+ def main():
11
+ uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
12
+
13
+ if __name__ == "__main__":
14
+ main()
tasks.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tasks.py — Difficulty Configurations for TrafficEnv
3
+ =====================================================
4
+
5
+ Three pre-defined task configurations:
6
+
7
+ EASY_CONFIG – Stable, balanced traffic; good for initial training.
8
+ MEDIUM_CONFIG – Random bursts, moderate congestion; standard benchmark.
9
+ HARD_CONFIG – High intensity, frequent emergencies, strict fairness.
10
+
11
+ Each config is a plain dict consumed by TrafficEnv.__init__().
12
+ """
13
+
14
+ from __future__ import annotations
15
+ from typing import Any, Dict
16
+
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Easy
20
+ # ---------------------------------------------------------------------------
21
+
22
+ EASY_CONFIG: Dict[str, Any] = {
23
+ # Traffic flow
24
+ "arrival_rate": (0, 1), # 0–1 cars per lane per step
25
+ "discharge_rate": (4, 5), # 4–5 cars discharged per green lane per step
26
+ "max_queue": 15, # queue cap per lane
27
+ "max_steps": 50,
28
+
29
+ # Emergencies — rare
30
+ "emergency_prob": 0.01,
31
+
32
+ # Bursts — none
33
+ "burst_prob": 0.0,
34
+ "burst_multiplier": 1.0,
35
+
36
+ # Reward knobs
37
+ "switch_penalty": 0.10,
38
+ "starvation_threshold": 20,
39
+ "r_efficiency_scale": 0.20,
40
+ "p_congestion_scale": 0.30,
41
+ "p_max_q_scale": 0.10,
42
+ "p_starvation_scale": 0.10,
43
+ "r_fairness_bonus": 0.05,
44
+ "r_improvement_bonus": 0.15,
45
+ "p_emergency_scale": 0.30,
46
+ "r_ev_bonus_scale": 0.20,
47
+
48
+ # Logic thresholds
49
+ "ev_golden_window": 8, # Easy: very generous window
50
+ "ev_max_delay": 20,
51
+ }
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Medium
55
+ # ---------------------------------------------------------------------------
56
+
57
+ MEDIUM_CONFIG: Dict[str, Any] = {
58
+ # Traffic flow
59
+ "arrival_rate": (1, 3), # moderate, variable arrivals
60
+ "discharge_rate": (3, 5), # standard discharge
61
+ "max_queue": 25,
62
+ "max_steps": 100,
63
+
64
+ # Emergencies — occasional
65
+ "emergency_prob": 0.05,
66
+
67
+ # Random bursts — 10% chance, 1.5× arrivals
68
+ "burst_prob": 0.10,
69
+ "burst_multiplier": 1.5,
70
+
71
+ # Reward knobs
72
+ "switch_penalty": 0.20,
73
+ "starvation_threshold": 15,
74
+ "r_efficiency_scale": 0.20,
75
+ "p_congestion_scale": 0.40,
76
+ "p_max_q_scale": 0.15,
77
+ "p_starvation_scale": 0.15,
78
+ "r_fairness_bonus": 0.10,
79
+ "r_improvement_bonus": 0.20,
80
+ "p_emergency_scale": 0.40,
81
+ "r_ev_bonus_scale": 0.25,
82
+
83
+ # Logic thresholds
84
+ "ev_golden_window": 5, # Medium: standard window
85
+ "ev_max_delay": 15,
86
+ }
87
+
88
+ # ---------------------------------------------------------------------------
89
+ # Hard
90
+ # ---------------------------------------------------------------------------
91
+
92
+ HARD_CONFIG: Dict[str, Any] = {
93
+ # Traffic flow — high intensity
94
+ "arrival_rate": (2, 5), # heavy, bursty arrivals
95
+ "discharge_rate": (2, 4), # reduced discharge (lane friction)
96
+ "max_queue": 40,
97
+ "max_steps": 200,
98
+
99
+ # Emergencies — frequent
100
+ "emergency_prob": 0.15,
101
+
102
+ # Frequent aggressive bursts
103
+ "burst_prob": 0.20,
104
+ "burst_multiplier": 2.0,
105
+
106
+ # Reward knobs — stricter penalties
107
+ "switch_penalty": 0.30,
108
+ "starvation_threshold": 10, # stricter fairness
109
+ "r_efficiency_scale": 0.25,
110
+ "p_congestion_scale": 0.50,
111
+ "p_max_q_scale": 0.20,
112
+ "p_starvation_scale": 0.20,
113
+ "r_fairness_bonus": 0.15,
114
+ "r_improvement_bonus": 0.25,
115
+ "p_emergency_scale": 0.60, # amplified emergency penalty
116
+ "r_ev_bonus_scale": 0.30,
117
+
118
+ # Logic thresholds
119
+ "ev_golden_window": 3, # Hard: must clear immediately
120
+ "ev_max_delay": 10,
121
+ }
122
+
123
+
124
+ # ---------------------------------------------------------------------------
125
+ # Accessor
126
+ # ---------------------------------------------------------------------------
127
+
128
+ _CONFIGS = {
129
+ "easy": EASY_CONFIG,
130
+ "medium": MEDIUM_CONFIG,
131
+ "hard": HARD_CONFIG,
132
+ }
133
+
134
+
135
+ def get_config(mode: str) -> Dict[str, Any]:
136
+ """
137
+ Return the config dict for the requested difficulty mode.
138
+
139
+ Parameters
140
+ ----------
141
+ mode : str
142
+ One of "easy", "medium", "hard" (case-insensitive).
143
+
144
+ Returns
145
+ -------
146
+ dict
147
+ Configuration dictionary suitable for ``TrafficEnv(config)``.
148
+
149
+ Raises
150
+ ------
151
+ ValueError
152
+ If an unknown mode is requested.
153
+ """
154
+ key = mode.strip().lower()
155
+ if key not in _CONFIGS:
156
+ raise ValueError(
157
+ f"Unknown difficulty mode '{mode}'. "
158
+ f"Choose one of: {list(_CONFIGS)}"
159
+ )
160
+ # Return a copy so callers can mutate without side-effects
161
+ return dict(_CONFIGS[key])
test_env.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ test_env.py — Simulation Runner & Sanity Tests
3
+ ================================================
4
+
5
+ Provides two entry-points:
6
+
7
+ run_simulation(mode) – Run one full episode and print a formatted report.
8
+ run_all() – Run all three difficulty modes and compare.
9
+ run_sanity_checks() – Fast correctness assertions (no pytest needed).
10
+
11
+ Usage
12
+ -----
13
+ python test_env.py # runs all modes + sanity checks
14
+ python test_env.py easy # run a single mode
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import sys
20
+ import builtins
21
+ from typing import Dict, Any
22
+
23
+ from env import TrafficEnv
24
+ from tasks import get_config
25
+ from baseline_agent import RuleBasedAgent
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Helpers
30
+ # ---------------------------------------------------------------------------
31
+
32
+ _COL = 80 # separator width
33
+
34
+
35
+ def _separator(char: str = "─") -> str:
36
+ return char * _COL
37
+
38
+
39
+ _ASCII_FALLBACKS = (
40
+ ("\u2550", "="),
41
+ ("\u2500", "-"),
42
+ ("\u2502", "|"),
43
+ ("\u00b7", "-"),
44
+ ("\U0001F6A8", "EV"),
45
+ ("\u2713", "PASS"),
46
+ ("\u2717", "FAIL"),
47
+ ("\u26a0\ufe0f", "WARNING"),
48
+ ("\u2705", "PASS"),
49
+ ("\u2014", "-"),
50
+ ("\u2265", ">="),
51
+ ("\u2264", "<="),
52
+ ("\u2208", "in"),
53
+ )
54
+
55
+
56
+ def _safe_text(text: str) -> str:
57
+ encoding = getattr(sys.stdout, "encoding", None) or "utf-8"
58
+ try:
59
+ text.encode(encoding)
60
+ return text
61
+ except UnicodeEncodeError:
62
+ for src, dest in _ASCII_FALLBACKS:
63
+ text = text.replace(src, dest)
64
+ return text
65
+
66
+
67
+ def print(*args, **kwargs) -> None: # type: ignore[override]
68
+ """
69
+ Safe local print wrapper:
70
+ - keeps rich Unicode output when supported
71
+ - falls back to ASCII-safe glyphs on limited encodings (e.g. cp1252)
72
+ """
73
+ file = kwargs.get("file", sys.stdout)
74
+ if file is not sys.stdout:
75
+ builtins.print(*args, **kwargs)
76
+ return
77
+
78
+ sep = kwargs.get("sep", " ")
79
+ end = kwargs.get("end", "\n")
80
+ flush = kwargs.get("flush", False)
81
+ text = sep.join(str(arg) for arg in args)
82
+ builtins.print(_safe_text(text), end=end, flush=flush, file=file)
83
+
84
+
85
+ def _fmt_metric(key: str, value: Any) -> str:
86
+ label = key.replace("_", " ").title()
87
+ if isinstance(value, float):
88
+ return f" {label:<30} {value:.4f}"
89
+ return f" {label:<30} {value}"
90
+
91
+
92
+ # ---------------------------------------------------------------------------
93
+ # Single-mode simulation
94
+ # ---------------------------------------------------------------------------
95
+
96
+ def run_simulation(mode: str = "medium", verbose: bool = True) -> Dict[str, Any]:
97
+ """
98
+ Run one complete episode in the specified difficulty mode.
99
+
100
+ Parameters
101
+ ----------
102
+ mode : str
103
+ "easy", "medium", or "hard"
104
+ verbose : bool
105
+ Print step-by-step output if True.
106
+
107
+ Returns
108
+ -------
109
+ dict
110
+ Final info metrics plus 'cumulative_reward' and 'mode'.
111
+ """
112
+ config = get_config(mode)
113
+ env = TrafficEnv(config)
114
+ agent = RuleBasedAgent(
115
+ min_green_time=5,
116
+ imbalance_threshold=5,
117
+ max_green_time=15,
118
+ emergency_min_green=2,
119
+ )
120
+
121
+ state = env.reset()
122
+ agent.reset()
123
+ done = False
124
+ total_reward = 0.0
125
+ step_rewards = []
126
+
127
+ if verbose:
128
+ print()
129
+ print(_separator("═"))
130
+ print(f" TRAFFIC SIGNAL SIMULATION · Mode: {mode.upper()}")
131
+ print(_separator("═"))
132
+ header = (
133
+ f"{'Step':<6} │ {'Phase':<4} │ "
134
+ f"{'N':>4} {'S':>4} {'E':>4} {'W':>4} │ "
135
+ f"{'NS':>4} {'EW':>4} │ "
136
+ f"{'Reward':>8} │ EV"
137
+ )
138
+ print(header)
139
+ print(_separator())
140
+
141
+ while not done:
142
+ action = agent.select_action(state)
143
+ next_state, reward, done, info = env.step(action)
144
+ total_reward += reward
145
+ step_rewards.append(reward)
146
+
147
+ if verbose:
148
+ phase_str = "NS" if next_state["phase"] == 0 else "EW"
149
+ ns_q = next_state["north_cars"] + next_state["south_cars"]
150
+ ew_q = next_state["east_cars"] + next_state["west_cars"]
151
+ ev_flags = next_state["emergency_flags"]
152
+ ev_active = "🚨" if any(ev_flags.values()) else " "
153
+
154
+ # Print every 5 steps, or whenever there's an emergency
155
+ if env.step_count % 5 == 0 or any(ev_flags.values()):
156
+ print(
157
+ f"{env.step_count:<6} │ {phase_str:<4} │ "
158
+ f"{next_state['north_cars']:>4} "
159
+ f"{next_state['south_cars']:>4} "
160
+ f"{next_state['east_cars']:>4} "
161
+ f"{next_state['west_cars']:>4} │ "
162
+ f"{ns_q:>4} {ew_q:>4} │ "
163
+ f"{reward:>8.3f} │ {ev_active}"
164
+ )
165
+
166
+ state = next_state
167
+
168
+ if verbose:
169
+ print(_separator())
170
+ print(f"\n FINAL METRICS ({mode.upper()})")
171
+ print(_separator())
172
+ for k, v in info.items():
173
+ print(_fmt_metric(k, v))
174
+ print(_fmt_metric("cumulative_reward", total_reward))
175
+ if step_rewards:
176
+ print(_fmt_metric("min_step_reward", min(step_rewards)))
177
+ print(_fmt_metric("max_step_reward", max(step_rewards)))
178
+ print()
179
+
180
+ result = dict(info)
181
+ result["cumulative_reward"] = total_reward
182
+ result["mode"] = mode
183
+ return result
184
+
185
+
186
+ # ---------------------------------------------------------------------------
187
+ # Run all modes and print comparison table
188
+ # ---------------------------------------------------------------------------
189
+
190
+ def run_all() -> None:
191
+ """Run easy, medium and hard in sequence; print a comparison table."""
192
+ results = {}
193
+ for mode in ("easy", "medium", "hard"):
194
+ results[mode] = run_simulation(mode, verbose=True)
195
+
196
+ print()
197
+ print(_separator("═"))
198
+ print(" CROSS-MODE COMPARISON")
199
+ print(_separator("═"))
200
+ metrics = [
201
+ "total_cleared", "avg_waiting_time",
202
+ "max_queue_length", "signal_switch_count",
203
+ "congestion_score", "avg_ev_clear_time",
204
+ "fairness_score", "cumulative_reward",
205
+ ]
206
+ col_w = 18
207
+ header = f" {'Metric':<30}" + "".join(f"{m.upper():>{col_w}}" for m in ("easy", "medium", "hard"))
208
+ print(header)
209
+ print(_separator())
210
+ for m in metrics:
211
+ row = f" {m.replace('_',' ').title():<30}"
212
+ for mode in ("easy", "medium", "hard"):
213
+ val = results[mode].get(m, "—")
214
+ if isinstance(val, float):
215
+ row += f"{val:>{col_w}.3f}"
216
+ else:
217
+ row += f"{val:>{col_w}}"
218
+ print(row)
219
+ print(_separator("═"))
220
+ print()
221
+
222
+
223
+ # ---------------------------------------------------------------------------
224
+ # Sanity / correctness checks (no external test runner needed)
225
+ # ---------------------------------------------------------------------------
226
+
227
+ def run_sanity_checks() -> None:
228
+ """Assert basic correctness invariants for all difficulty modes."""
229
+ print()
230
+ print(_separator("═"))
231
+ print(" SANITY CHECKS")
232
+ print(_separator("═"))
233
+
234
+ passed = 0
235
+ failed = 0
236
+
237
+ def check(name: str, condition: bool) -> None:
238
+ nonlocal passed, failed
239
+ status = "✓ PASS" if condition else "✗ FAIL"
240
+ print(f" [{status}] {name}")
241
+ if condition:
242
+ passed += 1
243
+ else:
244
+ failed += 1
245
+
246
+ for mode in ("easy", "medium", "hard"):
247
+ cfg = get_config(mode)
248
+ env = TrafficEnv(cfg)
249
+ agent = RuleBasedAgent()
250
+
251
+ # 1. reset() returns valid state
252
+ state = env.reset()
253
+ agent.reset()
254
+ check(
255
+ f"[{mode}] reset() returns all-zero queues",
256
+ all(state[f"{d}_cars"] == 0 for d in ("north", "south", "east", "west")),
257
+ )
258
+
259
+ # 2. Step returns correct tuple length
260
+ action = agent.select_action(state)
261
+ result = env.step(action)
262
+ check(f"[{mode}] step() returns 4-tuple", len(result) == 4)
263
+
264
+ ns, reward, done, info = result
265
+
266
+ # 3. Reward is clipped
267
+ check(f"[{mode}] reward in [-1, 1]", -1.0 <= reward <= 1.0)
268
+
269
+ # 4. State keys present
270
+ required_keys = {
271
+ "north_cars", "south_cars", "east_cars", "west_cars",
272
+ "waiting_times", "phase", "emergency_flags", "step_count",
273
+ }
274
+ check(f"[{mode}] state has required keys", required_keys.issubset(ns.keys()))
275
+
276
+ # 5. Info keys present
277
+ required_info = {
278
+ "total_cleared", "avg_waiting_time",
279
+ "max_queue_length", "signal_switch_count",
280
+ "congestion_score", "avg_ev_clear_time",
281
+ "fairness_score",
282
+ }
283
+ check(f"[{mode}] info has required keys", required_info.issubset(info.keys()))
284
+
285
+ # 6. Queues never go negative
286
+ for _ in range(cfg["max_steps"]):
287
+ a = agent.select_action(ns)
288
+ ns, _, done, _ = env.step(a)
289
+ if done:
290
+ break
291
+ all_non_neg = all(v >= 0 for v in env.queues.values())
292
+ check(f"[{mode}] queues never go negative (full episode)", all_non_neg)
293
+
294
+ # 7. Queues never exceed max_queue
295
+ check(
296
+ f"[{mode}] queues never exceed max_queue ({cfg['max_queue']})",
297
+ all(v <= cfg["max_queue"] for v in env.queues.values()),
298
+ )
299
+
300
+ # 8. Signal phase is always 0 or 1
301
+ check(f"[{mode}] phase is always 0 or 1", env.phase in (0, 1))
302
+
303
+ # 9. total_cleared is non-negative
304
+ check(f"[{mode}] total_cleared ≥ 0", env.total_cleared >= 0)
305
+
306
+ # 10. congestion_score in [0, 1]
307
+ score = info["congestion_score"]
308
+ check(f"[{mode}] congestion_score ∈ [0, 1]", 0.0 <= score <= 1.0)
309
+
310
+ print()
311
+
312
+ print(_separator())
313
+ print(f" Results: {passed} passed, {failed} failed")
314
+ print(_separator("═"))
315
+ if failed:
316
+ print(" ⚠️ Some checks failed — review the environment logic.")
317
+ else:
318
+ print(" ✅ All sanity checks passed.")
319
+ print()
320
+
321
+
322
+ # ---------------------------------------------------------------------------
323
+ # CLI entry-point
324
+ # ---------------------------------------------------------------------------
325
+
326
+ if __name__ == "__main__":
327
+ if len(sys.argv) == 2 and sys.argv[1].lower() in ("easy", "medium", "hard"):
328
+ run_simulation(sys.argv[1].lower(), verbose=True)
329
+ else:
330
+ run_all()
331
+ run_sanity_checks()
test_inference.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from env import TrafficEnv
2
+ from tasks import EASY_CONFIG
3
+
4
+ env = TrafficEnv(EASY_CONFIG)
5
+
6
+ print("[START]")
7
+
8
+ state = env.reset()
9
+
10
+ done = False
11
+ step_count = 0
12
+
13
+ while not done:
14
+ action = 0
15
+ next_state, reward, done, info = env.step(action)
16
+ print(f"[STEP] step={step_count}, reward={reward}, done={done}")
17
+ step_count += 1
18
+
19
+ print("[END]")
uv.lock ADDED
The diff for this file is too large to render. See raw diff