File size: 9,501 Bytes
5893134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
from __future__ import annotations

from dataclasses import dataclass

import numpy as np

from env.intersection_config import DistrictConfig, IntersectionConfig
from env.utils import normalize_scalar


@dataclass(frozen=True)
class ObservationConfig:
    max_incoming_lanes: int = 16
    count_scale: float = 20.0
    elapsed_time_scale: float = 60.0
    include_outgoing_congestion: bool = True
    include_district_context: bool = True
    include_district_type_feature: bool = True


class ObservationBuilder:
    def __init__(

        self,

        intersections: dict[str, IntersectionConfig],

        districts: dict[str, DistrictConfig],

        config: ObservationConfig | None = None,

    ):
        self.intersections = intersections
        self.districts = districts
        self.config = config or ObservationConfig()
        self.intersection_ids = tuple(sorted(intersections))
        self._district_lookup = {
            intersection_id: intersections[intersection_id].district_id
            for intersection_id in self.intersection_ids
        }
        self._district_sizes = {
            district_id: max(1, len(district.intersection_ids))
            for district_id, district in districts.items()
        }
        self.observation_dim = self._compute_observation_dim()

    def build(

        self,

        lane_vehicle_count: dict[str, int],

        lane_waiting_count: dict[str, int],

        phase_positions: dict[str, int],

        phase_elapsed_times: dict[str, int],

        switch_allowed: dict[str, bool],

    ) -> dict[str, np.ndarray | tuple[str, ...]]:
        district_context = self._compute_district_context(
            lane_vehicle_count=lane_vehicle_count,
            lane_waiting_count=lane_waiting_count,
        )

        num_intersections = len(self.intersection_ids)
        max_lanes = self.config.max_incoming_lanes

        observations = np.zeros(
            (num_intersections, self.observation_dim),
            dtype=np.float32,
        )
        incoming_counts = np.zeros((num_intersections, max_lanes), dtype=np.float32)
        incoming_waiting = np.zeros((num_intersections, max_lanes), dtype=np.float32)
        lane_mask = np.zeros((num_intersections, max_lanes), dtype=np.float32)
        action_mask = np.ones((num_intersections, 2), dtype=np.float32)
        current_phase = np.zeros(num_intersections, dtype=np.int64)
        phase_elapsed = np.zeros(num_intersections, dtype=np.float32)
        outgoing_congestion = np.zeros(num_intersections, dtype=np.float32)
        district_type_indices = np.zeros(num_intersections, dtype=np.int64)
        boundary_mask = np.zeros(num_intersections, dtype=np.float32)

        for row_index, intersection_id in enumerate(self.intersection_ids):
            config = self.intersections[intersection_id]
            lane_count_vector, waiting_vector, mask_vector = self._lane_vectors(
                config=config,
                lane_vehicle_count=lane_vehicle_count,
                lane_waiting_count=lane_waiting_count,
            )
            incoming_counts[row_index] = lane_count_vector
            incoming_waiting[row_index] = waiting_vector
            lane_mask[row_index] = mask_vector

            phase_index = int(phase_positions[intersection_id])
            phase_time = float(phase_elapsed_times[intersection_id])
            phase_count = max(1, config.num_green_phases)
            current_phase[row_index] = phase_index
            phase_elapsed[row_index] = phase_time
            district_type_indices[row_index] = config.district_type_index
            boundary_mask[row_index] = 1.0 if config.is_boundary else 0.0

            next_col = 0
            observations[row_index, next_col : next_col + max_lanes] = (
                lane_count_vector / self.config.count_scale
            )
            next_col += max_lanes
            observations[row_index, next_col : next_col + max_lanes] = (
                waiting_vector / self.config.count_scale
            )
            next_col += max_lanes
            observations[row_index, next_col : next_col + max_lanes] = mask_vector
            next_col += max_lanes

            if self.config.include_outgoing_congestion:
                outgoing_congestion[row_index] = self._mean_outgoing_congestion(
                    config=config,
                    lane_vehicle_count=lane_vehicle_count,
                )

            meta_features = [
                normalize_scalar(phase_index, max(1, phase_count - 1))
                if phase_count > 1
                else 0.0,
                normalize_scalar(phase_time, self.config.elapsed_time_scale),
                normalize_scalar(float(outgoing_congestion[row_index]), self.config.count_scale),
                normalize_scalar(float(lane_count_vector.sum()), self.config.count_scale),
                normalize_scalar(float(phase_count), 4.0),
                1.0 if switch_allowed[intersection_id] else 0.0,
                boundary_mask[row_index],
            ]
            observations[row_index, next_col : next_col + len(meta_features)] = meta_features
            next_col += len(meta_features)

            if self.config.include_district_type_feature:
                observations[row_index, next_col + config.district_type_index] = 1.0
                next_col += 4

            if self.config.include_district_context:
                district_values = district_context.get(
                    config.district_id,
                    (0.0, 0.0),
                )
                observations[row_index, next_col : next_col + len(district_values)] = district_values

            if not switch_allowed[intersection_id]:
                action_mask[row_index, 1] = 0.0

        return {
            "observations": observations,
            "incoming_counts": incoming_counts,
            "incoming_waiting": incoming_waiting,
            "lane_mask": lane_mask,
            "action_mask": action_mask,
            "current_phase": current_phase,
            "phase_elapsed": phase_elapsed,
            "outgoing_congestion": outgoing_congestion,
            "boundary_mask": boundary_mask,
            "district_type_indices": district_type_indices,
            "district_types": tuple(
                self.intersections[intersection_id].district_type
                for intersection_id in self.intersection_ids
            ),
            "district_ids": tuple(
                self.intersections[intersection_id].district_id
                for intersection_id in self.intersection_ids
            ),
            "intersection_ids": self.intersection_ids,
        }

    def _compute_observation_dim(self) -> int:
        base_dim = self.config.max_incoming_lanes * 3 + 7
        if self.config.include_district_type_feature:
            base_dim += 4
        if self.config.include_district_context:
            base_dim += 2
        return base_dim

    def _lane_vectors(

        self,

        config: IntersectionConfig,

        lane_vehicle_count: dict[str, int],

        lane_waiting_count: dict[str, int],

    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        max_lanes = self.config.max_incoming_lanes
        count_vector = np.zeros(max_lanes, dtype=np.float32)
        waiting_vector = np.zeros(max_lanes, dtype=np.float32)
        mask_vector = np.zeros(max_lanes, dtype=np.float32)

        for lane_index, lane_id in enumerate(config.incoming_lanes[:max_lanes]):
            count_vector[lane_index] = float(lane_vehicle_count.get(lane_id, 0))
            waiting_vector[lane_index] = float(lane_waiting_count.get(lane_id, 0))
            mask_vector[lane_index] = 1.0

        return count_vector, waiting_vector, mask_vector

    def _mean_outgoing_congestion(

        self,

        config: IntersectionConfig,

        lane_vehicle_count: dict[str, int],

    ) -> float:
        if not config.outgoing_lanes:
            return 0.0
        total = sum(float(lane_vehicle_count.get(lane_id, 0)) for lane_id in config.outgoing_lanes)
        return total / float(len(config.outgoing_lanes))

    def _compute_district_context(

        self,

        lane_vehicle_count: dict[str, int],

        lane_waiting_count: dict[str, int],

    ) -> dict[str, tuple[float, float]]:
        context: dict[str, tuple[float, float]] = {}
        if not self.config.include_district_context:
            return context

        for district_id, district in self.districts.items():
            total_count = 0.0
            total_waiting = 0.0
            for intersection_id in district.intersection_ids:
                config = self.intersections[intersection_id]
                total_count += sum(
                    float(lane_vehicle_count.get(lane_id, 0))
                    for lane_id in config.incoming_lanes
                )
                total_waiting += sum(
                    float(lane_waiting_count.get(lane_id, 0))
                    for lane_id in config.incoming_lanes
                )

            size = float(self._district_sizes[district_id])
            context[district_id] = (
                normalize_scalar(total_count / size, self.config.count_scale),
                normalize_scalar(total_waiting / size, self.config.count_scale),
            )

        return context