Spaces:
Sleeping
Sleeping
| import random | |
| import numpy as np | |
| from typing import Dict, Any, Optional | |
| from src.models import State, Action, StepResult | |
| class TrafficEnv: | |
| def __init__(self, config: Dict[str, Any]): | |
| self.config = config | |
| self.max_time = config.get("max_time", 100) | |
| self.arrival_rate_base = config.get("arrival_rate", 2) | |
| self.congestion_multiplier = config.get("congestion_multiplier", 1.0) | |
| self.emergency_prob = config.get("emergency_prob", 0.0) | |
| self.queue_cap = 100 | |
| self.reset() | |
| def reset(self, seed: Optional[int] = None) -> State: | |
| if seed is not None: | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| self.north = 0 | |
| self.south = 0 | |
| self.east = 0 | |
| self.west = 0 | |
| self.current_signal = "red" | |
| self.waiting_time_total = 0.0 | |
| self.time_step = 0 | |
| self.emergency_present = False | |
| self.emergency_direction_str = 'none' | |
| self.total_cleared = 0 | |
| self.total_waiting_time = 0.0 | |
| self.emergency_response_time = 0 | |
| self.emergencies_handled = 0 | |
| self.total_emergencies_generated = 0 | |
| self.done = False | |
| self.prev_ns_total = 0 | |
| self.prev_ew_total = 0 | |
| self.ns_wait_time = 0.0 | |
| self.ew_wait_time = 0.0 | |
| self.reward_trends = [] | |
| return self.state() | |
| def state(self) -> State: | |
| ns_total = self.north + self.south | |
| ew_total = self.east + self.west | |
| ns_growth = float(ns_total - self.prev_ns_total) | |
| ew_growth = float(ew_total - self.prev_ew_total) | |
| return State( | |
| north_queue=self.north, | |
| south_queue=self.south, | |
| east_queue=self.east, | |
| west_queue=self.west, | |
| current_signal=self.current_signal, | |
| waiting_time_total=self.waiting_time_total, | |
| emergency_vehicle_present=self.emergency_present, | |
| time_step=self.time_step, | |
| ns_growth=ns_growth, | |
| ew_growth=ew_growth, | |
| emergency_direction=self.emergency_direction_str, | |
| ns_wait_time=self.ns_wait_time, | |
| ew_wait_time=self.ew_wait_time | |
| ) | |
| def step(self, action_idx: int) -> StepResult: | |
| if self.done: | |
| return StepResult(self.state(), 0, True, {"msg": "Done"}) | |
| self.prev_ns_total = self.north + self.south | |
| self.prev_ew_total = self.east + self.west | |
| action = Action(action_idx) | |
| reward = 0.0 | |
| prev_signal = self.current_signal | |
| if action.action_type == 0: | |
| self.current_signal = "red" | |
| elif action.action_type == 1: | |
| self.current_signal = "green_ns" | |
| elif action.action_type == 2: | |
| self.current_signal = "green_ew" | |
| if prev_signal != self.current_signal and prev_signal != "red": | |
| reward -= 0.5 | |
| ns_total = self.north + self.south | |
| ew_total = self.east + self.west | |
| total_waiting = ns_total + ew_total | |
| reward -= (total_waiting * 0.15) | |
| self.waiting_time_total += total_waiting | |
| self.total_waiting_time += total_waiting | |
| if self.emergency_present: | |
| self.emergency_response_time += 1 | |
| reward -= 0.5 | |
| cleared_this_step = 0 | |
| clearance_capacity = 8 | |
| emergency_cleared = False | |
| if self.current_signal == "green_ns": | |
| c_n = min(self.north, clearance_capacity) | |
| c_s = min(self.south, clearance_capacity) | |
| self.north -= c_n | |
| self.south -= c_s | |
| cleared_this_step = c_n + c_s | |
| if self.emergency_present and self.emergency_direction_str == 'ns': | |
| emergency_cleared = True | |
| elif self.current_signal == "green_ew": | |
| c_e = min(self.east, clearance_capacity) | |
| c_w = min(self.west, clearance_capacity) | |
| self.east -= c_e | |
| self.west -= c_w | |
| cleared_this_step = c_e + c_w | |
| if self.emergency_present and self.emergency_direction_str == 'ew': | |
| emergency_cleared = True | |
| self.total_cleared += cleared_this_step | |
| reward += cleared_this_step * 0.75 | |
| if self.current_signal == "green_ns" and ns_total == 0: | |
| reward -= 0.5 | |
| elif self.current_signal == "green_ew" and ew_total == 0: | |
| reward -= 0.5 | |
| if total_waiting > 0 and cleared_this_step == 0: | |
| reward -= 0.5 | |
| if emergency_cleared: | |
| reward += 15.0 | |
| self.emergency_present = False | |
| self.emergency_direction_str = 'none' | |
| self.emergencies_handled += 1 | |
| if ns_total > 0 and self.current_signal != "green_ns": | |
| self.ns_wait_time += 1.0 | |
| else: | |
| self.ns_wait_time = 0.0 | |
| if ew_total > 0 and self.current_signal != "green_ew": | |
| self.ew_wait_time += 1.0 | |
| else: | |
| self.ew_wait_time = 0.0 | |
| reward -= 0.1 | |
| current_multiplier = 1.0 + (self.congestion_multiplier * (self.time_step / self.max_time)) | |
| total_expected_rate = (self.arrival_rate_base * 4) * current_multiplier | |
| noise_factor = random.uniform(0.85, 1.15) | |
| noisy_rate = total_expected_rate * noise_factor | |
| lane_split = np.random.dirichlet([5, 5, 5, 5]) | |
| def arrive(r): | |
| base = int(r) | |
| return base + 1 if random.random() < (r - base) else base | |
| self.north = min(self.queue_cap, self.north + arrive(noisy_rate * lane_split[0])) | |
| self.south = min(self.queue_cap, self.south + arrive(noisy_rate * lane_split[1])) | |
| self.east = min(self.queue_cap, self.east + arrive(noisy_rate * lane_split[2])) | |
| self.west = min(self.queue_cap, self.west + arrive(noisy_rate * lane_split[3])) | |
| if not self.emergency_present and random.random() < self.emergency_prob: | |
| self.emergency_present = True | |
| self.emergency_direction_str = random.choice(['ns', 'ew']) | |
| self.total_emergencies_generated += 1 | |
| reward += random.uniform(-0.1, 0.1) | |
| self.time_step += 1 | |
| if self.time_step >= self.max_time: | |
| self.done = True | |
| self.reward_trends.append(reward) | |
| info = { | |
| "total_cleared": self.total_cleared, | |
| "avg_waiting_time": self.total_waiting_time / max(1, self.total_cleared), | |
| "emergencies_handled": self.emergencies_handled, | |
| "total_emergencies": self.total_emergencies_generated, | |
| "reward_trend_avg": sum(self.reward_trends[-10:]) / 10 if self.reward_trends else 0 | |
| } | |
| return StepResult(self.state(), reward, self.done, info) | |