""" ─────────────────────────────────────────────────────────────────────────────── JaamCTRLTrafficEnv — core Gymnasium environment loop. This class owns: - SUMO subprocess lifecycle (_launch_sumo, _close_sumo) - Action application and yellow-phase safety enforcement - Delegating to observation.py, reward.py, incident_manager.py - Episode metrics accumulation and success checking - The three Gymnasium contract methods: reset(), step(), state() It does NOT define task configs, reward coefficients, or observation math — those all live in their respective modules. ─────────────────────────────────────────────────────────────────────────────── """ from __future__ import annotations import json import logging import os import subprocess import sys import time from collections import deque from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import numpy as np import gymnasium as gym from gymnasium import spaces # ── TraCI import ───────────────────────────────────────────────────────────── if "SUMO_HOME" in os.environ: sys.path.append(os.path.join(os.environ["SUMO_HOME"], "tools")) try: import traci TRACI_AVAILABLE = True except ImportError: TRACI_AVAILABLE = False # ── Intra-package imports ───────────────────────────────────────────────────── from env import ( TASK_CONFIGS, YELLOW_PHASES, YELLOW_DURATION, MIN_GREEN_S, ) from env.reward import compute_reward, reward_breakdown from env.observation import ( collect_telemetry, mock_telemetry, build_obs, OBS_DIM, ) from env.incident_manager import IncidentManager logger = logging.getLogger("JaamCTRL.BaseEnv") class JaamCTRLTrafficEnv(gym.Env): """ Gymnasium-compatible adaptive traffic signal control environment. Supports 3 progressive difficulty tasks via `task_id`. Designed for use with the OpenEnv judging harness and stable-baselines3. Parameters ---------- task_id : 1 = Easy, 2 = Medium, 3 = Hard sumo_cfg_path : path to .sumocfg (default "sumo/corridor.sumocfg") use_gui : launch sumo-gui (True) or headless sumo (False) port : TraCI port; 0 = auto-select a free port seed : RNG seed for full reproducibility mock_sumo : skip SUMO entirely, use synthetic observations (useful for CI / unit tests without SUMO installed) """ metadata = {"render_modes": ["human", "none"], "render_fps": 10} # ── Construction ───────────────────────────────────────────────────────── def __init__( self, task_id: int = 1, sumo_cfg_path: str = "sumo/corridor.sumocfg", use_gui: bool = False, port: int = 0, seed: Optional[int] = None, mock_sumo: bool = False, ) -> None: super().__init__() assert task_id in (1, 2, 3), f"task_id must be 1, 2 or 3; got {task_id}" self.task_id = task_id self.cfg = TASK_CONFIGS[task_id] self.use_gui = use_gui self.port = port self.mock_sumo = mock_sumo or not TRACI_AVAILABLE self.sumo_cfg = Path(sumo_cfg_path) self._rng = np.random.default_rng(seed) self.n_tl = self.cfg["active_intersections"] # ── Spaces ──────────────────────────────────────────────────────── # Action: phase index per active TL (0–3 each) self.action_space = spaces.MultiDiscrete([4] * self.n_tl) # Observation: Dict with a pre-flattened "flat" key for PPO self.observation_space = spaces.Dict({ "queue_lengths": spaces.Box(0.0, 50.0, shape=(3, 4), dtype=np.float32), "current_phase": spaces.MultiDiscrete([4, 4, 4]), "phase_elapsed": spaces.Box(0.0, 120.0, shape=(3,), dtype=np.float32), "probe_density": spaces.Box(0.0, 1.0, shape=(3, 8), dtype=np.float32), "incident_flag": spaces.MultiBinary(3), "time_of_day_norm": spaces.Box(0.0, 1.0, shape=(1,), dtype=np.float32), "flat": spaces.Box(-1.0, 50.0, shape=(OBS_DIM,), dtype=np.float32), }) # ── Internal state ──────────────────────────────────────────────── self._step_count = 0 self._sim_time_s = 0.0 self._sumo_process = None self._phase_elapsed = np.zeros(3, dtype=np.float32) self._current_phases = np.zeros(3, dtype=np.int32) self._episode_throughput = np.zeros(3, dtype=np.float32) self._episode_delay_sum = 0.0 self._episode_stops = 0 self._overflow_events = 0 self._last_telemetry: Dict[str, Any] = {} # Phase history for green-wave computation (observation.py) # Entries: (sim_time_s: float, tl_index: int, phase: int) self._phase_history: deque = deque(maxlen=200) # Baseline metrics for episode_summary comparison self._baseline_avg_delay: Optional[float] = None self._baseline_throughput: Optional[float] = None # Incident manager — owns all chaos events self._incident_mgr = IncidentManager( cfg=self.cfg, rng=self._rng, mock_sumo=self.mock_sumo, ) logger.info( "JaamCTRLTrafficEnv | task=%d | %s | n_tl=%d | mock=%s", task_id, self.cfg["name"], self.n_tl, self.mock_sumo, ) # ── Gymnasium API ───────────────────────────────────────────────────────── def _to_serializable(self, obj: Any) -> Any: """Convert numpy arrays and other non-JSON types to native Python types.""" if isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, (np.integer, np.floating)): return obj.item() elif isinstance(obj, dict): return {k: self._to_serializable(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): return [self._to_serializable(v) for v in obj] return obj def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Reset to the start of a new episode. `options` dict accepts: "task_id" / "difficulty" : int — switch task on reset "use_gui" : bool — override GUI flag Returns ------- obs : dict (see observation_space) info : dict (task metadata + reset flag) """ if seed is not None: self._rng = np.random.default_rng(seed) # Optional task switch if options: new_task = options.get("task_id") or options.get("difficulty") if new_task and int(new_task) != self.task_id: self.task_id = int(new_task) self.cfg = TASK_CONFIGS[self.task_id] self.n_tl = self.cfg["active_intersections"] self.action_space = spaces.MultiDiscrete([4] * self.n_tl) self._incident_mgr = IncidentManager( cfg=self.cfg, rng=self._rng, mock_sumo=self.mock_sumo ) logger.info("Task switched to %d on reset.", self.task_id) if "use_gui" in options: self.use_gui = bool(options["use_gui"]) self._close_sumo() # Reset episode counters self._step_count = 0 self._sim_time_s = 0.0 self._phase_elapsed[:] = 0.0 self._current_phases[:] = 0 self._episode_throughput[:] = 0.0 self._episode_delay_sum = 0.0 self._episode_stops = 0 self._overflow_events = 0 self._last_telemetry = {} self._phase_history.clear() self._incident_mgr.reset() if not self.mock_sumo: self._launch_sumo() # Generate first observation via mock telemetry (no sim steps yet) tel = ( mock_telemetry(self._rng, self.n_tl, self.cfg["probe_noise_sigma"]) if self.mock_sumo else self._fetch_telemetry() ) self._last_telemetry = tel obs = build_obs( telemetry=tel, current_phases=self._current_phases, phase_elapsed=self._phase_elapsed, active_incidents=self._incident_mgr.active_incidents, step_count=self._step_count, max_steps=self.cfg["max_steps"], n_tl=self.n_tl, ) info = self._build_info(reward=0.0, terminated=False, truncated=False) info["reset"] = True return self._to_serializable(obs), self._to_serializable(info) def step( self, action: np.ndarray, ) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]: """ Apply phase actions and advance the simulation by one decision step. Parameters ---------- action : np.ndarray shape (n_tl,) dtype int Phase index (0–3) for each active TL. Returns ------- obs, reward, terminated, truncated, info """ action_arr = np.asarray(action, dtype=np.int64).flatten() padded_action = self._pad_action(action_arr) thrash_count = 0 # ── Apply phase actions ─────────────────────────────────────────── for i in range(self.n_tl): desired = int(padded_action[i]) current = int(self._current_phases[i]) # Thrash guard: ignore switch if minimum green not yet elapsed if ( desired != current and desired not in YELLOW_PHASES and self._phase_elapsed[i] < MIN_GREEN_S ): thrash_count += 1 desired = current # keep current phase # Auto-insert yellow transition between green phases if ( current in (0, 2) and desired in (0, 2) and desired != current ): yellow = current + 1 # 0→1 (NS) or 2→3 (EW) self._set_phase(i, yellow) self._advance_sim(YELLOW_DURATION[yellow]) self._set_phase(i, desired) self._phase_history.append((self._sim_time_s, i, desired)) # ── Advance simulation ──────────────────────────────────────────── self._advance_sim(self.cfg["decision_interval_s"]) self._step_count += 1 # ── Incident tick ───────────────────────────────────────────────── traci_ref = traci if not self.mock_sumo else None self._incident_mgr.tick( step=self._step_count, n_tl=self.n_tl, traci=traci_ref, ) # ── Telemetry ───────────────────────────────────────────────────── tel = ( mock_telemetry(self._rng, self.n_tl, self.cfg["probe_noise_sigma"]) if self.mock_sumo else self._fetch_telemetry() ) tel["thrash_count"] = thrash_count # Incident clearance check (updates incident_mgr internal flag) self._incident_mgr.check_clearance(tel["queue_lengths"]) tel["incident_cleared"] = self._incident_mgr.incident_cleared self._last_telemetry = tel # ── Update phase elapsed ────────────────────────────────────────── for i in range(self.n_tl): if int(self._current_phases[i]) == int(padded_action[i]): self._phase_elapsed[i] += self.cfg["decision_interval_s"] else: self._phase_elapsed[i] = 0.0 # ── Build obs ───────────────────────────────────────────────────── obs = build_obs( telemetry=tel, current_phases=self._current_phases, phase_elapsed=self._phase_elapsed, active_incidents=self._incident_mgr.active_incidents, step_count=self._step_count, max_steps=self.cfg["max_steps"], n_tl=self.n_tl, ) # ── Reward ──────────────────────────────────────────────────────── reward = compute_reward( telemetry=tel, cfg=self.cfg, step_count=self._step_count, n_tl=self.n_tl, ) # ── Accumulate episode metrics ──────────────────────────────────── self._episode_throughput += tel["throughput"] self._episode_delay_sum += tel["total_waiting_time_s"] self._episode_stops += int(tel["new_stops"].sum()) if tel["overflow_lanes"] > 0: self._overflow_events += 1 # ── Termination ─────────────────────────────────────────────────── truncated = self._step_count >= self.cfg["max_steps"] terminated = False # no early-success termination info = self._build_info(reward, terminated, truncated, tel) if truncated or terminated: info["episode_summary"] = self._build_episode_summary() return self._to_serializable(obs), reward, terminated, truncated, self._to_serializable(info) def state(self) -> Dict[str, Any]: """ Return a fully JSON-serialisable snapshot of current env state. Called by the OpenEnv grader after each episode. """ state_dict = { "task_id": self.task_id, "task_name": self.cfg["name"], "step": int(self._step_count), "sim_time_s": float(self._sim_time_s), "current_phases": self._current_phases.tolist(), "phase_elapsed_s": self._phase_elapsed.tolist(), "active_incidents": self._incident_mgr.active_incidents, "episode_throughput": self._episode_throughput.tolist(), "episode_delay_sum_s": float(self._episode_delay_sum), "episode_stops": int(self._episode_stops), "overflow_events": int(self._overflow_events), "incident_cleared": bool(self._incident_mgr.incident_cleared), "success_thresholds": self.cfg["success_thresholds"], } return self._to_serializable(state_dict) def render(self, mode: str = "none") -> None: """SUMO-GUI handles rendering; this is a no-op for headless mode.""" def close(self) -> None: self._close_sumo() # ── SUMO lifecycle ──────────────────────────────────────────────────────── def _launch_sumo(self) -> None: """Start a SUMO subprocess and connect via TraCI.""" binary = "sumo-gui" if self.use_gui else "sumo" if self.port == 0: import socket with socket.socket() as s: s.bind(("", 0)) self.port = s.getsockname()[1] cmd = [ binary, "--configuration-file", str(self.sumo_cfg), "--route-files", str(self.cfg["route_file"]), "--remote-port", str(self.port), "--no-step-log", "true", "--no-warnings", "true", "--collision.action", "remove", "--seed", str(int(self._rng.integers(0, 99_999))), ] logger.debug("SUMO cmd: %s", " ".join(cmd)) self._sumo_process = subprocess.Popen( cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) time.sleep(1.0) traci.init(self.port) logger.debug("TraCI connected on port %d.", self.port) def _close_sumo(self) -> None: try: if TRACI_AVAILABLE: traci.close() except Exception: pass if self._sumo_process is not None: self._sumo_process.terminate() try: self._sumo_process.wait(timeout=5) except subprocess.TimeoutExpired: self._sumo_process.kill() self._sumo_process = None def _advance_sim(self, seconds: float) -> None: if self.mock_sumo: self._sim_time_s += seconds return for _ in range(max(1, int(seconds))): traci.simulationStep() self._sim_time_s += seconds def _set_phase(self, tl_index: int, phase: int) -> None: self._current_phases[tl_index] = phase if not self.mock_sumo: traci.trafficlight.setPhase(self.cfg["tl_ids"][tl_index], phase) # ── Telemetry via TraCI ─────────────────────────────────────────────────── def _fetch_telemetry(self) -> Dict[str, Any]: """Collect metrics from live TraCI connection.""" return collect_telemetry( traci=traci, cfg=self.cfg, n_tl=self.n_tl, rng=self._rng, phase_history=self._phase_history, active_incidents=self._incident_mgr.active_incidents, incident_cleared_flag=self._incident_mgr.incident_cleared, ) # ── Helpers ─────────────────────────────────────────────────────────────── def _pad_action(self, action: np.ndarray) -> np.ndarray: """Zero-pad a task-sized action to length 3 for internal loops.""" padded = np.zeros(3, dtype=np.int64) padded[:len(action)] = action[:3] return padded def _build_info( self, reward: float, terminated: bool, truncated: bool, telemetry: Optional[Dict] = None, ) -> Dict[str, Any]: info: Dict[str, Any] = { "task_id": self.task_id, "step": int(self._step_count), "sim_time_s": float(self._sim_time_s), "reward": float(reward), "terminated": terminated, "truncated": truncated, } if telemetry: info.update({ "queue_total": float(telemetry["queue_lengths"][:self.n_tl].sum()), "throughput_total": float(telemetry["throughput"][:self.n_tl].sum()), "overflow_lanes": int(telemetry["overflow_lanes"]), "long_wait_count": int(telemetry["long_wait_count"]), "green_wave_hits": int(telemetry["green_wave_hits"]), "incident_cleared": bool(telemetry["incident_cleared"]), "active_incidents": self._incident_mgr.active_incidents, "reward_breakdown": reward_breakdown( telemetry, self.cfg, self._step_count, self.n_tl ), }) return info def _build_episode_summary(self) -> Dict[str, Any]: """Compute end-of-episode metrics and check success thresholds.""" total_steps = max(1, self._step_count) avg_delay = self._episode_delay_sum / total_steps total_through = float(self._episode_throughput[:self.n_tl].sum()) # Fall back to heuristic baseline if set_baseline() was never called baseline_delay = self._baseline_avg_delay or max(1e-6, avg_delay * 1.30) baseline_throughput = self._baseline_throughput or max(1e-6, total_through * 0.80) delay_reduction_pct = ( 100.0 * (baseline_delay - avg_delay) / baseline_delay ) throughput_improvement_pct = ( 100.0 * (total_through - baseline_throughput) / baseline_throughput ) t = self.cfg["success_thresholds"] passed = all([ delay_reduction_pct >= t.get("delay_reduction_pct", 0.0), throughput_improvement_pct >= t.get("throughput_improvement_pct", 0.0), self._overflow_events <= t.get("overflow_events", 9999), ]) summary = { "task_id": self.task_id, "total_steps": total_steps, "avg_delay_s": round(float(avg_delay), 4), "total_throughput": round(float(total_through), 4), "delay_reduction_pct": round(float(delay_reduction_pct), 4), "throughput_improvement_pct": round(float(throughput_improvement_pct), 4), "overflow_events": int(self._overflow_events), "success": bool(passed), "thresholds": t, } logger.info("Episode done | %s", json.dumps( {k: v for k, v in summary.items() if k != "thresholds"} )) return summary def set_baseline(self, avg_delay: float, throughput: float) -> None: """ Inject fixed-time baseline metrics for episode_summary comparison. Called by inference.py after running a fixed-time reference episode. """ self._baseline_avg_delay = float(avg_delay) self._baseline_throughput = float(throughput) def reward_breakdown_last(self) -> Dict[str, float]: """Return per-term reward breakdown for the most recent step.""" if not self._last_telemetry: return {} return reward_breakdown( self._last_telemetry, self.cfg, self._step_count, self.n_tl ) def __repr__(self) -> str: return ( f"JaamCTRLTrafficEnv(" f"task={self.task_id}, " f"n_tl={self.n_tl}, " f"step={self._step_count}/{self.cfg['max_steps']})" )