traffic-visualizer / env /observation_builder.py
tokev's picture
Add files using upload-large-folder tool
5893134 verified
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
from env.intersection_config import DistrictConfig, IntersectionConfig
from env.utils import normalize_scalar
@dataclass(frozen=True)
class ObservationConfig:
max_incoming_lanes: int = 16
count_scale: float = 20.0
elapsed_time_scale: float = 60.0
include_outgoing_congestion: bool = True
include_district_context: bool = True
include_district_type_feature: bool = True
class ObservationBuilder:
def __init__(
self,
intersections: dict[str, IntersectionConfig],
districts: dict[str, DistrictConfig],
config: ObservationConfig | None = None,
):
self.intersections = intersections
self.districts = districts
self.config = config or ObservationConfig()
self.intersection_ids = tuple(sorted(intersections))
self._district_lookup = {
intersection_id: intersections[intersection_id].district_id
for intersection_id in self.intersection_ids
}
self._district_sizes = {
district_id: max(1, len(district.intersection_ids))
for district_id, district in districts.items()
}
self.observation_dim = self._compute_observation_dim()
def build(
self,
lane_vehicle_count: dict[str, int],
lane_waiting_count: dict[str, int],
phase_positions: dict[str, int],
phase_elapsed_times: dict[str, int],
switch_allowed: dict[str, bool],
) -> dict[str, np.ndarray | tuple[str, ...]]:
district_context = self._compute_district_context(
lane_vehicle_count=lane_vehicle_count,
lane_waiting_count=lane_waiting_count,
)
num_intersections = len(self.intersection_ids)
max_lanes = self.config.max_incoming_lanes
observations = np.zeros(
(num_intersections, self.observation_dim),
dtype=np.float32,
)
incoming_counts = np.zeros((num_intersections, max_lanes), dtype=np.float32)
incoming_waiting = np.zeros((num_intersections, max_lanes), dtype=np.float32)
lane_mask = np.zeros((num_intersections, max_lanes), dtype=np.float32)
action_mask = np.ones((num_intersections, 2), dtype=np.float32)
current_phase = np.zeros(num_intersections, dtype=np.int64)
phase_elapsed = np.zeros(num_intersections, dtype=np.float32)
outgoing_congestion = np.zeros(num_intersections, dtype=np.float32)
district_type_indices = np.zeros(num_intersections, dtype=np.int64)
boundary_mask = np.zeros(num_intersections, dtype=np.float32)
for row_index, intersection_id in enumerate(self.intersection_ids):
config = self.intersections[intersection_id]
lane_count_vector, waiting_vector, mask_vector = self._lane_vectors(
config=config,
lane_vehicle_count=lane_vehicle_count,
lane_waiting_count=lane_waiting_count,
)
incoming_counts[row_index] = lane_count_vector
incoming_waiting[row_index] = waiting_vector
lane_mask[row_index] = mask_vector
phase_index = int(phase_positions[intersection_id])
phase_time = float(phase_elapsed_times[intersection_id])
phase_count = max(1, config.num_green_phases)
current_phase[row_index] = phase_index
phase_elapsed[row_index] = phase_time
district_type_indices[row_index] = config.district_type_index
boundary_mask[row_index] = 1.0 if config.is_boundary else 0.0
next_col = 0
observations[row_index, next_col : next_col + max_lanes] = (
lane_count_vector / self.config.count_scale
)
next_col += max_lanes
observations[row_index, next_col : next_col + max_lanes] = (
waiting_vector / self.config.count_scale
)
next_col += max_lanes
observations[row_index, next_col : next_col + max_lanes] = mask_vector
next_col += max_lanes
if self.config.include_outgoing_congestion:
outgoing_congestion[row_index] = self._mean_outgoing_congestion(
config=config,
lane_vehicle_count=lane_vehicle_count,
)
meta_features = [
normalize_scalar(phase_index, max(1, phase_count - 1))
if phase_count > 1
else 0.0,
normalize_scalar(phase_time, self.config.elapsed_time_scale),
normalize_scalar(float(outgoing_congestion[row_index]), self.config.count_scale),
normalize_scalar(float(lane_count_vector.sum()), self.config.count_scale),
normalize_scalar(float(phase_count), 4.0),
1.0 if switch_allowed[intersection_id] else 0.0,
boundary_mask[row_index],
]
observations[row_index, next_col : next_col + len(meta_features)] = meta_features
next_col += len(meta_features)
if self.config.include_district_type_feature:
observations[row_index, next_col + config.district_type_index] = 1.0
next_col += 4
if self.config.include_district_context:
district_values = district_context.get(
config.district_id,
(0.0, 0.0),
)
observations[row_index, next_col : next_col + len(district_values)] = district_values
if not switch_allowed[intersection_id]:
action_mask[row_index, 1] = 0.0
return {
"observations": observations,
"incoming_counts": incoming_counts,
"incoming_waiting": incoming_waiting,
"lane_mask": lane_mask,
"action_mask": action_mask,
"current_phase": current_phase,
"phase_elapsed": phase_elapsed,
"outgoing_congestion": outgoing_congestion,
"boundary_mask": boundary_mask,
"district_type_indices": district_type_indices,
"district_types": tuple(
self.intersections[intersection_id].district_type
for intersection_id in self.intersection_ids
),
"district_ids": tuple(
self.intersections[intersection_id].district_id
for intersection_id in self.intersection_ids
),
"intersection_ids": self.intersection_ids,
}
def _compute_observation_dim(self) -> int:
base_dim = self.config.max_incoming_lanes * 3 + 7
if self.config.include_district_type_feature:
base_dim += 4
if self.config.include_district_context:
base_dim += 2
return base_dim
def _lane_vectors(
self,
config: IntersectionConfig,
lane_vehicle_count: dict[str, int],
lane_waiting_count: dict[str, int],
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
max_lanes = self.config.max_incoming_lanes
count_vector = np.zeros(max_lanes, dtype=np.float32)
waiting_vector = np.zeros(max_lanes, dtype=np.float32)
mask_vector = np.zeros(max_lanes, dtype=np.float32)
for lane_index, lane_id in enumerate(config.incoming_lanes[:max_lanes]):
count_vector[lane_index] = float(lane_vehicle_count.get(lane_id, 0))
waiting_vector[lane_index] = float(lane_waiting_count.get(lane_id, 0))
mask_vector[lane_index] = 1.0
return count_vector, waiting_vector, mask_vector
def _mean_outgoing_congestion(
self,
config: IntersectionConfig,
lane_vehicle_count: dict[str, int],
) -> float:
if not config.outgoing_lanes:
return 0.0
total = sum(float(lane_vehicle_count.get(lane_id, 0)) for lane_id in config.outgoing_lanes)
return total / float(len(config.outgoing_lanes))
def _compute_district_context(
self,
lane_vehicle_count: dict[str, int],
lane_waiting_count: dict[str, int],
) -> dict[str, tuple[float, float]]:
context: dict[str, tuple[float, float]] = {}
if not self.config.include_district_context:
return context
for district_id, district in self.districts.items():
total_count = 0.0
total_waiting = 0.0
for intersection_id in district.intersection_ids:
config = self.intersections[intersection_id]
total_count += sum(
float(lane_vehicle_count.get(lane_id, 0))
for lane_id in config.incoming_lanes
)
total_waiting += sum(
float(lane_waiting_count.get(lane_id, 0))
for lane_id in config.incoming_lanes
)
size = float(self._district_sizes[district_id])
context[district_id] = (
normalize_scalar(total_count / size, self.config.count_scale),
normalize_scalar(total_waiting / size, self.config.count_scale),
)
return context