Traffic-Control / environment /traffic_env.py
Dhaerya's picture
Add files
b00d5d5
"""
Traffic Environment β€” Gymnasium-compatible RL environment for traffic signal control.
State space : [N_SR, N_L, E_SR, E_L, S_SR, S_L, W_SR, W_L, current_phase] (9 features, float32 ∈ [0,1])
Action space : Discrete(2) β†’ 0 = keep phase, 1 = switch to next phase
Reward : βˆ’total_queue / 1000, clipped to [βˆ’1, 1]
Key design decisions (from PROJECT_EXPLANATION.md):
β€’ Dynamic normalization (divide by current max) prevents state saturation.
β€’ Directional phases (N, E, S, W) eliminate turning collisions.
β€’ Extended green time (10 steps) when switching makes actions impactful.
β€’ Reward clipping prevents gradient explosion during DQN training.
"""
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from .traffic_generator import TrafficGenerator
class TrafficEnvironment(gym.Env):
"""
Single-intersection traffic signal control environment.
The agent controls a 4-phase signal and must minimise total vehicle
waiting time across all four approach lanes (N / E / S / W).
"""
metadata = {"render_modes": ["human"], "render_fps": 30}
# Phase β†’ green queue indices mapping (8 queues total)
# Phase 0: North (0=SR, 1=L), Phase 1: East (2=SR, 3=L)
# Phase 2: South (4=SR, 5=L), Phase 3: West (6=SR, 7=L)
_PHASE_GREEN: dict = {
0: [0, 1],
1: [2, 3],
2: [4, 5],
3: [6, 7],
}
def __init__(self, config=None):
"""
Args:
config: Configuration module/object. Uses default config if None.
"""
super().__init__()
if config is None:
import config as default_config
config = default_config
self.config = config
# Environment parameters
self.num_lanes = config.NUM_LANES
self.episode_length = config.EPISODE_LENGTH
self.min_green_time = 8 # Steps before a switch is allowed
self.extended_green_time = 10 # Extra processing steps after switch
self.yellow_time = config.YELLOW_TIME
# Traffic simulator
self.traffic_generator = TrafficGenerator(config)
# ── Observation space ──────────────────────────────────────────
# 8 queues + phase, all normalised ∈ [0, 1]
self.observation_space = spaces.Box(
low=0.0, high=1.0, shape=(9,), dtype=np.float32
)
# ── Action space ───────────────────────────────────────────────
# 0 = keep current phase | 1 = switch to next phase
self.action_space = spaces.Discrete(2)
# Internal state
self.current_step: int = 0
self.current_phase: int = 0
self.time_in_phase: int = 0
self.queue_lengths: np.ndarray = np.zeros(8, dtype=np.float32)
self.waiting_times: np.ndarray = np.zeros(8, dtype=np.float32)
self.vehicles_passed: int = 0
self.last_action: int = 0
self.render_mode = None
# ------------------------------------------------------------------
# Gymnasium API
# ------------------------------------------------------------------
def reset(self, seed=None, options=None):
"""Reset environment to initial state and return (observation, info)."""
super().reset(seed=seed)
self.current_step = 0
self.current_phase = 0
self.time_in_phase = 0
self.queue_lengths = np.zeros(8, dtype=np.float32)
self.waiting_times = np.zeros(8, dtype=np.float32)
self.vehicles_passed = 0
self.last_action = 0
self.traffic_generator.reset()
observation = self._get_observation()
info = self._get_info()
return observation, info
def step(self, action: int):
"""
Execute one decision step.
Args:
action: 0 = keep current phase, 1 = switch to next phase.
Returns:
(observation, reward, terminated, truncated, info)
"""
if not self.action_space.contains(action):
raise ValueError(f"Invalid action {action!r}. Must be 0 or 1.")
is_switching = bool(action == 1)
# ── Phase switch ───────────────────────────────────────────────
if is_switching and self.time_in_phase >= self.min_green_time:
self.current_phase = (self.current_phase + 1) % 4
self.time_in_phase = 0
# Extended green: process multiple clearing steps for visible impact
for _ in range(self.extended_green_time):
cleared = self._process_phase()
self.vehicles_passed += int(cleared)
self.time_in_phase += 1
self.current_step += 1
# ── Vehicle arrivals ───────────────────────────────────────────
new_vehicles = self.traffic_generator.generate(self.current_step)
self.queue_lengths = self.queue_lengths + new_vehicles
# ── Normal phase processing ────────────────────────────────────
vehicles_passing = self._process_phase()
self.vehicles_passed += int(vehicles_passing)
# ── Waiting time accumulation ──────────────────────────────────
self.waiting_times = self.waiting_times + self.queue_lengths
# ── Reward ────────────────────────────────────────────────────
reward = float(self._calculate_reward())
self.last_action = action
terminated = bool(self.current_step >= self.episode_length)
truncated = False
observation = self._get_observation()
info = self._get_info()
info["waiting_time"] = float(np.sum(self.waiting_times))
info["queue_length"] = float(np.sum(self.queue_lengths))
return observation, reward, terminated, truncated, info
def render(self):
"""Console render (human mode)."""
if self.render_mode == "human":
print(
f"Step: {self.current_step:4d} | Phase: {self.current_phase} | "
f"Queues: {self.queue_lengths} | Passed: {self.vehicles_passed}"
)
def close(self):
pass
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _get_observation(self) -> np.ndarray:
"""
Build the 9-dimensional state vector.
Queue features are normalised by the current maximum queue value
(dynamic normalisation) to preserve relative lane differences and
prevent saturation when absolute queue counts are large.
"""
queue_state = self.queue_lengths.copy().astype(np.float32)
# Absolute normalisation (cap at 20 vehicles to keep ∈ [0, 1])
queue_state = np.clip(queue_state / 20.0, 0.0, 1.0)
phase_state = np.array(
[float(self.current_phase) / 3.0], dtype=np.float32
)
observation = np.concatenate([queue_state, phase_state])
# Validate
assert observation.shape == (9,), f"Bad obs shape: {observation.shape}"
assert observation.dtype == np.float32
assert not np.any(np.isnan(observation)), "NaN in observation"
assert not np.any(np.isinf(observation)), "Inf in observation"
return observation
def _get_info(self) -> dict:
return {
"current_step": self.current_step,
"current_phase": self.current_phase,
"total_queue_length": float(np.sum(self.queue_lengths)),
"average_waiting_time": float(np.mean(self.waiting_times)),
"vehicles_passed": self.vehicles_passed,
}
def _process_phase(self) -> float:
"""
Clear vehicles from green-light lanes.
Returns:
vehicles_passing: Number of vehicles that cleared this step.
"""
green_dirs = self._PHASE_GREEN.get(self.current_phase, [])
vehicles_passing = 0.0
for d in green_dirs:
if self.queue_lengths[d] > 0:
passing = min(
self.queue_lengths[d],
float(np.random.randint(1, 3)),
)
self.queue_lengths[d] -= passing
vehicles_passing += passing
return vehicles_passing
def _calculate_reward(self) -> float:
"""
Compute reward signal.
reward = βˆ’total_queue / 1000 (clipped to [βˆ’1, 1])
Dividing by 1000 keeps the magnitude in a range suitable for
stable neural-network training; clipping prevents extreme gradients.
"""
total_queue = float(np.sum(self.queue_lengths))
reward = -total_queue / 20.0
return float(np.clip(reward, -1.0, 1.0))