Smart-Traffic-openenv / src /environment.py
Aryansabasana
Fix score saturation, add controlled variability, update heuristic
f259f2b
import random
import numpy as np
from typing import Dict, Any, Optional
from src.models import State, Action, StepResult
class TrafficEnv:
def __init__(self, config: Dict[str, Any]):
self.config = config
self.max_time = config.get("max_time", 100)
self.arrival_rate_base = config.get("arrival_rate", 2)
self.congestion_multiplier = config.get("congestion_multiplier", 1.0)
self.emergency_prob = config.get("emergency_prob", 0.0)
self.queue_cap = 100
self.reset()
def reset(self, seed: Optional[int] = None) -> State:
if seed is not None:
random.seed(seed)
np.random.seed(seed)
self.north = 0
self.south = 0
self.east = 0
self.west = 0
self.current_signal = "red"
self.waiting_time_total = 0.0
self.time_step = 0
self.emergency_present = False
self.emergency_direction_str = 'none'
self.total_cleared = 0
self.total_waiting_time = 0.0
self.emergency_response_time = 0
self.emergencies_handled = 0
self.total_emergencies_generated = 0
self.done = False
self.prev_ns_total = 0
self.prev_ew_total = 0
self.ns_wait_time = 0.0
self.ew_wait_time = 0.0
self.reward_trends = []
return self.state()
def state(self) -> State:
ns_total = self.north + self.south
ew_total = self.east + self.west
ns_growth = float(ns_total - self.prev_ns_total)
ew_growth = float(ew_total - self.prev_ew_total)
return State(
north_queue=self.north,
south_queue=self.south,
east_queue=self.east,
west_queue=self.west,
current_signal=self.current_signal,
waiting_time_total=self.waiting_time_total,
emergency_vehicle_present=self.emergency_present,
time_step=self.time_step,
ns_growth=ns_growth,
ew_growth=ew_growth,
emergency_direction=self.emergency_direction_str,
ns_wait_time=self.ns_wait_time,
ew_wait_time=self.ew_wait_time
)
def step(self, action_idx: int) -> StepResult:
if self.done:
return StepResult(self.state(), 0, True, {"msg": "Done"})
self.prev_ns_total = self.north + self.south
self.prev_ew_total = self.east + self.west
action = Action(action_idx)
reward = 0.0
prev_signal = self.current_signal
if action.action_type == 0:
self.current_signal = "red"
elif action.action_type == 1:
self.current_signal = "green_ns"
elif action.action_type == 2:
self.current_signal = "green_ew"
if prev_signal != self.current_signal and prev_signal != "red":
reward -= 0.5
ns_total = self.north + self.south
ew_total = self.east + self.west
total_waiting = ns_total + ew_total
reward -= (total_waiting * 0.15)
self.waiting_time_total += total_waiting
self.total_waiting_time += total_waiting
if self.emergency_present:
self.emergency_response_time += 1
reward -= 0.5
cleared_this_step = 0
clearance_capacity = 8
emergency_cleared = False
if self.current_signal == "green_ns":
c_n = min(self.north, clearance_capacity)
c_s = min(self.south, clearance_capacity)
self.north -= c_n
self.south -= c_s
cleared_this_step = c_n + c_s
if self.emergency_present and self.emergency_direction_str == 'ns':
emergency_cleared = True
elif self.current_signal == "green_ew":
c_e = min(self.east, clearance_capacity)
c_w = min(self.west, clearance_capacity)
self.east -= c_e
self.west -= c_w
cleared_this_step = c_e + c_w
if self.emergency_present and self.emergency_direction_str == 'ew':
emergency_cleared = True
self.total_cleared += cleared_this_step
reward += cleared_this_step * 0.75
if self.current_signal == "green_ns" and ns_total == 0:
reward -= 0.5
elif self.current_signal == "green_ew" and ew_total == 0:
reward -= 0.5
if total_waiting > 0 and cleared_this_step == 0:
reward -= 0.5
if emergency_cleared:
reward += 15.0
self.emergency_present = False
self.emergency_direction_str = 'none'
self.emergencies_handled += 1
if ns_total > 0 and self.current_signal != "green_ns":
self.ns_wait_time += 1.0
else:
self.ns_wait_time = 0.0
if ew_total > 0 and self.current_signal != "green_ew":
self.ew_wait_time += 1.0
else:
self.ew_wait_time = 0.0
reward -= 0.1
current_multiplier = 1.0 + (self.congestion_multiplier * (self.time_step / self.max_time))
total_expected_rate = (self.arrival_rate_base * 4) * current_multiplier
noise_factor = random.uniform(0.85, 1.15)
noisy_rate = total_expected_rate * noise_factor
lane_split = np.random.dirichlet([5, 5, 5, 5])
def arrive(r):
base = int(r)
return base + 1 if random.random() < (r - base) else base
self.north = min(self.queue_cap, self.north + arrive(noisy_rate * lane_split[0]))
self.south = min(self.queue_cap, self.south + arrive(noisy_rate * lane_split[1]))
self.east = min(self.queue_cap, self.east + arrive(noisy_rate * lane_split[2]))
self.west = min(self.queue_cap, self.west + arrive(noisy_rate * lane_split[3]))
if not self.emergency_present and random.random() < self.emergency_prob:
self.emergency_present = True
self.emergency_direction_str = random.choice(['ns', 'ew'])
self.total_emergencies_generated += 1
reward += random.uniform(-0.1, 0.1)
self.time_step += 1
if self.time_step >= self.max_time:
self.done = True
self.reward_trends.append(reward)
info = {
"total_cleared": self.total_cleared,
"avg_waiting_time": self.total_waiting_time / max(1, self.total_cleared),
"emergencies_handled": self.emergencies_handled,
"total_emergencies": self.total_emergencies_generated,
"reward_trend_avg": sum(self.reward_trends[-10:]) / 10 if self.reward_trends else 0
}
return StepResult(self.state(), reward, self.done, info)