Spaces:
Sleeping
Sleeping
| """ | |
| ─────────────────────────────────────────────────────────────────────────────── | |
| JaamCTRLTrafficEnv — core Gymnasium environment loop. | |
| This class owns: | |
| - SUMO subprocess lifecycle (_launch_sumo, _close_sumo) | |
| - Action application and yellow-phase safety enforcement | |
| - Delegating to observation.py, reward.py, incident_manager.py | |
| - Episode metrics accumulation and success checking | |
| - The three Gymnasium contract methods: reset(), step(), state() | |
| It does NOT define task configs, reward coefficients, or observation math — | |
| those all live in their respective modules. | |
| ─────────────────────────────────────────────────────────────────────────────── | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import subprocess | |
| import sys | |
| import time | |
| from collections import deque | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import gymnasium as gym | |
| from gymnasium import spaces | |
| # ── TraCI import ───────────────────────────────────────────────────────────── | |
| if "SUMO_HOME" in os.environ: | |
| sys.path.append(os.path.join(os.environ["SUMO_HOME"], "tools")) | |
| try: | |
| import traci | |
| TRACI_AVAILABLE = True | |
| except ImportError: | |
| TRACI_AVAILABLE = False | |
| # ── Intra-package imports ───────────────────────────────────────────────────── | |
| from env import ( | |
| TASK_CONFIGS, | |
| YELLOW_PHASES, | |
| YELLOW_DURATION, | |
| MIN_GREEN_S, | |
| ) | |
| from env.reward import compute_reward, reward_breakdown | |
| from env.observation import ( | |
| collect_telemetry, | |
| mock_telemetry, | |
| build_obs, | |
| OBS_DIM, | |
| ) | |
| from env.incident_manager import IncidentManager | |
| logger = logging.getLogger("JaamCTRL.BaseEnv") | |
| class JaamCTRLTrafficEnv(gym.Env): | |
| """ | |
| Gymnasium-compatible adaptive traffic signal control environment. | |
| Supports 3 progressive difficulty tasks via `task_id`. | |
| Designed for use with the OpenEnv judging harness and stable-baselines3. | |
| Parameters | |
| ---------- | |
| task_id : 1 = Easy, 2 = Medium, 3 = Hard | |
| sumo_cfg_path : path to .sumocfg (default "sumo/corridor.sumocfg") | |
| use_gui : launch sumo-gui (True) or headless sumo (False) | |
| port : TraCI port; 0 = auto-select a free port | |
| seed : RNG seed for full reproducibility | |
| mock_sumo : skip SUMO entirely, use synthetic observations | |
| (useful for CI / unit tests without SUMO installed) | |
| """ | |
| metadata = {"render_modes": ["human", "none"], "render_fps": 10} | |
| # ── Construction ───────────────────────────────────────────────────────── | |
| def __init__( | |
| self, | |
| task_id: int = 1, | |
| sumo_cfg_path: str = "sumo/corridor.sumocfg", | |
| use_gui: bool = False, | |
| port: int = 0, | |
| seed: Optional[int] = None, | |
| mock_sumo: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| assert task_id in (1, 2, 3), f"task_id must be 1, 2 or 3; got {task_id}" | |
| self.task_id = task_id | |
| self.cfg = TASK_CONFIGS[task_id] | |
| self.use_gui = use_gui | |
| self.port = port | |
| self.mock_sumo = mock_sumo or not TRACI_AVAILABLE | |
| self.sumo_cfg = Path(sumo_cfg_path) | |
| self._rng = np.random.default_rng(seed) | |
| self.n_tl = self.cfg["active_intersections"] | |
| # ── Spaces ──────────────────────────────────────────────────────── | |
| # Action: phase index per active TL (0–3 each) | |
| self.action_space = spaces.MultiDiscrete([4] * self.n_tl) | |
| # Observation: Dict with a pre-flattened "flat" key for PPO | |
| self.observation_space = spaces.Dict({ | |
| "queue_lengths": spaces.Box(0.0, 50.0, shape=(3, 4), dtype=np.float32), | |
| "current_phase": spaces.MultiDiscrete([4, 4, 4]), | |
| "phase_elapsed": spaces.Box(0.0, 120.0, shape=(3,), dtype=np.float32), | |
| "probe_density": spaces.Box(0.0, 1.0, shape=(3, 8), dtype=np.float32), | |
| "incident_flag": spaces.MultiBinary(3), | |
| "time_of_day_norm": spaces.Box(0.0, 1.0, shape=(1,), dtype=np.float32), | |
| "flat": spaces.Box(-1.0, 50.0, shape=(OBS_DIM,), dtype=np.float32), | |
| }) | |
| # ── Internal state ──────────────────────────────────────────────── | |
| self._step_count = 0 | |
| self._sim_time_s = 0.0 | |
| self._sumo_process = None | |
| self._phase_elapsed = np.zeros(3, dtype=np.float32) | |
| self._current_phases = np.zeros(3, dtype=np.int32) | |
| self._episode_throughput = np.zeros(3, dtype=np.float32) | |
| self._episode_delay_sum = 0.0 | |
| self._episode_stops = 0 | |
| self._overflow_events = 0 | |
| self._last_telemetry: Dict[str, Any] = {} | |
| # Phase history for green-wave computation (observation.py) | |
| # Entries: (sim_time_s: float, tl_index: int, phase: int) | |
| self._phase_history: deque = deque(maxlen=200) | |
| # Baseline metrics for episode_summary comparison | |
| self._baseline_avg_delay: Optional[float] = None | |
| self._baseline_throughput: Optional[float] = None | |
| # Incident manager — owns all chaos events | |
| self._incident_mgr = IncidentManager( | |
| cfg=self.cfg, | |
| rng=self._rng, | |
| mock_sumo=self.mock_sumo, | |
| ) | |
| logger.info( | |
| "JaamCTRLTrafficEnv | task=%d | %s | n_tl=%d | mock=%s", | |
| task_id, self.cfg["name"], self.n_tl, self.mock_sumo, | |
| ) | |
| # ── Gymnasium API ───────────────────────────────────────────────────────── | |
| def _to_serializable(self, obj: Any) -> Any: | |
| """Convert numpy arrays and other non-JSON types to native Python types.""" | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| elif isinstance(obj, (np.integer, np.floating)): | |
| return obj.item() | |
| elif isinstance(obj, dict): | |
| return {k: self._to_serializable(v) for k, v in obj.items()} | |
| elif isinstance(obj, (list, tuple)): | |
| return [self._to_serializable(v) for v in obj] | |
| return obj | |
| def reset( | |
| self, | |
| *, | |
| seed: Optional[int] = None, | |
| options: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
| """ | |
| Reset to the start of a new episode. | |
| `options` dict accepts: | |
| "task_id" / "difficulty" : int — switch task on reset | |
| "use_gui" : bool — override GUI flag | |
| Returns | |
| ------- | |
| obs : dict (see observation_space) | |
| info : dict (task metadata + reset flag) | |
| """ | |
| if seed is not None: | |
| self._rng = np.random.default_rng(seed) | |
| # Optional task switch | |
| if options: | |
| new_task = options.get("task_id") or options.get("difficulty") | |
| if new_task and int(new_task) != self.task_id: | |
| self.task_id = int(new_task) | |
| self.cfg = TASK_CONFIGS[self.task_id] | |
| self.n_tl = self.cfg["active_intersections"] | |
| self.action_space = spaces.MultiDiscrete([4] * self.n_tl) | |
| self._incident_mgr = IncidentManager( | |
| cfg=self.cfg, rng=self._rng, mock_sumo=self.mock_sumo | |
| ) | |
| logger.info("Task switched to %d on reset.", self.task_id) | |
| if "use_gui" in options: | |
| self.use_gui = bool(options["use_gui"]) | |
| self._close_sumo() | |
| # Reset episode counters | |
| self._step_count = 0 | |
| self._sim_time_s = 0.0 | |
| self._phase_elapsed[:] = 0.0 | |
| self._current_phases[:] = 0 | |
| self._episode_throughput[:] = 0.0 | |
| self._episode_delay_sum = 0.0 | |
| self._episode_stops = 0 | |
| self._overflow_events = 0 | |
| self._last_telemetry = {} | |
| self._phase_history.clear() | |
| self._incident_mgr.reset() | |
| if not self.mock_sumo: | |
| self._launch_sumo() | |
| # Generate first observation via mock telemetry (no sim steps yet) | |
| tel = ( | |
| mock_telemetry(self._rng, self.n_tl, self.cfg["probe_noise_sigma"]) | |
| if self.mock_sumo | |
| else self._fetch_telemetry() | |
| ) | |
| self._last_telemetry = tel | |
| obs = build_obs( | |
| telemetry=tel, | |
| current_phases=self._current_phases, | |
| phase_elapsed=self._phase_elapsed, | |
| active_incidents=self._incident_mgr.active_incidents, | |
| step_count=self._step_count, | |
| max_steps=self.cfg["max_steps"], | |
| n_tl=self.n_tl, | |
| ) | |
| info = self._build_info(reward=0.0, terminated=False, truncated=False) | |
| info["reset"] = True | |
| return self._to_serializable(obs), self._to_serializable(info) | |
| def step( | |
| self, | |
| action: np.ndarray, | |
| ) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]: | |
| """ | |
| Apply phase actions and advance the simulation by one decision step. | |
| Parameters | |
| ---------- | |
| action : np.ndarray shape (n_tl,) dtype int | |
| Phase index (0–3) for each active TL. | |
| Returns | |
| ------- | |
| obs, reward, terminated, truncated, info | |
| """ | |
| action_arr = np.asarray(action, dtype=np.int64).flatten() | |
| padded_action = self._pad_action(action_arr) | |
| thrash_count = 0 | |
| # ── Apply phase actions ─────────────────────────────────────────── | |
| for i in range(self.n_tl): | |
| desired = int(padded_action[i]) | |
| current = int(self._current_phases[i]) | |
| # Thrash guard: ignore switch if minimum green not yet elapsed | |
| if ( | |
| desired != current | |
| and desired not in YELLOW_PHASES | |
| and self._phase_elapsed[i] < MIN_GREEN_S | |
| ): | |
| thrash_count += 1 | |
| desired = current # keep current phase | |
| # Auto-insert yellow transition between green phases | |
| if ( | |
| current in (0, 2) | |
| and desired in (0, 2) | |
| and desired != current | |
| ): | |
| yellow = current + 1 # 0→1 (NS) or 2→3 (EW) | |
| self._set_phase(i, yellow) | |
| self._advance_sim(YELLOW_DURATION[yellow]) | |
| self._set_phase(i, desired) | |
| self._phase_history.append((self._sim_time_s, i, desired)) | |
| # ── Advance simulation ──────────────────────────────────────────── | |
| self._advance_sim(self.cfg["decision_interval_s"]) | |
| self._step_count += 1 | |
| # ── Incident tick ───────────────────────────────────────────────── | |
| traci_ref = traci if not self.mock_sumo else None | |
| self._incident_mgr.tick( | |
| step=self._step_count, | |
| n_tl=self.n_tl, | |
| traci=traci_ref, | |
| ) | |
| # ── Telemetry ───────────────────────────────────────────────────── | |
| tel = ( | |
| mock_telemetry(self._rng, self.n_tl, self.cfg["probe_noise_sigma"]) | |
| if self.mock_sumo | |
| else self._fetch_telemetry() | |
| ) | |
| tel["thrash_count"] = thrash_count | |
| # Incident clearance check (updates incident_mgr internal flag) | |
| self._incident_mgr.check_clearance(tel["queue_lengths"]) | |
| tel["incident_cleared"] = self._incident_mgr.incident_cleared | |
| self._last_telemetry = tel | |
| # ── Update phase elapsed ────────────────────────────────────────── | |
| for i in range(self.n_tl): | |
| if int(self._current_phases[i]) == int(padded_action[i]): | |
| self._phase_elapsed[i] += self.cfg["decision_interval_s"] | |
| else: | |
| self._phase_elapsed[i] = 0.0 | |
| # ── Build obs ───────────────────────────────────────────────────── | |
| obs = build_obs( | |
| telemetry=tel, | |
| current_phases=self._current_phases, | |
| phase_elapsed=self._phase_elapsed, | |
| active_incidents=self._incident_mgr.active_incidents, | |
| step_count=self._step_count, | |
| max_steps=self.cfg["max_steps"], | |
| n_tl=self.n_tl, | |
| ) | |
| # ── Reward ──────────────────────────────────────────────────────── | |
| reward = compute_reward( | |
| telemetry=tel, | |
| cfg=self.cfg, | |
| step_count=self._step_count, | |
| n_tl=self.n_tl, | |
| ) | |
| # ── Accumulate episode metrics ──────────────────────────────────── | |
| self._episode_throughput += tel["throughput"] | |
| self._episode_delay_sum += tel["total_waiting_time_s"] | |
| self._episode_stops += int(tel["new_stops"].sum()) | |
| if tel["overflow_lanes"] > 0: | |
| self._overflow_events += 1 | |
| # ── Termination ─────────────────────────────────────────────────── | |
| truncated = self._step_count >= self.cfg["max_steps"] | |
| terminated = False # no early-success termination | |
| info = self._build_info(reward, terminated, truncated, tel) | |
| if truncated or terminated: | |
| info["episode_summary"] = self._build_episode_summary() | |
| return self._to_serializable(obs), reward, terminated, truncated, self._to_serializable(info) | |
| def state(self) -> Dict[str, Any]: | |
| """ | |
| Return a fully JSON-serialisable snapshot of current env state. | |
| Called by the OpenEnv grader after each episode. | |
| """ | |
| state_dict = { | |
| "task_id": self.task_id, | |
| "task_name": self.cfg["name"], | |
| "step": int(self._step_count), | |
| "sim_time_s": float(self._sim_time_s), | |
| "current_phases": self._current_phases.tolist(), | |
| "phase_elapsed_s": self._phase_elapsed.tolist(), | |
| "active_incidents": self._incident_mgr.active_incidents, | |
| "episode_throughput": self._episode_throughput.tolist(), | |
| "episode_delay_sum_s": float(self._episode_delay_sum), | |
| "episode_stops": int(self._episode_stops), | |
| "overflow_events": int(self._overflow_events), | |
| "incident_cleared": bool(self._incident_mgr.incident_cleared), | |
| "success_thresholds": self.cfg["success_thresholds"], | |
| } | |
| return self._to_serializable(state_dict) | |
| def render(self, mode: str = "none") -> None: | |
| """SUMO-GUI handles rendering; this is a no-op for headless mode.""" | |
| def close(self) -> None: | |
| self._close_sumo() | |
| # ── SUMO lifecycle ──────────────────────────────────────────────────────── | |
| def _launch_sumo(self) -> None: | |
| """Start a SUMO subprocess and connect via TraCI.""" | |
| binary = "sumo-gui" if self.use_gui else "sumo" | |
| if self.port == 0: | |
| import socket | |
| with socket.socket() as s: | |
| s.bind(("", 0)) | |
| self.port = s.getsockname()[1] | |
| cmd = [ | |
| binary, | |
| "--configuration-file", str(self.sumo_cfg), | |
| "--route-files", str(self.cfg["route_file"]), | |
| "--remote-port", str(self.port), | |
| "--no-step-log", "true", | |
| "--no-warnings", "true", | |
| "--collision.action", "remove", | |
| "--seed", str(int(self._rng.integers(0, 99_999))), | |
| ] | |
| logger.debug("SUMO cmd: %s", " ".join(cmd)) | |
| self._sumo_process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.DEVNULL, | |
| stderr=subprocess.DEVNULL, | |
| ) | |
| time.sleep(1.0) | |
| traci.init(self.port) | |
| logger.debug("TraCI connected on port %d.", self.port) | |
| def _close_sumo(self) -> None: | |
| try: | |
| if TRACI_AVAILABLE: | |
| traci.close() | |
| except Exception: | |
| pass | |
| if self._sumo_process is not None: | |
| self._sumo_process.terminate() | |
| try: | |
| self._sumo_process.wait(timeout=5) | |
| except subprocess.TimeoutExpired: | |
| self._sumo_process.kill() | |
| self._sumo_process = None | |
| def _advance_sim(self, seconds: float) -> None: | |
| if self.mock_sumo: | |
| self._sim_time_s += seconds | |
| return | |
| for _ in range(max(1, int(seconds))): | |
| traci.simulationStep() | |
| self._sim_time_s += seconds | |
| def _set_phase(self, tl_index: int, phase: int) -> None: | |
| self._current_phases[tl_index] = phase | |
| if not self.mock_sumo: | |
| traci.trafficlight.setPhase(self.cfg["tl_ids"][tl_index], phase) | |
| # ── Telemetry via TraCI ─────────────────────────────────────────────────── | |
| def _fetch_telemetry(self) -> Dict[str, Any]: | |
| """Collect metrics from live TraCI connection.""" | |
| return collect_telemetry( | |
| traci=traci, | |
| cfg=self.cfg, | |
| n_tl=self.n_tl, | |
| rng=self._rng, | |
| phase_history=self._phase_history, | |
| active_incidents=self._incident_mgr.active_incidents, | |
| incident_cleared_flag=self._incident_mgr.incident_cleared, | |
| ) | |
| # ── Helpers ─────────────────────────────────────────────────────────────── | |
| def _pad_action(self, action: np.ndarray) -> np.ndarray: | |
| """Zero-pad a task-sized action to length 3 for internal loops.""" | |
| padded = np.zeros(3, dtype=np.int64) | |
| padded[:len(action)] = action[:3] | |
| return padded | |
| def _build_info( | |
| self, | |
| reward: float, | |
| terminated: bool, | |
| truncated: bool, | |
| telemetry: Optional[Dict] = None, | |
| ) -> Dict[str, Any]: | |
| info: Dict[str, Any] = { | |
| "task_id": self.task_id, | |
| "step": int(self._step_count), | |
| "sim_time_s": float(self._sim_time_s), | |
| "reward": float(reward), | |
| "terminated": terminated, | |
| "truncated": truncated, | |
| } | |
| if telemetry: | |
| info.update({ | |
| "queue_total": float(telemetry["queue_lengths"][:self.n_tl].sum()), | |
| "throughput_total": float(telemetry["throughput"][:self.n_tl].sum()), | |
| "overflow_lanes": int(telemetry["overflow_lanes"]), | |
| "long_wait_count": int(telemetry["long_wait_count"]), | |
| "green_wave_hits": int(telemetry["green_wave_hits"]), | |
| "incident_cleared": bool(telemetry["incident_cleared"]), | |
| "active_incidents": self._incident_mgr.active_incidents, | |
| "reward_breakdown": reward_breakdown( | |
| telemetry, self.cfg, self._step_count, self.n_tl | |
| ), | |
| }) | |
| return info | |
| def _build_episode_summary(self) -> Dict[str, Any]: | |
| """Compute end-of-episode metrics and check success thresholds.""" | |
| total_steps = max(1, self._step_count) | |
| avg_delay = self._episode_delay_sum / total_steps | |
| total_through = float(self._episode_throughput[:self.n_tl].sum()) | |
| # Fall back to heuristic baseline if set_baseline() was never called | |
| baseline_delay = self._baseline_avg_delay or max(1e-6, avg_delay * 1.30) | |
| baseline_throughput = self._baseline_throughput or max(1e-6, total_through * 0.80) | |
| delay_reduction_pct = ( | |
| 100.0 * (baseline_delay - avg_delay) / baseline_delay | |
| ) | |
| throughput_improvement_pct = ( | |
| 100.0 * (total_through - baseline_throughput) / baseline_throughput | |
| ) | |
| t = self.cfg["success_thresholds"] | |
| passed = all([ | |
| delay_reduction_pct >= t.get("delay_reduction_pct", 0.0), | |
| throughput_improvement_pct >= t.get("throughput_improvement_pct", 0.0), | |
| self._overflow_events <= t.get("overflow_events", 9999), | |
| ]) | |
| summary = { | |
| "task_id": self.task_id, | |
| "total_steps": total_steps, | |
| "avg_delay_s": round(float(avg_delay), 4), | |
| "total_throughput": round(float(total_through), 4), | |
| "delay_reduction_pct": round(float(delay_reduction_pct), 4), | |
| "throughput_improvement_pct": round(float(throughput_improvement_pct), 4), | |
| "overflow_events": int(self._overflow_events), | |
| "success": bool(passed), | |
| "thresholds": t, | |
| } | |
| logger.info("Episode done | %s", json.dumps( | |
| {k: v for k, v in summary.items() if k != "thresholds"} | |
| )) | |
| return summary | |
| def set_baseline(self, avg_delay: float, throughput: float) -> None: | |
| """ | |
| Inject fixed-time baseline metrics for episode_summary comparison. | |
| Called by inference.py after running a fixed-time reference episode. | |
| """ | |
| self._baseline_avg_delay = float(avg_delay) | |
| self._baseline_throughput = float(throughput) | |
| def reward_breakdown_last(self) -> Dict[str, float]: | |
| """Return per-term reward breakdown for the most recent step.""" | |
| if not self._last_telemetry: | |
| return {} | |
| return reward_breakdown( | |
| self._last_telemetry, self.cfg, self._step_count, self.n_tl | |
| ) | |
| def __repr__(self) -> str: | |
| return ( | |
| f"JaamCTRLTrafficEnv(" | |
| f"task={self.task_id}, " | |
| f"n_tl={self.n_tl}, " | |
| f"step={self._step_count}/{self.cfg['max_steps']})" | |
| ) |