Spaces:
Running
Running
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import numpy as np | |
| from env.intersection_config import DistrictConfig, IntersectionConfig | |
| from env.utils import normalize_scalar | |
| class ObservationConfig: | |
| max_incoming_lanes: int = 16 | |
| count_scale: float = 20.0 | |
| elapsed_time_scale: float = 60.0 | |
| include_outgoing_congestion: bool = True | |
| include_district_context: bool = True | |
| include_district_type_feature: bool = True | |
| class ObservationBuilder: | |
| def __init__( | |
| self, | |
| intersections: dict[str, IntersectionConfig], | |
| districts: dict[str, DistrictConfig], | |
| config: ObservationConfig | None = None, | |
| ): | |
| self.intersections = intersections | |
| self.districts = districts | |
| self.config = config or ObservationConfig() | |
| self.intersection_ids = tuple(sorted(intersections)) | |
| self._district_lookup = { | |
| intersection_id: intersections[intersection_id].district_id | |
| for intersection_id in self.intersection_ids | |
| } | |
| self._district_sizes = { | |
| district_id: max(1, len(district.intersection_ids)) | |
| for district_id, district in districts.items() | |
| } | |
| self.observation_dim = self._compute_observation_dim() | |
| def build( | |
| self, | |
| lane_vehicle_count: dict[str, int], | |
| lane_waiting_count: dict[str, int], | |
| phase_positions: dict[str, int], | |
| phase_elapsed_times: dict[str, int], | |
| switch_allowed: dict[str, bool], | |
| ) -> dict[str, np.ndarray | tuple[str, ...]]: | |
| district_context = self._compute_district_context( | |
| lane_vehicle_count=lane_vehicle_count, | |
| lane_waiting_count=lane_waiting_count, | |
| ) | |
| num_intersections = len(self.intersection_ids) | |
| max_lanes = self.config.max_incoming_lanes | |
| observations = np.zeros( | |
| (num_intersections, self.observation_dim), | |
| dtype=np.float32, | |
| ) | |
| incoming_counts = np.zeros((num_intersections, max_lanes), dtype=np.float32) | |
| incoming_waiting = np.zeros((num_intersections, max_lanes), dtype=np.float32) | |
| lane_mask = np.zeros((num_intersections, max_lanes), dtype=np.float32) | |
| action_mask = np.ones((num_intersections, 2), dtype=np.float32) | |
| current_phase = np.zeros(num_intersections, dtype=np.int64) | |
| phase_elapsed = np.zeros(num_intersections, dtype=np.float32) | |
| outgoing_congestion = np.zeros(num_intersections, dtype=np.float32) | |
| district_type_indices = np.zeros(num_intersections, dtype=np.int64) | |
| boundary_mask = np.zeros(num_intersections, dtype=np.float32) | |
| for row_index, intersection_id in enumerate(self.intersection_ids): | |
| config = self.intersections[intersection_id] | |
| lane_count_vector, waiting_vector, mask_vector = self._lane_vectors( | |
| config=config, | |
| lane_vehicle_count=lane_vehicle_count, | |
| lane_waiting_count=lane_waiting_count, | |
| ) | |
| incoming_counts[row_index] = lane_count_vector | |
| incoming_waiting[row_index] = waiting_vector | |
| lane_mask[row_index] = mask_vector | |
| phase_index = int(phase_positions[intersection_id]) | |
| phase_time = float(phase_elapsed_times[intersection_id]) | |
| phase_count = max(1, config.num_green_phases) | |
| current_phase[row_index] = phase_index | |
| phase_elapsed[row_index] = phase_time | |
| district_type_indices[row_index] = config.district_type_index | |
| boundary_mask[row_index] = 1.0 if config.is_boundary else 0.0 | |
| next_col = 0 | |
| observations[row_index, next_col : next_col + max_lanes] = ( | |
| lane_count_vector / self.config.count_scale | |
| ) | |
| next_col += max_lanes | |
| observations[row_index, next_col : next_col + max_lanes] = ( | |
| waiting_vector / self.config.count_scale | |
| ) | |
| next_col += max_lanes | |
| observations[row_index, next_col : next_col + max_lanes] = mask_vector | |
| next_col += max_lanes | |
| if self.config.include_outgoing_congestion: | |
| outgoing_congestion[row_index] = self._mean_outgoing_congestion( | |
| config=config, | |
| lane_vehicle_count=lane_vehicle_count, | |
| ) | |
| meta_features = [ | |
| normalize_scalar(phase_index, max(1, phase_count - 1)) | |
| if phase_count > 1 | |
| else 0.0, | |
| normalize_scalar(phase_time, self.config.elapsed_time_scale), | |
| normalize_scalar(float(outgoing_congestion[row_index]), self.config.count_scale), | |
| normalize_scalar(float(lane_count_vector.sum()), self.config.count_scale), | |
| normalize_scalar(float(phase_count), 4.0), | |
| 1.0 if switch_allowed[intersection_id] else 0.0, | |
| boundary_mask[row_index], | |
| ] | |
| observations[row_index, next_col : next_col + len(meta_features)] = meta_features | |
| next_col += len(meta_features) | |
| if self.config.include_district_type_feature: | |
| observations[row_index, next_col + config.district_type_index] = 1.0 | |
| next_col += 4 | |
| if self.config.include_district_context: | |
| district_values = district_context.get( | |
| config.district_id, | |
| (0.0, 0.0), | |
| ) | |
| observations[row_index, next_col : next_col + len(district_values)] = district_values | |
| if not switch_allowed[intersection_id]: | |
| action_mask[row_index, 1] = 0.0 | |
| return { | |
| "observations": observations, | |
| "incoming_counts": incoming_counts, | |
| "incoming_waiting": incoming_waiting, | |
| "lane_mask": lane_mask, | |
| "action_mask": action_mask, | |
| "current_phase": current_phase, | |
| "phase_elapsed": phase_elapsed, | |
| "outgoing_congestion": outgoing_congestion, | |
| "boundary_mask": boundary_mask, | |
| "district_type_indices": district_type_indices, | |
| "district_types": tuple( | |
| self.intersections[intersection_id].district_type | |
| for intersection_id in self.intersection_ids | |
| ), | |
| "district_ids": tuple( | |
| self.intersections[intersection_id].district_id | |
| for intersection_id in self.intersection_ids | |
| ), | |
| "intersection_ids": self.intersection_ids, | |
| } | |
| def _compute_observation_dim(self) -> int: | |
| base_dim = self.config.max_incoming_lanes * 3 + 7 | |
| if self.config.include_district_type_feature: | |
| base_dim += 4 | |
| if self.config.include_district_context: | |
| base_dim += 2 | |
| return base_dim | |
| def _lane_vectors( | |
| self, | |
| config: IntersectionConfig, | |
| lane_vehicle_count: dict[str, int], | |
| lane_waiting_count: dict[str, int], | |
| ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| max_lanes = self.config.max_incoming_lanes | |
| count_vector = np.zeros(max_lanes, dtype=np.float32) | |
| waiting_vector = np.zeros(max_lanes, dtype=np.float32) | |
| mask_vector = np.zeros(max_lanes, dtype=np.float32) | |
| for lane_index, lane_id in enumerate(config.incoming_lanes[:max_lanes]): | |
| count_vector[lane_index] = float(lane_vehicle_count.get(lane_id, 0)) | |
| waiting_vector[lane_index] = float(lane_waiting_count.get(lane_id, 0)) | |
| mask_vector[lane_index] = 1.0 | |
| return count_vector, waiting_vector, mask_vector | |
| def _mean_outgoing_congestion( | |
| self, | |
| config: IntersectionConfig, | |
| lane_vehicle_count: dict[str, int], | |
| ) -> float: | |
| if not config.outgoing_lanes: | |
| return 0.0 | |
| total = sum(float(lane_vehicle_count.get(lane_id, 0)) for lane_id in config.outgoing_lanes) | |
| return total / float(len(config.outgoing_lanes)) | |
| def _compute_district_context( | |
| self, | |
| lane_vehicle_count: dict[str, int], | |
| lane_waiting_count: dict[str, int], | |
| ) -> dict[str, tuple[float, float]]: | |
| context: dict[str, tuple[float, float]] = {} | |
| if not self.config.include_district_context: | |
| return context | |
| for district_id, district in self.districts.items(): | |
| total_count = 0.0 | |
| total_waiting = 0.0 | |
| for intersection_id in district.intersection_ids: | |
| config = self.intersections[intersection_id] | |
| total_count += sum( | |
| float(lane_vehicle_count.get(lane_id, 0)) | |
| for lane_id in config.incoming_lanes | |
| ) | |
| total_waiting += sum( | |
| float(lane_waiting_count.get(lane_id, 0)) | |
| for lane_id in config.incoming_lanes | |
| ) | |
| size = float(self._district_sizes[district_id]) | |
| context[district_id] = ( | |
| normalize_scalar(total_count / size, self.config.count_scale), | |
| normalize_scalar(total_waiting / size, self.config.count_scale), | |
| ) | |
| return context | |