agentic-traffic / env /traffic_env.py
Aditya2162's picture
Upload folder using huggingface_hub
3d2dbcf verified
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
from env.cityflow_adapter import CityFlowAdapter
from env.intersection_config import DistrictConfig, IntersectionConfig
from env.observation_builder import ObservationBuilder, ObservationConfig
from env.reward import RewardCalculator, RewardConfig
from env.utils import build_topology, load_json
@dataclass(frozen=True)
class EnvConfig:
simulator_interval: int = 1
decision_interval: int = 5
min_green_time: int = 10
thread_num: int = 1
observation: ObservationConfig = ObservationConfig()
reward: RewardConfig = RewardConfig()
max_episode_seconds: int | None = None
class TrafficEnv:
def __init__(
self,
city_id: str,
scenario_name: str,
city_dir: str | Path,
scenario_dir: str | Path,
config_path: str | Path,
roadnet_path: str | Path,
district_map_path: str | Path | None = None,
metadata_path: str | Path | None = None,
env_config: EnvConfig | None = None,
):
self.city_id = city_id
self.scenario_name = scenario_name
self.city_dir = Path(city_dir)
self.scenario_dir = Path(scenario_dir)
self.original_config_path = Path(config_path)
self.roadnet_path = Path(roadnet_path)
self.district_map_path = Path(district_map_path) if district_map_path else None
self.metadata_path = Path(metadata_path) if metadata_path else None
self.env_config = env_config or EnvConfig()
self.intersections, self.districts = build_topology(
roadnet_path=self.roadnet_path,
district_map_path=self.district_map_path,
metadata_path=self.metadata_path,
)
if not self.intersections:
raise ValueError(
f"No controllable intersections found for {self.city_id}/{self.scenario_name}."
)
self.controlled_intersection_ids = tuple(sorted(self.intersections))
self.observation_builder = ObservationBuilder(
intersections=self.intersections,
districts=self.districts,
config=self.env_config.observation,
)
self.reward_calculator = RewardCalculator(self.env_config.reward)
self.adapter = CityFlowAdapter(
config_path=self.original_config_path,
thread_num=self.env_config.thread_num,
)
config_payload = load_json(self.original_config_path)
self.max_episode_seconds = int(
self.env_config.max_episode_seconds
or config_payload.get("step", 0)
)
self.metadata = load_json(self.metadata_path) if self.metadata_path else {}
self._district_type_labels = tuple(
self.intersections[intersection_id].district_type
for intersection_id in self.controlled_intersection_ids
)
self._incoming_lane_counts = np.asarray(
[
max(1, len(self.intersections[intersection_id].incoming_lanes))
for intersection_id in self.controlled_intersection_ids
],
dtype=np.float32,
)
self.current_phase_positions: dict[str, int] = {}
self.phase_elapsed_times: dict[str, int] = {}
self.decision_step_count = 0
self.episode_return = 0.0
self.total_episode_return = 0.0
self.last_info: dict[str, Any] = {}
self.reward_component_sums: dict[str, float] = {}
@property
def observation_dim(self) -> int:
return self.observation_builder.observation_dim
def reset(self, seed: int | None = None) -> dict[str, Any]:
del seed
self.adapter.reset()
self.decision_step_count = 0
self.episode_return = 0.0
self.total_episode_return = 0.0
self.reward_component_sums = {}
self.current_phase_positions = {}
self.phase_elapsed_times = {}
for intersection_id in self.controlled_intersection_ids:
config = self.intersections[intersection_id]
initial_position = 0
initial_phase = config.green_phases[initial_position].engine_phase_index
self.current_phase_positions[intersection_id] = initial_position
self.phase_elapsed_times[intersection_id] = 0
self.adapter.set_tl_phase(intersection_id, initial_phase)
observation = self._build_observation()
self.reward_calculator.reset(
incoming_waiting=observation["incoming_waiting"],
incoming_counts=observation["incoming_counts"],
incoming_lane_counts=self._incoming_lane_counts,
finished_vehicle_count=self.adapter.get_finished_vehicle_count(),
)
self.last_info = self._build_info(
rewards=np.zeros(len(self.controlled_intersection_ids), dtype=np.float32),
avg_incoming_counts=observation["incoming_counts"],
avg_incoming_waiting=observation["incoming_waiting"],
reward_components={},
)
return observation
def step(
self,
actions: dict[str, int] | list[int] | np.ndarray,
) -> tuple[dict[str, Any], np.ndarray, bool, dict[str, Any]]:
normalized_actions = self._normalize_actions(actions)
self._apply_actions(normalized_actions)
avg_incoming_counts, avg_incoming_waiting, avg_outgoing_counts = self._advance_simulator()
reward_breakdown = self.reward_calculator.compute_breakdown(
incoming_waiting=avg_incoming_waiting,
incoming_counts=avg_incoming_counts,
outgoing_counts=avg_outgoing_counts,
incoming_lane_counts=self._incoming_lane_counts,
finished_vehicle_count=self.adapter.get_finished_vehicle_count(),
)
rewards = reward_breakdown.reward
self.decision_step_count += 1
self.total_episode_return += float(rewards.sum())
self.episode_return = self._mean_step_intersection_reward()
self._accumulate_reward_components(reward_breakdown.components)
observation = self._build_observation()
done = self.adapter.get_current_time() >= self.max_episode_seconds
info = self._build_info(
rewards=rewards,
avg_incoming_counts=avg_incoming_counts,
avg_incoming_waiting=avg_incoming_waiting,
reward_components=reward_breakdown.components,
)
self.last_info = info
return observation, rewards, done, info
def _build_observation(self) -> dict[str, Any]:
lane_vehicle_count = self.adapter.get_lane_vehicle_count()
lane_waiting_count = self.adapter.get_lane_waiting_vehicle_count()
switch_allowed = {
intersection_id: (
self.phase_elapsed_times[intersection_id] >= self.env_config.min_green_time
)
for intersection_id in self.controlled_intersection_ids
}
observation = self.observation_builder.build(
lane_vehicle_count=lane_vehicle_count,
lane_waiting_count=lane_waiting_count,
phase_positions=self.current_phase_positions,
phase_elapsed_times=self.phase_elapsed_times,
switch_allowed=switch_allowed,
)
observation["city_id"] = self.city_id
observation["scenario_name"] = self.scenario_name
observation["decision_step"] = self.decision_step_count
observation["sim_time"] = self.adapter.get_current_time()
return observation
def _apply_actions(self, actions: np.ndarray) -> None:
for action_index, intersection_id in enumerate(self.controlled_intersection_ids):
config = self.intersections[intersection_id]
current_position = self.current_phase_positions[intersection_id]
can_switch = self.phase_elapsed_times[intersection_id] >= self.env_config.min_green_time
should_switch = int(actions[action_index]) == 1 and can_switch
if should_switch:
next_position = (current_position + 1) % config.num_green_phases
engine_phase = config.green_phases[next_position].engine_phase_index
self.adapter.set_tl_phase(intersection_id, engine_phase)
self.current_phase_positions[intersection_id] = next_position
self.phase_elapsed_times[intersection_id] = 0
else:
current_engine_phase = config.green_phases[current_position].engine_phase_index
self.adapter.set_tl_phase(intersection_id, current_engine_phase)
def _advance_simulator(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
num_intersections = len(self.controlled_intersection_ids)
max_lanes = self.env_config.observation.max_incoming_lanes
avg_incoming_counts = np.zeros((num_intersections, max_lanes), dtype=np.float32)
avg_incoming_waiting = np.zeros((num_intersections, max_lanes), dtype=np.float32)
avg_outgoing_counts = np.zeros((num_intersections, max_lanes), dtype=np.float32)
for _ in range(self.env_config.decision_interval):
self.adapter.step()
lane_vehicle_count = self.adapter.get_lane_vehicle_count()
lane_waiting_count = self.adapter.get_lane_waiting_vehicle_count()
for row_index, intersection_id in enumerate(self.controlled_intersection_ids):
config = self.intersections[intersection_id]
for lane_index, lane_id in enumerate(
config.incoming_lanes[: self.env_config.observation.max_incoming_lanes]
):
avg_incoming_counts[row_index, lane_index] += float(
lane_vehicle_count.get(lane_id, 0)
)
avg_incoming_waiting[row_index, lane_index] += float(
lane_waiting_count.get(lane_id, 0)
)
for lane_index, lane_id in enumerate(
config.outgoing_lanes[: self.env_config.observation.max_incoming_lanes]
):
avg_outgoing_counts[row_index, lane_index] += float(
lane_vehicle_count.get(lane_id, 0)
)
self.phase_elapsed_times[intersection_id] += self.env_config.simulator_interval
avg_incoming_counts /= float(self.env_config.decision_interval)
avg_incoming_waiting /= float(self.env_config.decision_interval)
avg_outgoing_counts /= float(self.env_config.decision_interval)
return avg_incoming_counts, avg_incoming_waiting, avg_outgoing_counts
def _build_info(
self,
rewards: np.ndarray,
avg_incoming_counts: np.ndarray,
avg_incoming_waiting: np.ndarray,
reward_components: dict[str, np.ndarray],
) -> dict[str, Any]:
mean_reward = float(rewards.mean()) if rewards.size else 0.0
average_travel_time = self.adapter.get_average_travel_time()
info = {
"city_id": self.city_id,
"scenario_name": self.scenario_name,
"decision_step": self.decision_step_count,
"sim_time": self.adapter.get_current_time(),
"episode_return": float(self.episode_return),
"total_episode_return": float(self.total_episode_return),
"intersection_ids": self.controlled_intersection_ids,
"district_types": self._district_type_labels,
"metrics": {
"num_controlled_intersections": len(self.controlled_intersection_ids),
"mean_reward": mean_reward,
"mean_step_intersection_reward": self._mean_step_intersection_reward(),
"mean_waiting_vehicles": float(avg_incoming_waiting.sum(axis=1).mean()),
"mean_incoming_vehicles": float(avg_incoming_counts.sum(axis=1).mean()),
"total_waiting_vehicles": float(avg_incoming_waiting.sum()),
"total_incoming_vehicles": float(avg_incoming_counts.sum()),
"running_vehicles": self.adapter.get_vehicle_count(),
"throughput": self.adapter.get_finished_vehicle_count(),
"average_travel_time": average_travel_time,
"reward_variant": self.env_config.reward.variant,
},
}
info["metrics"].update(self._reward_component_metrics(reward_components))
info["metrics"].update(
per_district_type_metrics(
district_types=self._district_type_labels,
rewards=rewards,
avg_incoming_counts=avg_incoming_counts,
avg_incoming_waiting=avg_incoming_waiting,
)
)
return info
def _normalize_actions(
self,
actions: dict[str, int] | list[int] | np.ndarray,
) -> np.ndarray:
if isinstance(actions, dict):
return np.asarray(
[int(actions.get(intersection_id, 0)) for intersection_id in self.controlled_intersection_ids],
dtype=np.int64,
)
array = np.asarray(actions, dtype=np.int64)
if array.shape != (len(self.controlled_intersection_ids),):
raise ValueError(
"Actions must provide exactly one action per controlled intersection."
)
return array
def _mean_step_intersection_reward(self) -> float:
denominator = max(
1,
self.decision_step_count * len(self.controlled_intersection_ids),
)
return float(self.total_episode_return) / float(denominator)
def _accumulate_reward_components(self, components: dict[str, np.ndarray]) -> None:
for name, values in components.items():
self.reward_component_sums[name] = self.reward_component_sums.get(name, 0.0) + float(
np.asarray(values, dtype=np.float32).mean()
)
def _reward_component_metrics(
self,
reward_components: dict[str, np.ndarray],
) -> dict[str, float]:
metrics: dict[str, float] = {}
for name, values in reward_components.items():
metrics[f"reward_component_step_{name}"] = float(
np.asarray(values, dtype=np.float32).mean()
)
if self.decision_step_count <= 0:
return metrics
for name, total in self.reward_component_sums.items():
metrics[f"reward_component_mean_{name}"] = float(total) / float(
self.decision_step_count
)
return metrics
def per_district_type_metrics(
district_types: tuple[str, ...],
rewards: np.ndarray,
avg_incoming_counts: np.ndarray,
avg_incoming_waiting: np.ndarray,
) -> dict[str, float]:
metrics: dict[str, float] = {}
reward_vector = np.asarray(rewards, dtype=np.float32)
incoming_totals = avg_incoming_counts.sum(axis=1)
waiting_totals = avg_incoming_waiting.sum(axis=1)
for district_type in sorted(set(district_types)):
mask = np.asarray(
[item == district_type for item in district_types],
dtype=bool,
)
if not mask.any():
continue
metrics[f"num_{district_type}_intersections"] = float(mask.sum())
metrics[f"mean_reward_{district_type}"] = float(reward_vector[mask].mean())
metrics[f"mean_waiting_vehicles_{district_type}"] = float(waiting_totals[mask].mean())
metrics[f"mean_incoming_vehicles_{district_type}"] = float(incoming_totals[mask].mean())
return metrics