Aditya2162 commited on
Commit
9d09c45
·
verified ·
1 Parent(s): 027f431

Upload folder using huggingface_hub

Browse files
data/generators/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/generators
2
+
3
+ Synthetic dataset generation code for CityFlow cities and scenarios.
4
+
5
+ ## Main files
6
+
7
+ - [generate_dataset.py](/Users/aditya/Developer/traffic-llm/data/generators/generate_dataset.py)
8
+ Primary dataset build script.
9
+ - [city_generator.py](/Users/aditya/Developer/traffic-llm/data/generators/city_generator.py)
10
+ High-level city assembly.
11
+ - [roadnet_generator.py](/Users/aditya/Developer/traffic-llm/data/generators/roadnet_generator.py)
12
+ Road network generation.
13
+ - [district_generator.py](/Users/aditya/Developer/traffic-llm/data/generators/district_generator.py)
14
+ District assignments and relationships.
15
+ - [flow_generator.py](/Users/aditya/Developer/traffic-llm/data/generators/flow_generator.py)
16
+ Vehicle flow generation.
17
+ - [scenario_generator.py](/Users/aditya/Developer/traffic-llm/data/generators/scenario_generator.py)
18
+ Scenario-specific flow and config generation.
19
+
20
+ ## Notes
21
+
22
+ - This folder is for offline dataset creation.
23
+ - The current training pipeline consumes the generated files directly and does not regenerate them.
data/generators/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Synthetic CityFlow dataset generation package."""
2
+
3
+ from .city_generator import CityGenerator
4
+ from .schemas import DatasetGenerationConfig
5
+
6
+ __all__ = ["CityGenerator", "DatasetGenerationConfig"]
data/generators/city_generator.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """City-level orchestration: topology, districts, scenarios, flows, configs, validation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ from dataclasses import asdict
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from .config_generator import ConfigGenerator
11
+ from .district_generator import DistrictGenerator
12
+ from .flow_generator import FlowGenerator
13
+ from .roadnet_generator import RoadnetGenerator
14
+ from .scenario_generator import ScenarioGenerator
15
+ from .schemas import DatasetGenerationConfig, TopologyType
16
+ from .utils import (
17
+ build_road_index,
18
+ build_roadlink_index,
19
+ clamp,
20
+ compute_scenario_diagnostics,
21
+ ensure_dir,
22
+ summarize_route_validation,
23
+ validate_district_contiguity,
24
+ validate_district_exit_capacity,
25
+ validate_inter_district_connectivity,
26
+ validate_unique_ids,
27
+ write_json,
28
+ )
29
+
30
+
31
+ class CityGenerator:
32
+ """Generate one or many synthetic cities with scenario-specific CityFlow files."""
33
+
34
+ def __init__(self) -> None:
35
+ self.roadnet_generator = RoadnetGenerator()
36
+ self.district_generator = DistrictGenerator()
37
+ self.scenario_generator = ScenarioGenerator()
38
+ self.flow_generator = FlowGenerator()
39
+ self.config_generator = ConfigGenerator()
40
+
41
+ def generate_dataset(self, config: DatasetGenerationConfig) -> None:
42
+ ensure_dir(config.output_dir)
43
+ failures: list[tuple[str, str]] = []
44
+ for idx in range(config.num_cities):
45
+ city_id = f"city_{idx + 1:04d}"
46
+ city_seed = config.seed + idx * 10_003
47
+ try:
48
+ self.generate_city(
49
+ city_id=city_id,
50
+ output_dir=config.output_dir / city_id,
51
+ config=config,
52
+ city_seed=city_seed,
53
+ )
54
+ except Exception as exc:
55
+ failures.append((city_id, str(exc)))
56
+ if config.fail_fast:
57
+ raise
58
+ if failures:
59
+ details = "; ".join(
60
+ f"{city}: {message}" for city, message in failures[:5]
61
+ )
62
+ raise RuntimeError(
63
+ f"Dataset generation failed for {len(failures)} city/cities. {details}"
64
+ )
65
+
66
+ def generate_city(
67
+ self,
68
+ city_id: str,
69
+ output_dir: Path,
70
+ config: DatasetGenerationConfig,
71
+ city_seed: int,
72
+ ) -> None:
73
+ ensure_dir(output_dir)
74
+ rng = random.Random(city_seed)
75
+ topology_pool: list[TopologyType] = list(config.topologies)
76
+ if not topology_pool:
77
+ raise ValueError("No topology families provided in configuration.")
78
+
79
+ attempts_per_topology = 10
80
+ ordered_topologies = topology_pool.copy()
81
+ rng.shuffle(ordered_topologies)
82
+ max_attempts = attempts_per_topology * len(ordered_topologies)
83
+ attempt_count = 0
84
+ city_graph = None
85
+ district_data = None
86
+ last_error: Exception | None = None
87
+ for topology in ordered_topologies:
88
+ for topology_attempt in range(attempts_per_topology):
89
+ attempt_count += 1
90
+ attempt_seed = city_seed + ((attempt_count - 1) * 1009)
91
+ target_intersections = clamp(
92
+ rng.randint(
93
+ config.min_districts * config.min_intersections_per_district,
94
+ config.max_districts * config.max_intersections_per_district,
95
+ ),
96
+ low=config.min_districts * config.min_intersections_per_district,
97
+ high=config.max_districts * config.max_intersections_per_district + 36,
98
+ )
99
+ try:
100
+ city_graph = self.roadnet_generator.generate(
101
+ city_id=city_id,
102
+ seed=attempt_seed,
103
+ topology=topology,
104
+ target_intersections=target_intersections,
105
+ ring_diagonal_keep_prob=config.ring_diagonal_keep_prob,
106
+ ring_max_diagonal_fraction=config.ring_max_diagonal_fraction,
107
+ )
108
+ validate_unique_ids(city_graph)
109
+ max_districts = max(
110
+ config.min_districts,
111
+ min(
112
+ config.max_districts,
113
+ max(
114
+ 2,
115
+ len(
116
+ [
117
+ nid
118
+ for nid in city_graph.intersections
119
+ if nid not in city_graph.gateway_intersections
120
+ ]
121
+ )
122
+ // max(1, config.min_intersections_per_district),
123
+ ),
124
+ ),
125
+ )
126
+ if topology == "ring_road":
127
+ max_districts = min(max_districts, max(6, target_intersections // 10))
128
+ max_districts = max(config.min_districts, max_districts)
129
+ num_districts = rng.randint(config.min_districts, max_districts)
130
+
131
+ district_data = self.district_generator.generate(
132
+ city_graph=city_graph,
133
+ num_districts=num_districts,
134
+ seed=attempt_seed + 17,
135
+ )
136
+ validate_district_contiguity(city_graph, district_data)
137
+ validate_inter_district_connectivity(district_data)
138
+ min_exit_roads = 2 if topology == "ring_road" else 3
139
+ min_entry_roads = 2 if topology == "ring_road" else 3
140
+ min_neighbor_districts = 1 if topology == "ring_road" else 2
141
+ validate_district_exit_capacity(
142
+ district_data=district_data,
143
+ min_exit_roads=min_exit_roads,
144
+ min_entry_roads=min_entry_roads,
145
+ min_neighbor_districts=min_neighbor_districts,
146
+ )
147
+ print(
148
+ f"[INFO] {city_id} attempt={attempt_count} "
149
+ f"topology={topology} topology_try={topology_attempt + 1}/{attempts_per_topology} "
150
+ "generated successfully"
151
+ )
152
+ break
153
+ except Exception as exc:
154
+ message = str(exc)
155
+ print(
156
+ f"[WARN] {city_id} attempt={attempt_count} "
157
+ f"topology={topology} topology_try={topology_attempt + 1}/{attempts_per_topology} "
158
+ f"failed: {message}"
159
+ )
160
+ last_error = exc
161
+ city_graph = None
162
+ district_data = None
163
+ continue
164
+
165
+ if city_graph is not None and district_data is not None:
166
+ break
167
+
168
+ if city_graph is None or district_data is None:
169
+ raise ValueError(
170
+ f"Unable to produce a structurally valid city after {max_attempts} attempts: {last_error}"
171
+ )
172
+
173
+ roadnet_path = output_dir / "roadnet.json"
174
+ write_json(roadnet_path, city_graph.roadnet)
175
+
176
+ district_map = {
177
+ "intersection_to_district": district_data.intersection_to_district,
178
+ "district_neighbors": district_data.district_neighbors,
179
+ "boundary_intersections": district_data.boundary_intersections,
180
+ "gateway_intersections": sorted(city_graph.gateway_intersections),
181
+ "gateway_roads": sorted(city_graph.gateway_roads),
182
+ "districts": [
183
+ {
184
+ "id": d.id,
185
+ "type": d.district_type,
186
+ "intersections": d.intersections,
187
+ "neighbors": d.neighbors,
188
+ "boundary_intersections": d.boundary_intersections,
189
+ "entry_roads": d.entry_roads,
190
+ "exit_roads": d.exit_roads,
191
+ }
192
+ for d in district_data.districts.values()
193
+ ],
194
+ }
195
+ write_json(output_dir / "district_map.json", district_map)
196
+
197
+ metadata = self._city_metadata(
198
+ city_id=city_id,
199
+ topology=topology,
200
+ city_seed=city_seed,
201
+ city_graph=city_graph,
202
+ district_data=district_data,
203
+ config=config,
204
+ )
205
+ write_json(output_dir / "metadata.json", metadata)
206
+
207
+ print(f"[INFO] {city_id} generated: topology={topology}, districts={len(district_data.districts)}")
208
+
209
+ scenario_plans = self.scenario_generator.generate(
210
+ city_graph=city_graph,
211
+ district_data=district_data,
212
+ scenario_names=config.scenarios,
213
+ base_seed=city_seed + 1000,
214
+ config=config,
215
+ )
216
+ self._generate_scenarios(
217
+ output_dir=output_dir,
218
+ city_graph=city_graph,
219
+ district_data=district_data,
220
+ scenario_plans=scenario_plans,
221
+ config=config,
222
+ roadnet_path=roadnet_path,
223
+ )
224
+
225
+ def _generate_scenarios(
226
+ self,
227
+ output_dir: Path,
228
+ city_graph: Any,
229
+ district_data: Any,
230
+ scenario_plans: dict[str, Any],
231
+ config: DatasetGenerationConfig,
232
+ roadnet_path: Path,
233
+ ) -> None:
234
+ roads_by_id = build_road_index(city_graph.roadnet)
235
+ roadlinks_by_intersection = build_roadlink_index(city_graph.roadnet)
236
+ scenarios_dir = output_dir / "scenarios"
237
+ ensure_dir(scenarios_dir)
238
+
239
+ for scenario_name, plan in scenario_plans.items():
240
+ scenario_dir = scenarios_dir / scenario_name
241
+ ensure_dir(scenario_dir)
242
+ flows = self.flow_generator.generate(
243
+ city_graph=city_graph,
244
+ district_data=district_data,
245
+ scenario=plan,
246
+ simulation_steps=config.simulation_steps,
247
+ )
248
+ validation_summary = summarize_route_validation(
249
+ flow_entries=flows,
250
+ roads_by_id=roads_by_id,
251
+ roadlinks_by_intersection=roadlinks_by_intersection,
252
+ )
253
+ if validation_summary["invalid_routes"] > 0:
254
+ reasons = ", ".join(
255
+ f"{reason}={count}"
256
+ for reason, count in validation_summary["top_failure_reasons"]
257
+ )
258
+ raise ValueError(
259
+ f"{scenario_name} contains invalid routes after generation: {reasons}"
260
+ )
261
+ write_json(scenario_dir / "flow.json", flows)
262
+ diagnostics = compute_scenario_diagnostics(
263
+ flow_entries=flows,
264
+ city_graph=city_graph,
265
+ district_data=district_data,
266
+ )
267
+
268
+ config_payload = self.config_generator.generate(
269
+ simulation_steps=config.simulation_steps,
270
+ interval=config.interval,
271
+ seed=plan.seed,
272
+ save_replay=config.save_replay,
273
+ roadnet_file=roadnet_path,
274
+ flow_file=scenario_dir / "flow.json",
275
+ scenario_dir=scenario_dir,
276
+ )
277
+ write_json(scenario_dir / "config.json", config_payload)
278
+ write_json(
279
+ scenario_dir / "scenario_metadata.json",
280
+ {
281
+ "name": scenario_name,
282
+ "intensity": plan.intensity,
283
+ "seed": plan.seed,
284
+ "trip_multiplier": plan.trip_multiplier,
285
+ "trip_mix": asdict(plan.trip_mix),
286
+ "blocked_roads": sorted(plan.blocked_roads),
287
+ "penalized_roads": plan.penalized_roads,
288
+ "event_district": plan.event_district,
289
+ "overload_district": plan.overload_district,
290
+ "diagnostics": diagnostics,
291
+ "details": plan.metadata,
292
+ },
293
+ )
294
+
295
+ def _city_metadata(
296
+ self,
297
+ city_id: str,
298
+ topology: TopologyType,
299
+ city_seed: int,
300
+ city_graph: Any,
301
+ district_data: Any,
302
+ config: DatasetGenerationConfig,
303
+ ) -> dict[str, Any]:
304
+ total_lanes = sum(
305
+ road.num_lanes for road in city_graph.directed_roads.values()
306
+ )
307
+ district_types = {
308
+ did: district.district_type
309
+ for did, district in district_data.districts.items()
310
+ }
311
+ return {
312
+ "city_id": city_id,
313
+ "topology": topology,
314
+ "seed": city_seed,
315
+ "counts": {
316
+ "intersections": len(city_graph.intersections),
317
+ "roads": len(city_graph.directed_roads),
318
+ "lanes": total_lanes,
319
+ "districts": len(district_data.districts),
320
+ },
321
+ "district_types": district_types,
322
+ "district_adjacency_graph": district_data.district_neighbors,
323
+ "inter_district_connector_roads": district_data.inter_district_roads,
324
+ "arterial_roads": sorted(city_graph.arterial_roads),
325
+ "gateway_intersections": sorted(city_graph.gateway_intersections),
326
+ "gateway_connector_roads": sorted(city_graph.gateway_roads),
327
+ "generation_parameters": {
328
+ "min_districts": config.min_districts,
329
+ "max_districts": config.max_districts,
330
+ "min_intersections_per_district": config.min_intersections_per_district,
331
+ "max_intersections_per_district": config.max_intersections_per_district,
332
+ "simulation_steps": config.simulation_steps,
333
+ "interval": config.interval,
334
+ "intensity_levels": config.intensity_levels,
335
+ "global_demand_multiplier": config.global_demand_multiplier,
336
+ "intensity_distribution": config.intensity_distribution,
337
+ "scenario_demand_multipliers": config.scenario_demand_multipliers,
338
+ "ring_diagonal_keep_prob": config.ring_diagonal_keep_prob,
339
+ "ring_max_diagonal_fraction": config.ring_max_diagonal_fraction,
340
+ },
341
+ }
data/generators/config_generator.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CityFlow config generation for per-scenario simulation runs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ import os
7
+ from typing import Any
8
+
9
+
10
+ class ConfigGenerator:
11
+ """Build CityFlow-compatible config.json payloads."""
12
+
13
+ def generate(
14
+ self,
15
+ simulation_steps: int,
16
+ interval: float,
17
+ seed: int,
18
+ save_replay: bool,
19
+ roadnet_file: Path,
20
+ flow_file: Path,
21
+ scenario_dir: Path,
22
+ ) -> dict[str, Any]:
23
+ # CityFlow expects a small signed integer for seed.
24
+ safe_seed = int(seed) & 0x7FFFFFFF
25
+ # Use absolute paths so CityFlow can resolve files regardless of working dir.
26
+ roadnet_path = roadnet_file.resolve()
27
+ flow_path = flow_file.resolve()
28
+ scenario_path = scenario_dir.resolve()
29
+ base_dir = roadnet_path.parent
30
+ roadnet_rel = os.path.relpath(roadnet_path, base_dir)
31
+ flow_rel = os.path.relpath(flow_path, base_dir)
32
+ flow_rel_dir = Path(flow_rel).parent
33
+ roadnet_log_rel = str(flow_rel_dir / "roadnetLogFile.json")
34
+ replay_log_rel = str(flow_rel_dir / "replay.txt")
35
+ dir_str = str(base_dir)
36
+ if not dir_str.endswith(os.sep):
37
+ dir_str = dir_str + os.sep
38
+ return {
39
+ "interval": interval,
40
+ "seed": safe_seed,
41
+ "dir": dir_str,
42
+ "roadnetFile": roadnet_rel,
43
+ "flowFile": flow_rel,
44
+ "rlTrafficLight": True,
45
+ "laneChange": False,
46
+ "saveReplay": save_replay,
47
+ "roadnetLogFile": roadnet_log_rel,
48
+ "replayLogFile": replay_log_rel,
49
+ "step": simulation_steps,
50
+ }
data/generators/district_generator.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """District partitioning and district-level metadata generation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ from collections import Counter, defaultdict, deque
7
+
8
+ from .schemas import CityGraph, DistrictData, DistrictRecord, DistrictType
9
+ from .utils import connected_components, euclidean
10
+
11
+
12
+ class DistrictGenerator:
13
+ """Generate contiguous district partitions over the city intersection graph."""
14
+
15
+ DISTRICT_TYPE_WEIGHTS: dict[DistrictType, float] = {
16
+ "residential": 0.35,
17
+ "commercial": 0.25,
18
+ "industrial": 0.20,
19
+ "mixed": 0.20,
20
+ }
21
+
22
+ def generate(
23
+ self,
24
+ city_graph: CityGraph,
25
+ num_districts: int,
26
+ seed: int,
27
+ ) -> DistrictData:
28
+ rng = random.Random(seed)
29
+ node_ids = sorted(
30
+ n
31
+ for n in city_graph.intersections.keys()
32
+ if n not in city_graph.gateway_intersections
33
+ )
34
+ if len(node_ids) < 2:
35
+ raise ValueError("Insufficient non-gateway intersections for districting.")
36
+ if num_districts >= len(node_ids):
37
+ num_districts = max(2, len(node_ids) // 2)
38
+
39
+ local_coords = {nid: city_graph.intersections[nid] for nid in node_ids}
40
+ seeds = self._farthest_seeds(local_coords, num_districts, rng)
41
+ assignment = self._grow_contiguous_regions(
42
+ local_coords=local_coords,
43
+ adjacency=city_graph.adjacency,
44
+ seeds=seeds,
45
+ rng=rng,
46
+ )
47
+ assignment = self._enforce_contiguity(
48
+ assignment=assignment,
49
+ adjacency=city_graph.adjacency,
50
+ coords=local_coords,
51
+ )
52
+ assignment = self._fill_empty_districts(
53
+ assignment=assignment,
54
+ node_ids=node_ids,
55
+ adjacency=city_graph.adjacency,
56
+ district_ids=list(seeds.keys()),
57
+ )
58
+
59
+ district_neighbors: dict[str, set[str]] = {
60
+ did: set() for did in seeds.keys()
61
+ }
62
+ boundary: set[str] = set()
63
+ for a, neighbors in city_graph.adjacency.items():
64
+ if a not in assignment:
65
+ continue
66
+ da = assignment[a]
67
+ for b in neighbors:
68
+ if b not in assignment:
69
+ continue
70
+ db = assignment[b]
71
+ if da != db:
72
+ district_neighbors[da].add(db)
73
+ boundary.add(a)
74
+
75
+ entry_roads: dict[str, list[str]] = defaultdict(list)
76
+ exit_roads: dict[str, list[str]] = defaultdict(list)
77
+ inter_district_roads: set[str] = set()
78
+ for road in city_graph.directed_roads.values():
79
+ if (
80
+ road.start_intersection not in assignment
81
+ or road.end_intersection not in assignment
82
+ ):
83
+ continue
84
+ ds = assignment[road.start_intersection]
85
+ de = assignment[road.end_intersection]
86
+ if ds != de:
87
+ inter_district_roads.add(road.id)
88
+ exit_roads[ds].append(road.id)
89
+ entry_roads[de].append(road.id)
90
+ city_graph.inter_district_roads = inter_district_roads
91
+
92
+ district_records: dict[str, DistrictRecord] = {}
93
+ type_values = list(self.DISTRICT_TYPE_WEIGHTS.keys())
94
+ type_weights = list(self.DISTRICT_TYPE_WEIGHTS.values())
95
+ for district_id in seeds:
96
+ members = sorted([n for n, d in assignment.items() if d == district_id])
97
+ d_boundary = sorted([n for n in members if n in boundary])
98
+ district_type = rng.choices(type_values, weights=type_weights, k=1)[0]
99
+ district_records[district_id] = DistrictRecord(
100
+ id=district_id,
101
+ district_type=district_type,
102
+ intersections=members,
103
+ neighbors=sorted(district_neighbors[district_id]),
104
+ boundary_intersections=d_boundary,
105
+ entry_roads=sorted(set(entry_roads[district_id])),
106
+ exit_roads=sorted(set(exit_roads[district_id])),
107
+ )
108
+
109
+ return DistrictData(
110
+ intersection_to_district=assignment,
111
+ districts=district_records,
112
+ district_neighbors={
113
+ k: sorted(v) for k, v in district_neighbors.items()
114
+ },
115
+ boundary_intersections=sorted(boundary),
116
+ inter_district_roads=sorted(inter_district_roads),
117
+ )
118
+
119
+ def _farthest_seeds(
120
+ self,
121
+ coords: dict[str, tuple[float, float]],
122
+ num_districts: int,
123
+ rng: random.Random,
124
+ ) -> dict[str, str]:
125
+ nodes = sorted(coords.keys())
126
+ first = rng.choice(nodes)
127
+ selected = [first]
128
+
129
+ while len(selected) < num_districts:
130
+ best_node = None
131
+ best_dist = -1.0
132
+ for node in nodes:
133
+ if node in selected:
134
+ continue
135
+ nearest = min(
136
+ euclidean(coords[node], coords[s]) for s in selected
137
+ )
138
+ if nearest > best_dist:
139
+ best_dist = nearest
140
+ best_node = node
141
+ if best_node is None:
142
+ break
143
+ selected.append(best_node)
144
+ return {f"d_{idx:02d}": node for idx, node in enumerate(selected)}
145
+
146
+ def _assign_nearest(
147
+ self,
148
+ coords: dict[str, tuple[float, float]],
149
+ seeds: dict[str, str],
150
+ ) -> dict[str, str]:
151
+ assignment: dict[str, str] = {}
152
+ for node, point in coords.items():
153
+ district = min(
154
+ seeds.keys(),
155
+ key=lambda did: euclidean(point, coords[seeds[did]]),
156
+ )
157
+ assignment[node] = district
158
+ return assignment
159
+
160
+ def _grow_contiguous_regions(
161
+ self,
162
+ local_coords: dict[str, tuple[float, float]],
163
+ adjacency: dict[str, set[str]],
164
+ seeds: dict[str, str],
165
+ rng: random.Random,
166
+ ) -> dict[str, str]:
167
+ districts = list(seeds.keys())
168
+ district_sizes = {district_id: 1 for district_id in districts}
169
+ assignment: dict[str, str] = {seed_node: district_id for district_id, seed_node in seeds.items()}
170
+ frontiers: dict[str, deque[str]] = {
171
+ district_id: deque([seed_node]) for district_id, seed_node in seeds.items()
172
+ }
173
+ remaining = set(local_coords.keys()) - set(assignment.keys())
174
+
175
+ if not remaining:
176
+ return assignment
177
+
178
+ target_avg = max(1, len(local_coords) // len(districts))
179
+ target_limits = {
180
+ district_id: target_avg + 2 for district_id in districts
181
+ }
182
+ overcap = 0
183
+
184
+ # Expand from multiple district frontiers to ensure contiguity by construction.
185
+ frontier_order = deque(districts)
186
+ while remaining:
187
+ if not frontier_order:
188
+ break
189
+ district_id = frontier_order.popleft()
190
+ current_frontier = frontiers[district_id]
191
+ if not current_frontier:
192
+ continue
193
+
194
+ source = current_frontier.popleft()
195
+ neighbors = [n for n in adjacency.get(source, set()) if n in remaining]
196
+ rng.shuffle(neighbors)
197
+
198
+ expanded = False
199
+ for neighbor in neighbors:
200
+ if neighbor not in remaining:
201
+ continue
202
+ can_expand = (
203
+ overcap > 3
204
+ or district_sizes[district_id] < target_limits[district_id]
205
+ )
206
+ if not can_expand and overcap <= 3:
207
+ continue
208
+ assignment[neighbor] = district_id
209
+ remaining.remove(neighbor)
210
+ current_frontier.append(neighbor)
211
+ district_sizes[district_id] += 1
212
+ frontier_order.append(district_id)
213
+ frontier_order.append(district_id)
214
+ expanded = True
215
+ break
216
+
217
+ if expanded:
218
+ continue
219
+
220
+ # If all districts reached targets, allow unrestricted growth to consume leftovers.
221
+ overcap += 1
222
+ for fallback_neighbor in neighbors:
223
+ if fallback_neighbor not in remaining:
224
+ continue
225
+ assignment[fallback_neighbor] = district_id
226
+ remaining.remove(fallback_neighbor)
227
+ current_frontier.append(fallback_neighbor)
228
+ district_sizes[district_id] += 1
229
+ frontier_order.append(district_id)
230
+ break
231
+
232
+ if not expanded and not current_frontier:
233
+ # keep exploring this district only if it can still absorb nodes.
234
+ if all(size >= target_limits[d] for d, size in district_sizes.items()):
235
+ continue
236
+
237
+ frontier_order.append(district_id)
238
+
239
+ if overcap > 10_000:
240
+ # Safety break for unexpected stalling.
241
+ break
242
+
243
+ # If anything remains because of local disconnectedness in the non-gateway subgraph,
244
+ # assign by nearest-seed fallback and rely on contiguity enforcement later.
245
+ if remaining:
246
+ fallback = self._assign_nearest(local_coords, seeds)
247
+ for node in remaining:
248
+ assignment[node] = fallback[node]
249
+ return assignment
250
+
251
+ def _enforce_contiguity(
252
+ self,
253
+ assignment: dict[str, str],
254
+ adjacency: dict[str, set[str]],
255
+ coords: dict[str, tuple[float, float]],
256
+ ) -> dict[str, str]:
257
+ district_ids = sorted(set(assignment.values()))
258
+ changed = True
259
+ while changed:
260
+ changed = False
261
+ for district_id in district_ids:
262
+ nodes = [n for n, d in assignment.items() if d == district_id]
263
+ if len(nodes) <= 1:
264
+ continue
265
+ comps = connected_components(nodes, adjacency)
266
+ if len(comps) <= 1:
267
+ continue
268
+ comps.sort(key=len, reverse=True)
269
+ keep = comps[0]
270
+ for comp in comps[1:]:
271
+ for node in comp:
272
+ reassigned = self._best_neighbor_district(
273
+ node=node,
274
+ assignment=assignment,
275
+ adjacency=adjacency,
276
+ coords=coords,
277
+ )
278
+ if reassigned != district_id:
279
+ assignment[node] = reassigned
280
+ changed = True
281
+ return assignment
282
+
283
+ def _best_neighbor_district(
284
+ self,
285
+ node: str,
286
+ assignment: dict[str, str],
287
+ adjacency: dict[str, set[str]],
288
+ coords: dict[str, tuple[float, float]],
289
+ ) -> str:
290
+ neighbors = [n for n in adjacency[node] if n in assignment]
291
+ if not neighbors:
292
+ return assignment[node]
293
+ counts = Counter(assignment[n] for n in neighbors)
294
+ best = counts.most_common(1)[0][0]
295
+ if len(counts) == 1:
296
+ return best
297
+ best_score = float("inf")
298
+ best_district = best
299
+ for district in counts.keys():
300
+ district_nodes = [n for n, d in assignment.items() if d == district]
301
+ if not district_nodes:
302
+ continue
303
+ centroid_x = sum(coords[n][0] for n in district_nodes) / len(district_nodes)
304
+ centroid_y = sum(coords[n][1] for n in district_nodes) / len(district_nodes)
305
+ dist = euclidean(coords[node], (centroid_x, centroid_y))
306
+ if dist < best_score:
307
+ best_score = dist
308
+ best_district = district
309
+ return best_district
310
+
311
+ def _fill_empty_districts(
312
+ self,
313
+ assignment: dict[str, str],
314
+ node_ids: list[str],
315
+ adjacency: dict[str, set[str]],
316
+ district_ids: list[str],
317
+ ) -> dict[str, str]:
318
+ # This method now mainly keeps a lower bound on singleton-heavy districts.
319
+ # Keep existing behavior if an empty district somehow appears.
320
+ counts = Counter(assignment.values())
321
+ empty = [d for d in district_ids if counts[d] == 0]
322
+ if not empty:
323
+ return assignment
324
+
325
+ for empty_id in empty:
326
+ largest = max(district_ids, key=lambda d: counts[d])
327
+ donor_candidates = [n for n in node_ids if assignment[n] == largest]
328
+ if not donor_candidates:
329
+ continue
330
+ pivot = donor_candidates[0]
331
+ for candidate in donor_candidates:
332
+ if any(
333
+ assignment[n] != largest and assignment[n] != empty_id
334
+ for n in adjacency.get(candidate, set())
335
+ if n in assignment
336
+ ):
337
+ pivot = candidate
338
+ break
339
+ assignment[pivot] = empty_id
340
+ counts[largest] -= 1
341
+ counts[empty_id] += 1
342
+
343
+ # For any missing district, steal a boundary node from the largest district.
344
+ for district in district_ids:
345
+ if counts[district] > 1:
346
+ continue
347
+ root = next((n for n in node_ids if assignment[n] == district), None)
348
+ if root is None:
349
+ continue
350
+ queue = deque([root])
351
+ while queue and counts[district] < 2:
352
+ current = queue.popleft()
353
+ for candidate in node_ids:
354
+ if assignment[candidate] == district:
355
+ continue
356
+ if candidate not in adjacency.get(current, set()):
357
+ continue
358
+ old = assignment[candidate]
359
+ if counts[old] <= 2:
360
+ continue
361
+ assignment[candidate] = district
362
+ counts[old] -= 1
363
+ counts[district] += 1
364
+ queue.append(candidate)
365
+ if counts[district] >= 2:
366
+ break
367
+ return assignment
data/generators/flow_generator.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Flow generation with district-aware O-D pressure and turn-feasible routing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import heapq
6
+ import math
7
+ import random
8
+ from collections import defaultdict
9
+ from typing import Any
10
+
11
+ from .schemas import CityGraph, DistrictData, ScenarioPlan
12
+ from .utils import (
13
+ build_road_index,
14
+ build_roadlink_index,
15
+ choose_weighted,
16
+ summarize_route_validation,
17
+ validate_route_with_reasons,
18
+ )
19
+
20
+
21
+ class FlowGenerator:
22
+ """Create high-pressure CityFlow flow entries for each scenario."""
23
+
24
+ DISTRICT_BASE_PRODUCTION = {
25
+ "residential": 1.70,
26
+ "commercial": 0.85,
27
+ "industrial": 0.95,
28
+ "mixed": 1.15,
29
+ }
30
+ DISTRICT_BASE_ATTRACTION = {
31
+ "residential": 0.90,
32
+ "commercial": 1.85,
33
+ "industrial": 1.55,
34
+ "mixed": 1.25,
35
+ }
36
+
37
+ INTENSITY_SCALE = {
38
+ "normal": 1.0,
39
+ "moderate_rush": 1.22,
40
+ "heavy_rush": 1.45,
41
+ "overload": 1.75,
42
+ "accident_overload": 2.0,
43
+ }
44
+
45
+ def generate(
46
+ self,
47
+ city_graph: CityGraph,
48
+ district_data: DistrictData,
49
+ scenario: ScenarioPlan,
50
+ simulation_steps: int,
51
+ ) -> list[dict[str, Any]]:
52
+ rng = random.Random(scenario.seed)
53
+ district_ids = sorted(district_data.districts.keys())
54
+ if len(district_ids) < 2:
55
+ raise ValueError("Flow generation requires at least 2 districts.")
56
+
57
+ base_density = int(scenario.metadata.get("base_demand_per_intersection", 36))
58
+ trip_total = int(
59
+ len(city_graph.intersections)
60
+ * base_density
61
+ * scenario.trip_multiplier
62
+ )
63
+ trip_total = max(260, min(42000, trip_total))
64
+
65
+ routing_state = self._build_routing_state(city_graph, scenario)
66
+ district_features = self._build_district_features(city_graph, district_data)
67
+ connector_counts = self._build_connector_counts(city_graph, district_data)
68
+ production, attraction = self._district_weights(
69
+ city_graph=city_graph,
70
+ district_data=district_data,
71
+ district_features=district_features,
72
+ scenario=scenario,
73
+ )
74
+ gateway_context = self._build_gateway_context(city_graph, district_data)
75
+ external_share = self._external_trip_share(scenario, has_gateways=bool(gateway_context["gateways"]))
76
+
77
+ flows: list[dict[str, Any]] = []
78
+ max_global_sampling_attempts = max(5000, trip_total * 8)
79
+ global_attempts = 0
80
+ while len(flows) < trip_total and global_attempts < max_global_sampling_attempts:
81
+ global_attempts += 1
82
+ route = None
83
+ max_attempts = 45
84
+ for _ in range(max_attempts):
85
+ external_mode = self._sample_external_mode(
86
+ rng=rng,
87
+ external_share=external_share,
88
+ )
89
+ category = self._sample_trip_category(rng, scenario.trip_mix)
90
+ origin_node: str
91
+ destination_node: str
92
+ origin_district: str
93
+ destination_district: str
94
+
95
+ if external_mode == "inbound":
96
+ gateway = self._sample_gateway_for_inbound(
97
+ rng=rng,
98
+ gateway_context=gateway_context,
99
+ attraction_weights=attraction,
100
+ )
101
+ if gateway is None:
102
+ continue
103
+ destination_district = self._sample_origin_district(
104
+ rng=rng,
105
+ district_ids=district_ids,
106
+ weights=attraction,
107
+ )
108
+ origin_node = gateway
109
+ destination_node = self._sample_intersection_in_district(
110
+ rng=rng,
111
+ district_data=district_data,
112
+ district_id=destination_district,
113
+ favor_boundary=False,
114
+ excluded_nodes=gateway_context["anchor_nodes"],
115
+ )
116
+ elif external_mode == "outbound":
117
+ origin_district = self._sample_origin_district(
118
+ rng=rng,
119
+ district_ids=district_ids,
120
+ weights=production,
121
+ )
122
+ gateway = self._sample_gateway_for_outbound(
123
+ rng=rng,
124
+ origin_district=origin_district,
125
+ gateway_context=gateway_context,
126
+ connector_counts=connector_counts,
127
+ )
128
+ if gateway is None:
129
+ continue
130
+ origin_node = self._sample_intersection_in_district(
131
+ rng=rng,
132
+ district_data=district_data,
133
+ district_id=origin_district,
134
+ favor_boundary=True,
135
+ excluded_nodes=gateway_context["anchor_nodes"],
136
+ )
137
+ destination_node = gateway
138
+ else:
139
+ origin_district = self._sample_origin_district(
140
+ rng=rng,
141
+ district_ids=district_ids,
142
+ weights=production,
143
+ )
144
+ destination_district = self._sample_destination_district(
145
+ rng=rng,
146
+ city_graph=city_graph,
147
+ origin_district=origin_district,
148
+ category=category,
149
+ district_data=district_data,
150
+ district_features=district_features,
151
+ district_ids=district_ids,
152
+ attraction_weights=attraction,
153
+ connector_counts=connector_counts,
154
+ )
155
+ origin_node = self._sample_intersection_in_district(
156
+ rng=rng,
157
+ district_data=district_data,
158
+ district_id=origin_district,
159
+ favor_boundary=(category != "intra"),
160
+ excluded_nodes=set(),
161
+ )
162
+ destination_node = self._sample_intersection_in_district(
163
+ rng=rng,
164
+ district_data=district_data,
165
+ district_id=destination_district,
166
+ favor_boundary=(category == "long"),
167
+ excluded_nodes=set(),
168
+ )
169
+ if origin_node == destination_node:
170
+ continue
171
+
172
+ route = self._find_turn_feasible_route(
173
+ start_intersection=origin_node,
174
+ end_intersection=destination_node,
175
+ road_lookup=routing_state["road_lookup"],
176
+ start_roads_by_intersection=routing_state["start_roads_by_intersection"],
177
+ transitions=routing_state["transitions"],
178
+ road_cost=routing_state["road_cost"],
179
+ )
180
+ if not route:
181
+ continue
182
+ reasons = validate_route_with_reasons(
183
+ route=route,
184
+ roads_by_id=routing_state["road_lookup"],
185
+ roadlinks_by_intersection=routing_state["roadlinks_by_intersection"],
186
+ )
187
+ if reasons:
188
+ route = None
189
+ continue
190
+ break
191
+
192
+ if route:
193
+ start_time = self._sample_departure(rng, scenario, simulation_steps)
194
+ flows.append(
195
+ {
196
+ "vehicle": {
197
+ "length": 5.0,
198
+ "width": 2.0,
199
+ "maxPosAcc": 2.2,
200
+ "maxNegAcc": 4.6,
201
+ "usualPosAcc": 2.0,
202
+ "usualNegAcc": 4.2,
203
+ "minGap": 2.2,
204
+ "maxSpeed": 13.89,
205
+ "headwayTime": 1.4,
206
+ },
207
+ "route": route,
208
+ "interval": 1.0,
209
+ "startTime": start_time,
210
+ "endTime": start_time,
211
+ }
212
+ )
213
+
214
+ completion_ratio = len(flows) / max(1, trip_total)
215
+ min_completion_ratio = (
216
+ 0.55 if scenario.name in {"accident", "construction"} else 0.70
217
+ )
218
+ if completion_ratio < min_completion_ratio:
219
+ raise ValueError(
220
+ f"Scenario {scenario.name} produced too few valid flows "
221
+ f"({len(flows)}/{trip_total}, completion={completion_ratio:.3f})."
222
+ )
223
+
224
+ if not flows:
225
+ raise ValueError(f"Scenario {scenario.name} produced no valid flows.")
226
+ summary = summarize_route_validation(
227
+ flow_entries=flows,
228
+ roads_by_id=routing_state["road_lookup"],
229
+ roadlinks_by_intersection=routing_state["roadlinks_by_intersection"],
230
+ )
231
+ if summary["invalid_routes"] > 0:
232
+ reasons = ", ".join(
233
+ f"{reason}={count}" for reason, count in summary["top_failure_reasons"]
234
+ )
235
+ raise ValueError(
236
+ f"Scenario {scenario.name} has invalid routes after regeneration: {reasons}"
237
+ )
238
+ return flows
239
+
240
+ def _build_gateway_context(
241
+ self,
242
+ city_graph: CityGraph,
243
+ district_data: DistrictData,
244
+ ) -> dict[str, Any]:
245
+ assignment = district_data.intersection_to_district
246
+ gateways = sorted(city_graph.gateway_intersections)
247
+ anchors_by_gateway: dict[str, str] = {}
248
+ district_by_gateway: dict[str, str] = {}
249
+ gateways_by_district: dict[str, list[str]] = defaultdict(list)
250
+
251
+ for gateway in gateways:
252
+ neighbors = [
253
+ n for n in city_graph.adjacency.get(gateway, set()) if n in assignment
254
+ ]
255
+ if not neighbors:
256
+ continue
257
+ anchor = sorted(neighbors)[0]
258
+ district_id = assignment[anchor]
259
+ anchors_by_gateway[gateway] = anchor
260
+ district_by_gateway[gateway] = district_id
261
+ gateways_by_district[district_id].append(gateway)
262
+
263
+ return {
264
+ "gateways": sorted(anchors_by_gateway.keys()),
265
+ "anchors_by_gateway": anchors_by_gateway,
266
+ "district_by_gateway": district_by_gateway,
267
+ "gateways_by_district": gateways_by_district,
268
+ "anchor_nodes": set(anchors_by_gateway.values()),
269
+ }
270
+
271
+ def _external_trip_share(
272
+ self,
273
+ scenario: ScenarioPlan,
274
+ has_gateways: bool,
275
+ ) -> float:
276
+ if not has_gateways:
277
+ return 0.0
278
+ base = {
279
+ "normal": 0.12,
280
+ "morning_rush": 0.16,
281
+ "evening_rush": 0.16,
282
+ "accident": 0.20,
283
+ "construction": 0.18,
284
+ "event_spike": 0.18,
285
+ "district_overload": 0.19,
286
+ }[scenario.name]
287
+ intensity_boost = {
288
+ "normal": 0.00,
289
+ "moderate_rush": 0.02,
290
+ "heavy_rush": 0.04,
291
+ "overload": 0.07,
292
+ "accident_overload": 0.10,
293
+ }.get(scenario.intensity, 0.0)
294
+ return min(0.42, base + intensity_boost)
295
+
296
+ def _sample_external_mode(
297
+ self,
298
+ rng: random.Random,
299
+ external_share: float,
300
+ ) -> str:
301
+ if external_share <= 0.0:
302
+ return "none"
303
+ if rng.random() >= external_share:
304
+ return "none"
305
+ return "inbound" if rng.random() < 0.5 else "outbound"
306
+
307
+ def _sample_gateway_for_inbound(
308
+ self,
309
+ rng: random.Random,
310
+ gateway_context: dict[str, Any],
311
+ attraction_weights: dict[str, float],
312
+ ) -> str | None:
313
+ gateways = gateway_context["gateways"]
314
+ if not gateways:
315
+ return None
316
+ district_by_gateway = gateway_context["district_by_gateway"]
317
+ weights: list[float] = []
318
+ for gateway in gateways:
319
+ district_id = district_by_gateway[gateway]
320
+ weights.append(1.0 + attraction_weights.get(district_id, 1.0))
321
+ return choose_weighted(rng, gateways, weights)
322
+
323
+ def _sample_gateway_for_outbound(
324
+ self,
325
+ rng: random.Random,
326
+ origin_district: str,
327
+ gateway_context: dict[str, Any],
328
+ connector_counts: dict[tuple[str, str], int],
329
+ ) -> str | None:
330
+ gateways = gateway_context["gateways"]
331
+ if not gateways:
332
+ return None
333
+ district_by_gateway = gateway_context["district_by_gateway"]
334
+ weights: list[float] = []
335
+ for gateway in gateways:
336
+ gateway_district = district_by_gateway[gateway]
337
+ connector_bonus = 1.0 + connector_counts.get(
338
+ (origin_district, gateway_district), 0
339
+ )
340
+ same_district_bonus = 2.0 if gateway_district == origin_district else 1.0
341
+ weights.append(connector_bonus * same_district_bonus)
342
+ return choose_weighted(rng, gateways, weights)
343
+
344
+ def _build_routing_state(
345
+ self,
346
+ city_graph: CityGraph,
347
+ scenario: ScenarioPlan,
348
+ ) -> dict[str, Any]:
349
+ road_lookup = build_road_index(city_graph.roadnet)
350
+ roadlinks_by_intersection = build_roadlink_index(city_graph.roadnet)
351
+ start_roads_by_intersection: dict[str, list[str]] = defaultdict(list)
352
+ road_cost: dict[str, float] = {}
353
+ available_roads: set[str] = set()
354
+
355
+ for road in city_graph.directed_roads.values():
356
+ if road.id in scenario.blocked_roads:
357
+ continue
358
+ cost = max(1.0, road.length / max(road.speed_limit, 1.0))
359
+ if road.id in scenario.penalized_roads:
360
+ cost *= scenario.penalized_roads[road.id]
361
+ road_cost[road.id] = cost
362
+ available_roads.add(road.id)
363
+ start_roads_by_intersection[road.start_intersection].append(road.id)
364
+
365
+ transitions: dict[str, list[str]] = defaultdict(list)
366
+ for pairs in roadlinks_by_intersection.values():
367
+ for start_road, end_road in pairs:
368
+ if start_road not in available_roads or end_road not in available_roads:
369
+ continue
370
+ transitions[start_road].append(end_road)
371
+ return {
372
+ "road_lookup": road_lookup,
373
+ "roadlinks_by_intersection": roadlinks_by_intersection,
374
+ "start_roads_by_intersection": start_roads_by_intersection,
375
+ "transitions": transitions,
376
+ "road_cost": road_cost,
377
+ }
378
+
379
+ def _find_turn_feasible_route(
380
+ self,
381
+ start_intersection: str,
382
+ end_intersection: str,
383
+ road_lookup: dict[str, dict[str, Any]],
384
+ start_roads_by_intersection: dict[str, list[str]],
385
+ transitions: dict[str, list[str]],
386
+ road_cost: dict[str, float],
387
+ ) -> list[str] | None:
388
+ if start_intersection == end_intersection:
389
+ return None
390
+ start_roads = start_roads_by_intersection.get(start_intersection, [])
391
+ if not start_roads:
392
+ return None
393
+
394
+ queue: list[tuple[float, str]] = []
395
+ dist: dict[str, float] = {}
396
+ prev: dict[str, str | None] = {}
397
+
398
+ for road_id in start_roads:
399
+ if road_id not in road_cost:
400
+ continue
401
+ cost = road_cost[road_id]
402
+ dist[road_id] = cost
403
+ prev[road_id] = None
404
+ heapq.heappush(queue, (cost, road_id))
405
+
406
+ best_terminal: str | None = None
407
+ while queue:
408
+ current_cost, current_road = heapq.heappop(queue)
409
+ if current_cost > dist.get(current_road, float("inf")):
410
+ continue
411
+
412
+ current_end = road_lookup[current_road]["endIntersection"]
413
+ if current_end == end_intersection:
414
+ best_terminal = current_road
415
+ break
416
+
417
+ for next_road in transitions.get(current_road, []):
418
+ next_cost = current_cost + road_cost[next_road]
419
+ if next_cost < dist.get(next_road, float("inf")):
420
+ dist[next_road] = next_cost
421
+ prev[next_road] = current_road
422
+ heapq.heappush(queue, (next_cost, next_road))
423
+
424
+ if best_terminal is None:
425
+ return None
426
+
427
+ route: list[str] = []
428
+ cursor: str | None = best_terminal
429
+ while cursor is not None:
430
+ route.append(cursor)
431
+ cursor = prev[cursor]
432
+ route.reverse()
433
+ return route
434
+
435
+ def _build_district_features(
436
+ self,
437
+ city_graph: CityGraph,
438
+ district_data: DistrictData,
439
+ ) -> dict[str, dict[str, float]]:
440
+ features: dict[str, dict[str, float]] = {}
441
+ for did, district in district_data.districts.items():
442
+ members = district.intersections
443
+ size = len(members)
444
+ cx = sum(city_graph.intersections[n][0] for n in members) / max(1, size)
445
+ cy = sum(city_graph.intersections[n][1] for n in members) / max(1, size)
446
+ features[did] = {
447
+ "size": float(size),
448
+ "neighbors": float(len(district.neighbors)),
449
+ "exits": float(len(district.exit_roads)),
450
+ "boundary": float(len(district.boundary_intersections)),
451
+ "cx": cx,
452
+ "cy": cy,
453
+ }
454
+ return features
455
+
456
+ def _build_connector_counts(
457
+ self,
458
+ city_graph: CityGraph,
459
+ district_data: DistrictData,
460
+ ) -> dict[tuple[str, str], int]:
461
+ connector_counts: dict[tuple[str, str], int] = defaultdict(int)
462
+ assignment = district_data.intersection_to_district
463
+ for a, neighbors in city_graph.adjacency.items():
464
+ if a not in assignment:
465
+ continue
466
+ da = assignment[a]
467
+ for b in neighbors:
468
+ if b not in assignment:
469
+ continue
470
+ db = assignment[b]
471
+ if da == db:
472
+ continue
473
+ connector_counts[(da, db)] += 1
474
+ return connector_counts
475
+
476
+ def _district_weights(
477
+ self,
478
+ city_graph: CityGraph,
479
+ district_data: DistrictData,
480
+ district_features: dict[str, dict[str, float]],
481
+ scenario: ScenarioPlan,
482
+ ) -> tuple[dict[str, float], dict[str, float]]:
483
+ production: dict[str, float] = {}
484
+ attraction: dict[str, float] = {}
485
+ intensity = self.INTENSITY_SCALE.get(scenario.intensity, 1.0)
486
+
487
+ for did, district in district_data.districts.items():
488
+ feature = district_features[did]
489
+ size_factor = max(0.85, min(2.1, math.sqrt(feature["size"]) / 2.0))
490
+ connector_factor = 1.0 + min(1.6, feature["exits"] / 7.0)
491
+ base_prod = self.DISTRICT_BASE_PRODUCTION[district.district_type]
492
+ base_attr = self.DISTRICT_BASE_ATTRACTION[district.district_type]
493
+ production[did] = base_prod * size_factor * (0.60 + 0.40 * connector_factor)
494
+ attraction[did] = base_attr * size_factor * (0.62 + 0.38 * connector_factor)
495
+
496
+ if scenario.name == "morning_rush":
497
+ self._scale_by_type(district_data, production, "residential", 3.0 * intensity)
498
+ self._scale_by_type(district_data, production, "mixed", 1.3 * intensity)
499
+ self._scale_by_type(district_data, attraction, "commercial", 3.2 * intensity)
500
+ self._scale_by_type(district_data, attraction, "industrial", 2.8 * intensity)
501
+ self._scale_by_type(district_data, attraction, "residential", 0.58)
502
+ elif scenario.name == "evening_rush":
503
+ self._scale_by_type(district_data, production, "commercial", 3.1 * intensity)
504
+ self._scale_by_type(district_data, production, "industrial", 2.7 * intensity)
505
+ self._scale_by_type(district_data, attraction, "residential", 3.0 * intensity)
506
+ self._scale_by_type(district_data, attraction, "commercial", 0.62)
507
+ elif scenario.name == "event_spike" and scenario.event_district:
508
+ attraction[scenario.event_district] *= 3.8 * intensity
509
+ production[scenario.event_district] *= 1.9 * intensity
510
+ elif scenario.name == "district_overload" and scenario.overload_district:
511
+ production[scenario.overload_district] *= 3.2 * intensity
512
+ attraction[scenario.overload_district] *= 3.0 * intensity
513
+
514
+ if scenario.name in {"accident", "construction"}:
515
+ impacted = self._impacted_districts(city_graph, district_data, scenario)
516
+ for did in impacted:
517
+ attraction[did] *= 1.6 * intensity
518
+ production[did] *= 1.45 * intensity
519
+ for neighbor in district_data.district_neighbors.get(did, []):
520
+ attraction[neighbor] *= 1.18
521
+ production[neighbor] *= 1.16
522
+
523
+ return production, attraction
524
+
525
+ def _impacted_districts(
526
+ self,
527
+ city_graph: CityGraph,
528
+ district_data: DistrictData,
529
+ scenario: ScenarioPlan,
530
+ ) -> set[str]:
531
+ impacted: set[str] = set()
532
+ assignment = district_data.intersection_to_district
533
+ for road_id in set(scenario.blocked_roads) | set(scenario.penalized_roads.keys()):
534
+ road = city_graph.directed_roads.get(road_id)
535
+ if road is None:
536
+ continue
537
+ if road.start_intersection in assignment:
538
+ impacted.add(assignment[road.start_intersection])
539
+ if road.end_intersection in assignment:
540
+ impacted.add(assignment[road.end_intersection])
541
+ return impacted
542
+
543
+ def _scale_by_type(
544
+ self,
545
+ district_data: DistrictData,
546
+ weights: dict[str, float],
547
+ district_type: str,
548
+ factor: float,
549
+ ) -> None:
550
+ for did, district in district_data.districts.items():
551
+ if district.district_type == district_type:
552
+ weights[did] *= factor
553
+
554
+ def _sample_trip_category(self, rng: random.Random, trip_mix: Any) -> str:
555
+ labels = ["intra", "adjacent", "long"]
556
+ weights = [
557
+ trip_mix.intra_district,
558
+ trip_mix.adjacent_district,
559
+ trip_mix.long_distance,
560
+ ]
561
+ return choose_weighted(rng, labels, weights)
562
+
563
+ def _sample_origin_district(
564
+ self,
565
+ rng: random.Random,
566
+ district_ids: list[str],
567
+ weights: dict[str, float],
568
+ ) -> str:
569
+ values = district_ids
570
+ scalar = [weights[d] for d in district_ids]
571
+ return choose_weighted(rng, values, scalar)
572
+
573
+ def _sample_destination_district(
574
+ self,
575
+ rng: random.Random,
576
+ city_graph: CityGraph,
577
+ origin_district: str,
578
+ category: str,
579
+ district_data: DistrictData,
580
+ district_features: dict[str, dict[str, float]],
581
+ district_ids: list[str],
582
+ attraction_weights: dict[str, float],
583
+ connector_counts: dict[tuple[str, str], int],
584
+ ) -> str:
585
+ if category == "intra":
586
+ return origin_district
587
+
588
+ if category == "adjacent":
589
+ neighbors = district_data.district_neighbors.get(origin_district, [])
590
+ if neighbors:
591
+ weights = []
592
+ for neighbor in neighbors:
593
+ connector = 1.0 + connector_counts.get((origin_district, neighbor), 0)
594
+ weights.append(attraction_weights[neighbor] * connector)
595
+ return choose_weighted(rng, neighbors, weights)
596
+
597
+ origin_feature = district_features[origin_district]
598
+ candidates = [d for d in district_ids if d != origin_district]
599
+ if category == "long":
600
+ candidates = [
601
+ did
602
+ for did in candidates
603
+ if did not in district_data.district_neighbors.get(origin_district, [])
604
+ ] or [d for d in district_ids if d != origin_district]
605
+
606
+ weights: list[float] = []
607
+ for candidate in candidates:
608
+ feature = district_features[candidate]
609
+ dx = feature["cx"] - origin_feature["cx"]
610
+ dy = feature["cy"] - origin_feature["cy"]
611
+ distance = math.hypot(dx, dy)
612
+ normalized_distance = max(1.0, distance / 260.0)
613
+ corridor_bonus = 1.0 + min(
614
+ 1.5,
615
+ (
616
+ feature["exits"] + feature["neighbors"]
617
+ ) / 10.0,
618
+ )
619
+ if category == "long":
620
+ weight = attraction_weights[candidate] * normalized_distance * corridor_bonus
621
+ else:
622
+ weight = attraction_weights[candidate] * (0.85 + 0.15 * corridor_bonus)
623
+ weights.append(weight)
624
+ return choose_weighted(rng, candidates, weights)
625
+
626
+ def _sample_intersection_in_district(
627
+ self,
628
+ rng: random.Random,
629
+ district_data: DistrictData,
630
+ district_id: str,
631
+ favor_boundary: bool,
632
+ excluded_nodes: set[str],
633
+ ) -> str:
634
+ district = district_data.districts[district_id]
635
+ values = [node for node in district.intersections if node not in excluded_nodes]
636
+ if not values:
637
+ values = district.intersections
638
+ boundary = set(district.boundary_intersections)
639
+ weights: list[float] = []
640
+ for node in values:
641
+ if node in boundary:
642
+ weights.append(2.3 if favor_boundary else 1.15)
643
+ else:
644
+ weights.append(0.95 if favor_boundary else 1.2)
645
+ return choose_weighted(rng, values, weights)
646
+
647
+ def _sample_departure(
648
+ self,
649
+ rng: random.Random,
650
+ scenario: ScenarioPlan,
651
+ simulation_steps: int,
652
+ ) -> int:
653
+ windows = scenario.departure_windows
654
+ labels = list(range(len(windows)))
655
+ weights = [w for _, _, w in windows]
656
+ selected = choose_weighted(rng, [str(i) for i in labels], weights)
657
+ window = windows[int(selected)]
658
+ start = int(window[0] * simulation_steps)
659
+ end = int(window[1] * simulation_steps)
660
+ if end <= start:
661
+ end = start + 1
662
+ return rng.randint(start, max(start, end - 1))
data/generators/generate_dataset.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI entrypoint for synthetic CityFlow dataset generation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+
8
+ from .city_generator import CityGenerator
9
+ from .schemas import DatasetGenerationConfig, DemandIntensity
10
+
11
+ DEFAULT_TOPOLOGIES = ["irregular_grid"]
12
+ DEFAULT_SCENARIOS = [
13
+ "normal",
14
+ "morning_rush",
15
+ "evening_rush",
16
+ "accident",
17
+ "construction",
18
+ "event_spike",
19
+ "district_overload",
20
+ ]
21
+ DEFAULT_INTENSITY_LEVELS: list[DemandIntensity] = [
22
+ "normal",
23
+ "moderate_rush",
24
+ "heavy_rush",
25
+ "overload",
26
+ "accident_overload",
27
+ ]
28
+ DEFAULT_INTENSITY_DISTRIBUTION: dict[DemandIntensity, float] = {
29
+ "normal": 0.20,
30
+ "moderate_rush": 0.42,
31
+ "heavy_rush": 0.24,
32
+ "overload": 0.10,
33
+ "accident_overload": 0.04,
34
+ }
35
+ DEFAULT_SCENARIO_DEMAND_MULTIPLIERS: dict[str, float] = {
36
+ "normal": 1.15,
37
+ "morning_rush": 1.35,
38
+ "evening_rush": 1.35,
39
+ "accident": 1.75,
40
+ "construction": 1.55,
41
+ "event_spike": 1.65,
42
+ "district_overload": 1.70,
43
+ }
44
+
45
+
46
+ def _parse_csv_list(raw: str | None) -> list[str] | None:
47
+ if raw is None:
48
+ return None
49
+ values = [part.strip() for part in raw.split(",")]
50
+ return [v for v in values if v]
51
+
52
+
53
+ def _parse_key_value_floats(raw: str | None) -> dict[str, float]:
54
+ if not raw:
55
+ return {}
56
+ result: dict[str, float] = {}
57
+ for token in raw.split(","):
58
+ token = token.strip()
59
+ if not token:
60
+ continue
61
+ if "=" not in token:
62
+ raise ValueError(f"Expected key=value pair, got '{token}'.")
63
+ key, value = token.split("=", 1)
64
+ result[key.strip()] = float(value.strip())
65
+ return result
66
+
67
+
68
+ def build_parser() -> argparse.ArgumentParser:
69
+ parser = argparse.ArgumentParser(
70
+ description="Generate synthetic CityFlow cities with district-aware scenarios."
71
+ )
72
+ parser.add_argument("--num-cities", type=int, default=100)
73
+ parser.add_argument("--output-dir", type=Path, default=Path("data/generated"))
74
+ parser.add_argument("--seed", type=int, default=42)
75
+ parser.add_argument("--min-districts", type=int, default=6)
76
+ parser.add_argument("--max-districts", type=int, default=20)
77
+ parser.add_argument("--min-intersections-per-district", type=int, default=4)
78
+ parser.add_argument("--max-intersections-per-district", type=int, default=10)
79
+ parser.add_argument(
80
+ "--topologies",
81
+ type=str,
82
+ default=None,
83
+ help="Comma-separated list of topologies.",
84
+ )
85
+ parser.add_argument(
86
+ "--scenarios",
87
+ type=str,
88
+ default=None,
89
+ help="Comma-separated list of scenarios.",
90
+ )
91
+ parser.add_argument("--simulation-steps", type=int, default=3600)
92
+ parser.add_argument("--interval", type=float, default=1.0)
93
+ parser.add_argument(
94
+ "--intensity-levels",
95
+ type=str,
96
+ default=None,
97
+ help="Comma-separated intensity levels.",
98
+ )
99
+ parser.add_argument(
100
+ "--intensity-distribution",
101
+ type=str,
102
+ default=None,
103
+ help="Comma-separated key=value weights for intensities.",
104
+ )
105
+ parser.add_argument(
106
+ "--global-demand-multiplier",
107
+ type=float,
108
+ default=1.25,
109
+ help="Global demand multiplier across all scenarios.",
110
+ )
111
+ parser.add_argument(
112
+ "--scenario-demand-multipliers",
113
+ type=str,
114
+ default=None,
115
+ help="Comma-separated key=value multipliers by scenario name.",
116
+ )
117
+ parser.add_argument(
118
+ "--ring-diagonal-keep-prob",
119
+ type=float,
120
+ default=0.07,
121
+ help="Keep probability for optional ring-road interior diagonals.",
122
+ )
123
+ parser.add_argument(
124
+ "--ring-max-diagonal-fraction",
125
+ type=float,
126
+ default=0.03,
127
+ help="Maximum fraction of optional diagonals retained in ring-road topology.",
128
+ )
129
+ parser.add_argument("--save-replay", action="store_true")
130
+ parser.add_argument("--fail-fast", action="store_true")
131
+ return parser
132
+
133
+
134
+ def main() -> None:
135
+ parser = build_parser()
136
+ args = parser.parse_args()
137
+ topologies = _parse_csv_list(args.topologies)
138
+ scenarios = _parse_csv_list(args.scenarios)
139
+ intensity_levels = _parse_csv_list(args.intensity_levels)
140
+ intensity_distribution = _parse_key_value_floats(args.intensity_distribution)
141
+ scenario_demand_multipliers = _parse_key_value_floats(
142
+ args.scenario_demand_multipliers
143
+ )
144
+ config = DatasetGenerationConfig(
145
+ num_cities=args.num_cities,
146
+ output_dir=args.output_dir,
147
+ seed=args.seed,
148
+ min_districts=args.min_districts,
149
+ max_districts=args.max_districts,
150
+ min_intersections_per_district=args.min_intersections_per_district,
151
+ max_intersections_per_district=args.max_intersections_per_district,
152
+ topologies=topologies if topologies is not None else DEFAULT_TOPOLOGIES,
153
+ scenarios=scenarios if scenarios is not None else DEFAULT_SCENARIOS,
154
+ intensity_levels=(
155
+ intensity_levels if intensity_levels is not None else DEFAULT_INTENSITY_LEVELS
156
+ ),
157
+ intensity_distribution=(
158
+ intensity_distribution
159
+ if intensity_distribution
160
+ else DEFAULT_INTENSITY_DISTRIBUTION
161
+ ),
162
+ global_demand_multiplier=args.global_demand_multiplier,
163
+ scenario_demand_multipliers=(
164
+ scenario_demand_multipliers
165
+ if scenario_demand_multipliers
166
+ else DEFAULT_SCENARIO_DEMAND_MULTIPLIERS
167
+ ),
168
+ ring_diagonal_keep_prob=args.ring_diagonal_keep_prob,
169
+ ring_max_diagonal_fraction=args.ring_max_diagonal_fraction,
170
+ simulation_steps=args.simulation_steps,
171
+ interval=args.interval,
172
+ save_replay=args.save_replay,
173
+ fail_fast=args.fail_fast,
174
+ )
175
+ CityGenerator().generate_dataset(config)
176
+
177
+
178
+ if __name__ == "__main__":
179
+ main()
data/generators/main.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Preconfigured dataset generation harness (no CLI arguments)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import secrets
6
+ from pathlib import Path
7
+
8
+ from .city_generator import CityGenerator
9
+ from .schemas import DatasetGenerationConfig
10
+
11
+
12
+ # -----------------------------------------------------------------------------
13
+ # Edit these values to control the default batch generation run.
14
+ # -----------------------------------------------------------------------------
15
+ NUM_CITIES: int = 100
16
+ OUTPUT_DIR: Path = Path("data/generated")
17
+ BASE_SEED: int | None = 42 # set to an int for reproducible runs
18
+ TOPOLOGIES: list[str] = [
19
+ "rectangular_grid",
20
+ "irregular_grid",
21
+ "arterial_local",
22
+ "ring_road",
23
+ "mixed",
24
+ ]
25
+ MIN_DISTRICTS: int = 6
26
+ MAX_DISTRICTS: int = 20
27
+ MIN_INTERSECTIONS_PER_DISTRICT: int = 4
28
+ MAX_INTERSECTIONS_PER_DISTRICT: int = 10
29
+ SIMULATION_STEPS: int = 3600
30
+ INTERVAL: float = 1.0
31
+ FAIL_FAST: bool = False
32
+
33
+
34
+ def main() -> None:
35
+ """Run deterministic-configured dataset generation with pre-set defaults."""
36
+ base_seed = BASE_SEED if BASE_SEED is not None else secrets.randbits(63)
37
+
38
+ config = DatasetGenerationConfig(
39
+ num_cities=NUM_CITIES,
40
+ output_dir=OUTPUT_DIR,
41
+ seed=base_seed,
42
+ topologies=TOPOLOGIES,
43
+ min_districts=MIN_DISTRICTS,
44
+ max_districts=MAX_DISTRICTS,
45
+ min_intersections_per_district=MIN_INTERSECTIONS_PER_DISTRICT,
46
+ max_intersections_per_district=MAX_INTERSECTIONS_PER_DISTRICT,
47
+ simulation_steps=SIMULATION_STEPS,
48
+ interval=INTERVAL,
49
+ fail_fast=FAIL_FAST,
50
+ )
51
+
52
+ print(f"Generating {config.num_cities} cities into {config.output_dir}")
53
+ print(f"Base seed: {config.seed}")
54
+ print(f"Topologies: {', '.join(TOPOLOGIES)}")
55
+
56
+ CityGenerator().generate_dataset(config)
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
data/generators/roadnet_generator.py ADDED
@@ -0,0 +1,1034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Road network topology generation and CityFlow roadnet assembly."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import random
7
+ from collections import defaultdict
8
+ from itertools import product
9
+
10
+ from .schemas import CityGraph, RoadRecord, TopologyType
11
+ from .utils import euclidean
12
+
13
+
14
+ class RoadnetGenerator:
15
+ """Generate city intersection graph and convert to CityFlow roadnet format."""
16
+
17
+ def generate(
18
+ self,
19
+ city_id: str,
20
+ seed: int,
21
+ topology: TopologyType,
22
+ target_intersections: int,
23
+ ring_diagonal_keep_prob: float = 0.07,
24
+ ring_max_diagonal_fraction: float = 0.03,
25
+ ) -> CityGraph:
26
+ rng = random.Random(seed)
27
+ coords, undirected_edges, arterial_pairs = self._build_topology(
28
+ topology=topology,
29
+ target_nodes=target_intersections,
30
+ rng=rng,
31
+ ring_diagonal_keep_prob=ring_diagonal_keep_prob,
32
+ ring_max_diagonal_fraction=ring_max_diagonal_fraction,
33
+ )
34
+ (
35
+ coords,
36
+ undirected_edges,
37
+ gateway_pairs,
38
+ gateway_nodes,
39
+ ) = self._augment_with_perimeter_gateways(
40
+ coords=coords,
41
+ undirected_edges=undirected_edges,
42
+ rng=rng,
43
+ )
44
+ adjacency = self._to_adjacency(coords, undirected_edges)
45
+ directed_roads, arterial_road_ids, gateway_road_ids = self._build_directed_roads(
46
+ coords=coords,
47
+ undirected_edges=undirected_edges,
48
+ arterial_pairs=arterial_pairs,
49
+ gateway_pairs=gateway_pairs,
50
+ )
51
+ roadnet = self._build_roadnet(
52
+ coords=coords,
53
+ adjacency=adjacency,
54
+ directed_roads=directed_roads,
55
+ )
56
+ return CityGraph(
57
+ city_id=city_id,
58
+ topology=topology,
59
+ seed=seed,
60
+ intersections=coords,
61
+ adjacency=adjacency,
62
+ directed_roads=directed_roads,
63
+ roadnet=roadnet,
64
+ arterial_roads=arterial_road_ids,
65
+ gateway_intersections=gateway_nodes,
66
+ gateway_roads=gateway_road_ids,
67
+ )
68
+
69
+ def _build_topology(
70
+ self,
71
+ topology: TopologyType,
72
+ target_nodes: int,
73
+ rng: random.Random,
74
+ ring_diagonal_keep_prob: float,
75
+ ring_max_diagonal_fraction: float,
76
+ ) -> tuple[
77
+ dict[str, tuple[float, float]],
78
+ list[tuple[str, str]],
79
+ set[frozenset[str]],
80
+ ]:
81
+ if topology == "rectangular_grid":
82
+ return self._rectangular_grid(target_nodes, rng)
83
+ if topology == "irregular_grid":
84
+ return self._irregular_grid(target_nodes, rng)
85
+ if topology == "arterial_local":
86
+ return self._arterial_local(target_nodes, rng)
87
+ if topology == "ring_road":
88
+ return self._ring_road(
89
+ target_nodes=target_nodes,
90
+ rng=rng,
91
+ ring_diagonal_keep_prob=ring_diagonal_keep_prob,
92
+ ring_max_diagonal_fraction=ring_max_diagonal_fraction,
93
+ )
94
+ return self._mixed(target_nodes, rng)
95
+
96
+ def _dimensions(self, target_nodes: int) -> tuple[int, int]:
97
+ cols = max(3, int(round(math.sqrt(target_nodes))))
98
+ rows = max(3, int(math.ceil(target_nodes / cols)))
99
+ return rows, cols
100
+
101
+ def _grid_coords(
102
+ self,
103
+ rows: int,
104
+ cols: int,
105
+ spacing: float,
106
+ jitter: float,
107
+ rng: random.Random,
108
+ ) -> dict[str, tuple[float, float]]:
109
+ coords: dict[str, tuple[float, float]] = {}
110
+ idx = 0
111
+ for r, c in product(range(rows), range(cols)):
112
+ x = c * spacing + rng.uniform(-jitter, jitter)
113
+ y = r * spacing + rng.uniform(-jitter, jitter)
114
+ coords[f"i_{idx:04d}"] = (x, y)
115
+ idx += 1
116
+ return coords
117
+
118
+ def _smooth_axis_offsets(
119
+ self,
120
+ size: int,
121
+ max_offset: float,
122
+ step_limit: float,
123
+ rng: random.Random,
124
+ ) -> list[float]:
125
+ offsets = [0.0] * size
126
+ drift = 0.0
127
+ for idx in range(1, size - 1):
128
+ drift += rng.uniform(-step_limit, step_limit)
129
+ drift = max(-max_offset, min(max_offset, drift))
130
+ offsets[idx] = drift
131
+
132
+ if size > 2:
133
+ mean_mid = sum(offsets[1:-1]) / (size - 2)
134
+ for idx in range(1, size - 1):
135
+ centered = offsets[idx] - mean_mid
136
+ offsets[idx] = max(-max_offset, min(max_offset, centered))
137
+
138
+ offsets[0] = 0.0
139
+ offsets[-1] = 0.0
140
+ return offsets
141
+
142
+ def _boundary_stability_weight(self, idx: int, size: int) -> float:
143
+ distance = min(idx, size - 1 - idx)
144
+ if distance <= 0:
145
+ return 0.0
146
+ if distance == 1:
147
+ return 0.2
148
+ if distance == 2:
149
+ return 0.45
150
+ if distance == 3:
151
+ return 0.72
152
+ return 1.00
153
+
154
+ def _axis_positions(
155
+ self,
156
+ size: int,
157
+ spacing: float,
158
+ gap_variation: float,
159
+ rng: random.Random,
160
+ ) -> list[float]:
161
+ if size <= 1:
162
+ return [0.0]
163
+
164
+ gap_profile = self._smooth_drop_profile(size - 1, rng)
165
+ positions = [0.0]
166
+ min_gap = spacing * (1.0 - gap_variation)
167
+ max_gap = spacing * (1.0 + gap_variation)
168
+
169
+ for gap_idx in range(size - 1):
170
+ centered = (gap_profile[gap_idx] - 0.5) * 2.0
171
+ edge_distance = min(gap_idx, (size - 2) - gap_idx)
172
+ if edge_distance <= 0:
173
+ edge_weight = 0.32
174
+ elif edge_distance == 1:
175
+ edge_weight = 0.55
176
+ elif edge_distance == 2:
177
+ edge_weight = 0.8
178
+ else:
179
+ edge_weight = 1.0
180
+
181
+ local_noise = rng.uniform(-0.22, 0.22) * gap_variation * edge_weight
182
+ gap = spacing * (1.0 + (centered * gap_variation * edge_weight) + local_noise)
183
+ gap = max(min_gap, min(max_gap, gap))
184
+ positions.append(positions[-1] + gap)
185
+
186
+ nominal_span = spacing * (size - 1)
187
+ actual_span = positions[-1]
188
+ if actual_span > 1e-9:
189
+ scale = nominal_span / actual_span
190
+ positions = [p * scale for p in positions]
191
+ return positions
192
+
193
+ def _irregular_grid_coords(
194
+ self,
195
+ rows: int,
196
+ cols: int,
197
+ spacing: float,
198
+ rng: random.Random,
199
+ ) -> dict[str, tuple[float, float]]:
200
+ # Keep intersections mostly on row/column lines while varying block size.
201
+ row_positions = self._axis_positions(rows, spacing, gap_variation=0.16, rng=rng)
202
+ col_positions = self._axis_positions(cols, spacing, gap_variation=0.16, rng=rng)
203
+
204
+ # Small line-wise drift and very small local jitter.
205
+ max_row_offset = spacing * 0.018
206
+ max_col_offset = spacing * 0.018
207
+ row_step = spacing * 0.006
208
+ col_step = spacing * 0.006
209
+ local_jitter = spacing * 0.0045
210
+
211
+ row_offsets = self._smooth_axis_offsets(rows, max_row_offset, row_step, rng)
212
+ col_offsets = self._smooth_axis_offsets(cols, max_col_offset, col_step, rng)
213
+
214
+ coords: dict[str, tuple[float, float]] = {}
215
+ idx = 0
216
+ for r, c in product(range(rows), range(cols)):
217
+ # Keep perimeter nodes stable while allowing interior irregularity.
218
+ perimeter_weight = min(
219
+ self._boundary_stability_weight(r, rows),
220
+ self._boundary_stability_weight(c, cols),
221
+ )
222
+ jitter_weight = perimeter_weight * perimeter_weight
223
+ x = col_positions[c] + (col_offsets[c] * perimeter_weight)
224
+ y = row_positions[r] + (row_offsets[r] * perimeter_weight)
225
+ x += rng.uniform(-local_jitter, local_jitter) * jitter_weight
226
+ y += rng.uniform(-local_jitter, local_jitter) * jitter_weight
227
+ coords[f"i_{idx:04d}"] = (x, y)
228
+ idx += 1
229
+ return coords
230
+
231
+ def _grid_edges(self, rows: int, cols: int) -> list[tuple[int, int]]:
232
+ edges: list[tuple[int, int]] = []
233
+ for r in range(rows):
234
+ for c in range(cols):
235
+ idx = r * cols + c
236
+ if c + 1 < cols:
237
+ edges.append((idx, idx + 1))
238
+ if r + 1 < rows:
239
+ edges.append((idx, idx + cols))
240
+ return edges
241
+
242
+ def _estimate_spacing(
243
+ self,
244
+ coords: dict[str, tuple[float, float]],
245
+ undirected_edges: list[tuple[str, str]],
246
+ ) -> float:
247
+ if not undirected_edges:
248
+ return 120.0
249
+ lengths = [
250
+ euclidean(coords[a], coords[b])
251
+ for a, b in undirected_edges
252
+ if a in coords and b in coords
253
+ ]
254
+ if not lengths:
255
+ return 120.0
256
+ lengths.sort()
257
+ return lengths[len(lengths) // 2]
258
+
259
+ def _select_spread_nodes(
260
+ self,
261
+ candidates: list[str],
262
+ coords: dict[str, tuple[float, float]],
263
+ count: int,
264
+ axis: str,
265
+ ) -> list[str]:
266
+ if count <= 0 or not candidates:
267
+ return []
268
+ if len(candidates) <= count:
269
+ return candidates
270
+ axis_idx = 0 if axis == "x" else 1
271
+ ordered = sorted(
272
+ candidates,
273
+ key=lambda nid: (coords[nid][axis_idx], nid),
274
+ )
275
+ selected: list[str] = []
276
+ for i in range(count):
277
+ pos = int(round((i * (len(ordered) - 1)) / max(1, count - 1)))
278
+ selected.append(ordered[pos])
279
+ deduped = sorted(set(selected), key=lambda nid: ordered.index(nid))
280
+ if len(deduped) >= count:
281
+ return deduped[:count]
282
+ for node in ordered:
283
+ if node in deduped:
284
+ continue
285
+ deduped.append(node)
286
+ if len(deduped) >= count:
287
+ break
288
+ return deduped
289
+
290
+ def _augment_with_perimeter_gateways(
291
+ self,
292
+ coords: dict[str, tuple[float, float]],
293
+ undirected_edges: list[tuple[str, str]],
294
+ rng: random.Random,
295
+ ) -> tuple[
296
+ dict[str, tuple[float, float]],
297
+ list[tuple[str, str]],
298
+ set[frozenset[str]],
299
+ set[str],
300
+ ]:
301
+ if not coords:
302
+ return coords, undirected_edges, set(), set()
303
+
304
+ min_x = min(x for x, _ in coords.values())
305
+ max_x = max(x for x, _ in coords.values())
306
+ min_y = min(y for _, y in coords.values())
307
+ max_y = max(y for _, y in coords.values())
308
+ spacing = self._estimate_spacing(coords, undirected_edges)
309
+ threshold = max(4.0, spacing * 0.08)
310
+ per_side_target = 2
311
+
312
+ side_to_candidates: dict[str, list[str]] = {
313
+ "west": [],
314
+ "east": [],
315
+ "south": [],
316
+ "north": [],
317
+ }
318
+ for node_id, (x, y) in coords.items():
319
+ distances = {
320
+ "west": abs(x - min_x),
321
+ "east": abs(x - max_x),
322
+ "south": abs(y - min_y),
323
+ "north": abs(y - max_y),
324
+ }
325
+ side = min(distances, key=distances.get)
326
+ if distances[side] <= threshold:
327
+ side_to_candidates[side].append(node_id)
328
+
329
+ selected_anchors: list[tuple[str, str]] = []
330
+ used_anchors: set[str] = set()
331
+ for side in ("west", "east", "south", "north"):
332
+ candidates = [n for n in side_to_candidates[side] if n not in used_anchors]
333
+ axis = "y" if side in {"west", "east"} else "x"
334
+ chosen = self._select_spread_nodes(
335
+ candidates=candidates,
336
+ coords=coords,
337
+ count=per_side_target,
338
+ axis=axis,
339
+ )
340
+ for anchor in chosen:
341
+ if anchor in used_anchors:
342
+ continue
343
+ used_anchors.add(anchor)
344
+ selected_anchors.append((side, anchor))
345
+
346
+ if not selected_anchors:
347
+ return coords, undirected_edges, set(), set()
348
+
349
+ offset = max(45.0, spacing * 0.82)
350
+ gateway_pairs: set[frozenset[str]] = set()
351
+ gateway_nodes: set[str] = set()
352
+ next_idx = 0
353
+ for side, anchor in selected_anchors:
354
+ ax, ay = coords[anchor]
355
+ if side == "west":
356
+ gx, gy = min_x - offset, ay + rng.uniform(-spacing * 0.03, spacing * 0.03)
357
+ elif side == "east":
358
+ gx, gy = max_x + offset, ay + rng.uniform(-spacing * 0.03, spacing * 0.03)
359
+ elif side == "south":
360
+ gx, gy = ax + rng.uniform(-spacing * 0.03, spacing * 0.03), min_y - offset
361
+ else:
362
+ gx, gy = ax + rng.uniform(-spacing * 0.03, spacing * 0.03), max_y + offset
363
+
364
+ gateway_id = f"g_{next_idx:04d}"
365
+ next_idx += 1
366
+ coords[gateway_id] = (gx, gy)
367
+ undirected_edges.append((anchor, gateway_id))
368
+ gateway_pairs.add(frozenset((anchor, gateway_id)))
369
+ gateway_nodes.add(gateway_id)
370
+
371
+ return coords, undirected_edges, gateway_pairs, gateway_nodes
372
+
373
+ def _arterial_indices(self, size: int) -> list[int]:
374
+ candidates = {size // 3, size // 2, (2 * size) // 3}
375
+ selected = sorted(i for i in candidates if 0 < i < size - 1)
376
+ if not selected and size > 2:
377
+ selected = [size // 2]
378
+ return selected
379
+
380
+ def _smooth_drop_profile(self, size: int, rng: random.Random) -> list[float]:
381
+ values: list[float] = []
382
+ state = rng.uniform(-0.25, 0.25)
383
+ for _ in range(size):
384
+ state = 0.78 * state + 0.22 * rng.uniform(-1.0, 1.0)
385
+ values.append(state)
386
+ low = min(values)
387
+ high = max(values)
388
+ if abs(high - low) < 1e-6:
389
+ return [0.5] * size
390
+ return [(value - low) / (high - low) for value in values]
391
+
392
+ def _edge_orientation(
393
+ self,
394
+ a: int,
395
+ b: int,
396
+ cols: int,
397
+ ) -> tuple[str, int, int]:
398
+ ra, ca = divmod(a, cols)
399
+ rb, cb = divmod(b, cols)
400
+ if ra == rb:
401
+ return ("horizontal", ra, min(ca, cb))
402
+ return ("vertical", ca, min(ra, rb))
403
+
404
+ def _is_edge_protected(
405
+ self,
406
+ a: int,
407
+ b: int,
408
+ rows: int,
409
+ cols: int,
410
+ arterial_rows: set[int],
411
+ arterial_cols: set[int],
412
+ ) -> tuple[bool, bool]:
413
+ ra, ca = divmod(a, cols)
414
+ rb, cb = divmod(b, cols)
415
+ if ra == rb:
416
+ on_perimeter = ra in {0, rows - 1}
417
+ on_arterial = ra in arterial_rows
418
+ else:
419
+ on_perimeter = ca in {0, cols - 1}
420
+ on_arterial = ca in arterial_cols
421
+ return on_perimeter or on_arterial, on_arterial
422
+
423
+ def _has_path_without_edge(
424
+ self,
425
+ start: int,
426
+ goal: int,
427
+ adjacency: dict[int, set[int]],
428
+ edge: tuple[int, int],
429
+ ) -> bool:
430
+ blocked_u, blocked_v = edge
431
+ stack = [start]
432
+ visited = {start}
433
+ while stack:
434
+ node = stack.pop()
435
+ if node == goal:
436
+ return True
437
+ for nxt in adjacency[node]:
438
+ if (
439
+ (node == blocked_u and nxt == blocked_v)
440
+ or (node == blocked_v and nxt == blocked_u)
441
+ ):
442
+ continue
443
+ if nxt in visited:
444
+ continue
445
+ visited.add(nxt)
446
+ stack.append(nxt)
447
+ return False
448
+
449
+ def _is_connected_without_str_edge(
450
+ self,
451
+ start: str,
452
+ goal: str,
453
+ adjacency: dict[str, set[str]],
454
+ edge: tuple[str, str],
455
+ ) -> bool:
456
+ blocked_u, blocked_v = edge
457
+ stack = [start]
458
+ visited = {start}
459
+ while stack:
460
+ node = stack.pop()
461
+ if node == goal:
462
+ return True
463
+ for nxt in adjacency[node]:
464
+ if (
465
+ (node == blocked_u and nxt == blocked_v)
466
+ or (node == blocked_v and nxt == blocked_u)
467
+ ):
468
+ continue
469
+ if nxt in visited:
470
+ continue
471
+ visited.add(nxt)
472
+ stack.append(nxt)
473
+ return False
474
+
475
+ def _rectangular_grid(
476
+ self, target_nodes: int, rng: random.Random
477
+ ) -> tuple[dict[str, tuple[float, float]], list[tuple[str, str]], set[frozenset[str]]]:
478
+ rows, cols = self._dimensions(target_nodes)
479
+ coords = self._grid_coords(rows, cols, spacing=120.0, jitter=6.0, rng=rng)
480
+ edges_raw = self._grid_edges(rows, cols)
481
+ id_list = list(coords.keys())
482
+ edges = [(id_list[a], id_list[b]) for a, b in edges_raw]
483
+ return coords, edges, set()
484
+
485
+ def _irregular_grid(
486
+ self, target_nodes: int, rng: random.Random
487
+ ) -> tuple[dict[str, tuple[float, float]], list[tuple[str, str]], set[frozenset[str]]]:
488
+ rows, cols = self._dimensions(int(target_nodes * 1.08))
489
+ coords = self._irregular_grid_coords(rows, cols, spacing=115.0, rng=rng)
490
+ edges_raw = self._grid_edges(rows, cols)
491
+ filtered: set[tuple[int, int]] = set(edges_raw)
492
+ adjacency: dict[int, set[int]] = defaultdict(set)
493
+ for a, b in edges_raw:
494
+ adjacency[a].add(b)
495
+ adjacency[b].add(a)
496
+
497
+ arterial_rows = set(self._arterial_indices(rows))
498
+ arterial_cols = set(self._arterial_indices(cols))
499
+ row_profile = self._smooth_drop_profile(rows, rng)
500
+ col_profile = self._smooth_drop_profile(cols, rng)
501
+ base_drop_prob = 0.11
502
+
503
+ interior_rows = [r for r in range(2, rows - 2) if r not in arterial_rows]
504
+ interior_cols = [c for c in range(2, cols - 2) if c not in arterial_cols]
505
+ row_gap_rows = set(
506
+ rng.sample(interior_rows, k=min(max(1, rows // 7), len(interior_rows)))
507
+ ) if interior_rows else set()
508
+ col_gap_cols = set(
509
+ rng.sample(interior_cols, k=min(max(1, cols // 7), len(interior_cols)))
510
+ ) if interior_cols else set()
511
+
512
+ arterial_pairs: set[frozenset[str]] = set()
513
+ removable: list[tuple[int, int]] = []
514
+ for a, b in edges_raw:
515
+ protected, arterial = self._is_edge_protected(
516
+ a=a,
517
+ b=b,
518
+ rows=rows,
519
+ cols=cols,
520
+ arterial_rows=arterial_rows,
521
+ arterial_cols=arterial_cols,
522
+ )
523
+ if arterial:
524
+ ida = f"i_{a:04d}"
525
+ idb = f"i_{b:04d}"
526
+ arterial_pairs.add(frozenset((ida, idb)))
527
+ if not protected:
528
+ removable.append((a, b))
529
+
530
+ rng.shuffle(removable)
531
+ for a, b in removable:
532
+ orientation, major_idx, minor_idx = self._edge_orientation(a, b, cols)
533
+ if orientation == "horizontal":
534
+ drop_prob = base_drop_prob + (0.11 * row_profile[major_idx]) + (
535
+ 0.08 * col_profile[minor_idx]
536
+ )
537
+ if major_idx in row_gap_rows:
538
+ drop_prob += 0.16
539
+ else:
540
+ drop_prob = base_drop_prob + (0.11 * col_profile[major_idx]) + (
541
+ 0.08 * row_profile[minor_idx]
542
+ )
543
+ if major_idx in col_gap_cols:
544
+ drop_prob += 0.16
545
+
546
+ # Keep perimeter-adjacent links denser to avoid sparse fringes.
547
+ ra, ca = divmod(a, cols)
548
+ rb, cb = divmod(b, cols)
549
+ boundary_distance = min(
550
+ ra,
551
+ rb,
552
+ rows - 1 - ra,
553
+ rows - 1 - rb,
554
+ ca,
555
+ cb,
556
+ cols - 1 - ca,
557
+ cols - 1 - cb,
558
+ )
559
+ if boundary_distance <= 1:
560
+ drop_prob *= 0.45
561
+ elif boundary_distance == 2:
562
+ drop_prob *= 0.75
563
+
564
+ if rng.random() > min(0.46, max(0.0, drop_prob)):
565
+ continue
566
+ if len(adjacency[a]) <= 2 or len(adjacency[b]) <= 2:
567
+ continue
568
+ if not self._has_path_without_edge(a, b, adjacency, (a, b)):
569
+ continue
570
+
571
+ adjacency[a].remove(b)
572
+ adjacency[b].remove(a)
573
+ filtered.remove((a, b))
574
+
575
+ id_list = list(coords.keys())
576
+ edges = [(id_list[a], id_list[b]) for a, b in sorted(filtered)]
577
+ return coords, edges, arterial_pairs
578
+
579
+ def _arterial_local(
580
+ self, target_nodes: int, rng: random.Random
581
+ ) -> tuple[dict[str, tuple[float, float]], list[tuple[str, str]], set[frozenset[str]]]:
582
+ rows, cols = self._dimensions(target_nodes)
583
+ coords = self._grid_coords(rows, cols, spacing=130.0, jitter=8.0, rng=rng)
584
+ edges_raw = self._grid_edges(rows, cols)
585
+ arterial_rows = {rows // 3, (2 * rows) // 3}
586
+ arterial_cols = {cols // 3, (2 * cols) // 3}
587
+ id_list = list(coords.keys())
588
+ edges: list[tuple[str, str]] = []
589
+ arterial_pairs: set[frozenset[str]] = set()
590
+
591
+ for a, b in edges_raw:
592
+ ra, ca = divmod(a, cols)
593
+ rb, cb = divmod(b, cols)
594
+ ida, idb = id_list[a], id_list[b]
595
+ edges.append((ida, idb))
596
+ if (
597
+ ra in arterial_rows
598
+ or rb in arterial_rows
599
+ or ca in arterial_cols
600
+ or cb in arterial_cols
601
+ ):
602
+ arterial_pairs.add(frozenset((ida, idb)))
603
+ elif rng.random() < 0.06:
604
+ arterial_pairs.add(frozenset((ida, idb)))
605
+ return coords, edges, arterial_pairs
606
+
607
+ def _ring_road(
608
+ self,
609
+ target_nodes: int,
610
+ rng: random.Random,
611
+ ring_diagonal_keep_prob: float,
612
+ ring_max_diagonal_fraction: float,
613
+ ) -> tuple[dict[str, tuple[float, float]], list[tuple[str, str]], set[frozenset[str]]]:
614
+ ring_nodes = max(14, int(target_nodes * 0.22))
615
+ inner_nodes = max(24, target_nodes - ring_nodes)
616
+ rows, cols = self._dimensions(inner_nodes)
617
+ inner_coords = self._grid_coords(rows, cols, spacing=90.0, jitter=10.0, rng=rng)
618
+ inner_ids = list(inner_coords.keys())
619
+ inner_index = {node: idx for idx, node in enumerate(inner_ids)}
620
+ min_x = min(x for x, _ in inner_coords.values())
621
+ max_x = max(x for x, _ in inner_coords.values())
622
+ min_y = min(y for _, y in inner_coords.values())
623
+ max_y = max(y for _, y in inner_coords.values())
624
+ center = ((min_x + max_x) / 2.0, (min_y + max_y) / 2.0)
625
+ radius = max(max_x - min_x, max_y - min_y) * 0.70
626
+
627
+ def norm_edge(a: str, b: str) -> tuple[str, str]:
628
+ return (a, b) if a < b else (b, a)
629
+
630
+ coords: dict[str, tuple[float, float]] = dict(inner_coords)
631
+ edge_set: set[tuple[str, str]] = set()
632
+ arterial_pairs: set[frozenset[str]] = set()
633
+
634
+ for a, b in self._grid_edges(rows, cols):
635
+ edge_set.add(norm_edge(inner_ids[a], inner_ids[b]))
636
+
637
+ ring_ids: list[str] = []
638
+ start_idx = len(coords)
639
+ for i in range(ring_nodes):
640
+ angle = (2.0 * math.pi * i) / ring_nodes
641
+ x = center[0] + radius * math.cos(angle)
642
+ y = center[1] + radius * math.sin(angle)
643
+ rid = f"i_{start_idx + i:04d}"
644
+ ring_ids.append(rid)
645
+ coords[rid] = (x, y)
646
+
647
+ for i, rid in enumerate(ring_ids):
648
+ nxt = ring_ids[(i + 1) % ring_nodes]
649
+ edge = norm_edge(rid, nxt)
650
+ edge_set.add(edge)
651
+ arterial_pairs.add(frozenset(edge))
652
+
653
+ boundary_inner_nodes = [
654
+ nid
655
+ for nid in inner_ids
656
+ if ((inner_index[nid] // cols) in {0, rows - 1})
657
+ or ((inner_index[nid] % cols) in {0, cols - 1})
658
+ ]
659
+ spokes = max(6, ring_nodes // 3)
660
+ anchor_pool = boundary_inner_nodes[:] if boundary_inner_nodes else inner_ids
661
+ anchor_ids = [
662
+ anchor_pool[(idx * len(anchor_pool)) // spokes]
663
+ for idx in range(spokes)
664
+ ]
665
+
666
+ protected_spokes: set[tuple[str, str]] = set()
667
+ for i in range(spokes):
668
+ ring_node = ring_ids[(i * ring_nodes) // spokes]
669
+ inner_node = anchor_ids[i]
670
+ edge = norm_edge(ring_node, inner_node)
671
+ edge_set.add(edge)
672
+ arterial_pairs.add(frozenset(edge))
673
+ protected_spokes.add(edge)
674
+
675
+ # Optional non-primary diagonals: sparse extra radials + sparse interior diagonals.
676
+ diagonal_candidates: list[tuple[tuple[str, str], float]] = []
677
+ extra_spokes = max(2, min(5, ring_nodes // 4))
678
+ for i in range(extra_spokes):
679
+ ring_node = ring_ids[(i * ring_nodes) // extra_spokes]
680
+ inner_node = anchor_pool[(i * len(anchor_pool)) // extra_spokes]
681
+ edge = norm_edge(ring_node, inner_node)
682
+ if edge in edge_set or edge in protected_spokes:
683
+ continue
684
+ score = 1.2 + (0.003 * euclidean(coords[ring_node], coords[inner_node]))
685
+ diagonal_candidates.append((edge, score))
686
+
687
+ def _orientation(
688
+ p: tuple[float, float],
689
+ q: tuple[float, float],
690
+ r: tuple[float, float],
691
+ ) -> int:
692
+ value = ((q[1] - p[1]) * (r[0] - q[0])) - ((q[0] - p[0]) * (r[1] - q[1]))
693
+ if abs(value) < 1e-9:
694
+ return 0
695
+ return 1 if value > 0 else 2
696
+
697
+ def _on_segment(
698
+ p: tuple[float, float],
699
+ q: tuple[float, float],
700
+ r: tuple[float, float],
701
+ ) -> bool:
702
+ return (
703
+ min(p[0], r[0]) <= q[0] <= max(p[0], r[0])
704
+ and min(p[1], r[1]) <= q[1] <= max(p[1], r[1])
705
+ )
706
+
707
+ def _segments_intersect(
708
+ p1: tuple[float, float],
709
+ q1: tuple[float, float],
710
+ p2: tuple[float, float],
711
+ q2: tuple[float, float],
712
+ ) -> bool:
713
+ o1 = _orientation(p1, q1, p2)
714
+ o2 = _orientation(p1, q1, q2)
715
+ o3 = _orientation(p2, q2, p1)
716
+ o4 = _orientation(p2, q2, q1)
717
+ if o1 != o2 and o3 != o4:
718
+ return True
719
+ if o1 == 0 and _on_segment(p1, p2, q1):
720
+ return True
721
+ if o2 == 0 and _on_segment(p1, q2, q1):
722
+ return True
723
+ if o3 == 0 and _on_segment(p2, p1, q2):
724
+ return True
725
+ if o4 == 0 and _on_segment(p2, q1, q2):
726
+ return True
727
+ return False
728
+
729
+ def _has_nonendpoint_intersection(
730
+ edge: tuple[str, str],
731
+ other_edges: set[tuple[str, str]],
732
+ ) -> bool:
733
+ a, b = edge
734
+ p1, q1 = coords[a], coords[b]
735
+ for u, v in other_edges:
736
+ # Shared endpoint is expected at valid junctions.
737
+ if len({a, b, u, v}) < 4:
738
+ continue
739
+ p2, q2 = coords[u], coords[v]
740
+ if _segments_intersect(p1, q1, p2, q2):
741
+ return True
742
+ return False
743
+
744
+ for r in range(rows - 1):
745
+ for c in range(cols - 1):
746
+ if rng.random() > 0.05:
747
+ continue
748
+ tl = inner_ids[(r * cols) + c]
749
+ tr = inner_ids[(r * cols) + c + 1]
750
+ bl = inner_ids[((r + 1) * cols) + c]
751
+ br = inner_ids[((r + 1) * cols) + c + 1]
752
+ a, b = (tl, br) if rng.random() < 0.5 else (tr, bl)
753
+ edge = norm_edge(a, b)
754
+ if edge in edge_set:
755
+ continue
756
+ boundary_bonus = 0.30 if (r in {0, rows - 2} or c in {0, cols - 2}) else 0.0
757
+ diagonal_candidates.append((edge, 1.0 + boundary_bonus))
758
+
759
+ keep_prob = max(0.0, min(1.0, ring_diagonal_keep_prob))
760
+ kept_candidates: list[tuple[tuple[str, str], float]] = []
761
+ geometry_edges = set(edge_set)
762
+ for edge, score in diagonal_candidates:
763
+ if rng.random() > keep_prob:
764
+ continue
765
+ if _has_nonendpoint_intersection(edge, geometry_edges):
766
+ continue
767
+ kept_candidates.append((edge, score))
768
+ geometry_edges.add(edge)
769
+ if diagonal_candidates and not kept_candidates:
770
+ for edge, score in sorted(diagonal_candidates, key=lambda item: item[1], reverse=True):
771
+ if not _has_nonendpoint_intersection(edge, edge_set):
772
+ kept_candidates.append((edge, score))
773
+ break
774
+
775
+ for edge, _ in kept_candidates:
776
+ edge_set.add(edge)
777
+
778
+ # Prune weaker optional diagonals while preserving connectivity.
779
+ max_fraction = max(0.0, min(1.0, ring_max_diagonal_fraction))
780
+ max_kept = max(1, int(round(max_fraction * max(1, len(diagonal_candidates)))))
781
+ hard_cap = max(1, min(3, len(inner_ids) // 60))
782
+ max_kept = min(max_kept, hard_cap)
783
+ if len(kept_candidates) > max_kept:
784
+ adjacency: dict[str, set[str]] = {nid: set() for nid in coords}
785
+ for u, v in edge_set:
786
+ adjacency[u].add(v)
787
+ adjacency[v].add(u)
788
+ kept = len(kept_candidates)
789
+ for edge, _ in sorted(kept_candidates, key=lambda item: item[1]):
790
+ if kept <= max_kept:
791
+ break
792
+ u, v = edge
793
+ if edge not in edge_set:
794
+ continue
795
+ if len(adjacency[u]) <= 1 or len(adjacency[v]) <= 1:
796
+ continue
797
+ if not self._is_connected_without_str_edge(u, v, adjacency, edge):
798
+ continue
799
+ adjacency[u].remove(v)
800
+ adjacency[v].remove(u)
801
+ edge_set.remove(edge)
802
+ kept -= 1
803
+
804
+ edges = sorted(edge_set)
805
+ return coords, edges, arterial_pairs
806
+
807
+ def _mixed(
808
+ self, target_nodes: int, rng: random.Random
809
+ ) -> tuple[dict[str, tuple[float, float]], list[tuple[str, str]], set[frozenset[str]]]:
810
+ coords, edges, arterial_pairs = self._arterial_local(target_nodes, rng)
811
+ ids = list(coords.keys())
812
+ for _ in range(max(3, len(ids) // 20)):
813
+ a, b = rng.sample(ids, 2)
814
+ if euclidean(coords[a], coords[b]) < 220.0:
815
+ edge = (a, b) if a < b else (b, a)
816
+ if edge not in edges:
817
+ edges.append(edge)
818
+ if rng.random() < 0.4:
819
+ arterial_pairs.add(frozenset(edge))
820
+ return coords, edges, arterial_pairs
821
+
822
+ def _to_adjacency(
823
+ self,
824
+ coords: dict[str, tuple[float, float]],
825
+ undirected_edges: list[tuple[str, str]],
826
+ ) -> dict[str, set[str]]:
827
+ adjacency: dict[str, set[str]] = {nid: set() for nid in coords}
828
+ for a, b in undirected_edges:
829
+ if a == b:
830
+ continue
831
+ adjacency[a].add(b)
832
+ adjacency[b].add(a)
833
+ return adjacency
834
+
835
+ def _build_directed_roads(
836
+ self,
837
+ coords: dict[str, tuple[float, float]],
838
+ undirected_edges: list[tuple[str, str]],
839
+ arterial_pairs: set[frozenset[str]],
840
+ gateway_pairs: set[frozenset[str]],
841
+ ) -> tuple[dict[str, RoadRecord], set[str], set[str]]:
842
+ roads: dict[str, RoadRecord] = {}
843
+ arterial_ids: set[str] = set()
844
+ gateway_ids: set[str] = set()
845
+ for a, b in undirected_edges:
846
+ for start, end in ((a, b), (b, a)):
847
+ is_arterial = frozenset((a, b)) in arterial_pairs
848
+ is_gateway = frozenset((a, b)) in gateway_pairs
849
+ if is_gateway:
850
+ speed = 12.0
851
+ lanes = 2
852
+ else:
853
+ speed = 14.0 if is_arterial else 11.0
854
+ lanes = 3 if is_arterial else 2
855
+ rid = f"r_{start}_{end}"
856
+ points = [
857
+ {"x": round(coords[start][0], 3), "y": round(coords[start][1], 3)},
858
+ {"x": round(coords[end][0], 3), "y": round(coords[end][1], 3)},
859
+ ]
860
+ record = RoadRecord(
861
+ id=rid,
862
+ start_intersection=start,
863
+ end_intersection=end,
864
+ length=euclidean(coords[start], coords[end]),
865
+ speed_limit=speed,
866
+ num_lanes=lanes,
867
+ points=points,
868
+ is_arterial=is_arterial,
869
+ )
870
+ roads[rid] = record
871
+ if is_arterial:
872
+ arterial_ids.add(rid)
873
+ if is_gateway:
874
+ gateway_ids.add(rid)
875
+ return roads, arterial_ids, gateway_ids
876
+
877
+ def _build_roadnet(
878
+ self,
879
+ coords: dict[str, tuple[float, float]],
880
+ adjacency: dict[str, set[str]],
881
+ directed_roads: dict[str, RoadRecord],
882
+ ) -> dict[str, list[dict[str, object]]]:
883
+ in_roads_by_node: dict[str, list[RoadRecord]] = defaultdict(list)
884
+ out_roads_by_node: dict[str, list[RoadRecord]] = defaultdict(list)
885
+ for road in directed_roads.values():
886
+ out_roads_by_node[road.start_intersection].append(road)
887
+ in_roads_by_node[road.end_intersection].append(road)
888
+
889
+ min_x = min(x for x, _ in coords.values())
890
+ max_x = max(x for x, _ in coords.values())
891
+ min_y = min(y for _, y in coords.values())
892
+ max_y = max(y for _, y in coords.values())
893
+ border_eps = 3.0
894
+
895
+ intersections: list[dict[str, object]] = []
896
+ for nid in sorted(coords):
897
+ x, y = coords[nid]
898
+ degree = len(adjacency[nid])
899
+ is_border = (
900
+ abs(x - min_x) < border_eps
901
+ or abs(x - max_x) < border_eps
902
+ or abs(y - min_y) < border_eps
903
+ or abs(y - max_y) < border_eps
904
+ )
905
+ # Keep boundary intersections non-virtual when they are part of the street grid.
906
+ # Mark only true stubs/dead-ends as virtual.
907
+ virtual = degree <= 1
908
+
909
+ road_links: list[dict[str, object]] = []
910
+ incoming = sorted(in_roads_by_node[nid], key=lambda r: r.id)
911
+ outgoing = sorted(out_roads_by_node[nid], key=lambda r: r.id)
912
+ for in_road in incoming:
913
+ for out_road in outgoing:
914
+ if out_road.end_intersection == in_road.start_intersection:
915
+ continue
916
+ lane_links = []
917
+ lane_count = min(in_road.num_lanes, out_road.num_lanes)
918
+ for lane_idx in range(lane_count):
919
+ lane_links.append(
920
+ {
921
+ "startLaneIndex": lane_idx,
922
+ "endLaneIndex": lane_idx,
923
+ "points": [
924
+ dict(in_road.points[-1]),
925
+ dict(out_road.points[0]),
926
+ ],
927
+ }
928
+ )
929
+ road_links.append(
930
+ {
931
+ "type": self._movement_type(
932
+ coords[in_road.start_intersection],
933
+ coords[nid],
934
+ coords[out_road.end_intersection],
935
+ ),
936
+ "startRoad": in_road.id,
937
+ "endRoad": out_road.id,
938
+ "laneLinks": lane_links,
939
+ }
940
+ )
941
+
942
+ lightphases = self._light_phases(nid, coords, incoming, road_links)
943
+ connected_roads = sorted(
944
+ {road.id for road in incoming + outgoing}
945
+ )
946
+ intersections.append(
947
+ {
948
+ "id": nid,
949
+ "point": {"x": round(x, 3), "y": round(y, 3)},
950
+ "width": 0,
951
+ "roads": connected_roads,
952
+ "virtual": virtual,
953
+ "roadLinks": road_links,
954
+ "trafficLight": {
955
+ "roadLinkIndices": list(range(len(road_links))),
956
+ "lightphases": lightphases,
957
+ },
958
+ }
959
+ )
960
+
961
+ roads: list[dict[str, object]] = []
962
+ for rid in sorted(directed_roads):
963
+ road = directed_roads[rid]
964
+ roads.append(
965
+ {
966
+ "id": road.id,
967
+ "startIntersection": road.start_intersection,
968
+ "endIntersection": road.end_intersection,
969
+ "points": road.points,
970
+ "lanes": [
971
+ {"maxSpeed": road.speed_limit, "width": 3.2}
972
+ for _ in range(road.num_lanes)
973
+ ],
974
+ }
975
+ )
976
+ return {"intersections": intersections, "roads": roads}
977
+
978
+ def _movement_type(
979
+ self,
980
+ in_start: tuple[float, float],
981
+ center: tuple[float, float],
982
+ out_end: tuple[float, float],
983
+ ) -> str:
984
+ v1 = (center[0] - in_start[0], center[1] - in_start[1])
985
+ v2 = (out_end[0] - center[0], out_end[1] - center[1])
986
+ cross = (v1[0] * v2[1]) - (v1[1] * v2[0])
987
+ dot = (v1[0] * v2[0]) + (v1[1] * v2[1])
988
+ angle = math.atan2(cross, dot)
989
+ if abs(angle) < (math.pi / 4):
990
+ return "go_straight"
991
+ if angle > 0:
992
+ return "turn_left"
993
+ return "turn_right"
994
+
995
+ def _light_phases(
996
+ self,
997
+ node_id: str,
998
+ coords: dict[str, tuple[float, float]],
999
+ incoming: list[RoadRecord],
1000
+ road_links: list[dict[str, object]],
1001
+ ) -> list[dict[str, object]]:
1002
+ if not road_links:
1003
+ return [{"time": 30, "availableRoadLinks": []}]
1004
+
1005
+ vertical_incoming: set[str] = set()
1006
+ horizontal_incoming: set[str] = set()
1007
+ center = coords[node_id]
1008
+ for road in incoming:
1009
+ source = coords[road.start_intersection]
1010
+ vx = center[0] - source[0]
1011
+ vy = center[1] - source[1]
1012
+ if abs(vy) >= abs(vx):
1013
+ vertical_incoming.add(road.id)
1014
+ else:
1015
+ horizontal_incoming.add(road.id)
1016
+
1017
+ vertical_links: list[int] = []
1018
+ horizontal_links: list[int] = []
1019
+ for idx, link in enumerate(road_links):
1020
+ start_road = str(link["startRoad"])
1021
+ if start_road in vertical_incoming:
1022
+ vertical_links.append(idx)
1023
+ else:
1024
+ horizontal_links.append(idx)
1025
+
1026
+ if not vertical_links or not horizontal_links:
1027
+ return [{"time": 35, "availableRoadLinks": list(range(len(road_links)))}]
1028
+
1029
+ return [
1030
+ {"time": 30, "availableRoadLinks": vertical_links},
1031
+ {"time": 5, "availableRoadLinks": []},
1032
+ {"time": 30, "availableRoadLinks": horizontal_links},
1033
+ {"time": 5, "availableRoadLinks": []},
1034
+ ]
data/generators/scenario_generator.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scenario plan generation for demand intensity and network disturbances."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ from collections import defaultdict
7
+
8
+ from .schemas import (
9
+ CityGraph,
10
+ DatasetGenerationConfig,
11
+ DemandIntensity,
12
+ DistrictData,
13
+ ScenarioPlan,
14
+ ScenarioType,
15
+ TripMix,
16
+ )
17
+
18
+
19
+ class ScenarioGenerator:
20
+ """Generate scenario-specific modifiers, bottlenecks, and demand intensity."""
21
+
22
+ INTENSITY_MULTIPLIER: dict[DemandIntensity, float] = {
23
+ "normal": 1.0,
24
+ "moderate_rush": 1.45,
25
+ "heavy_rush": 2.1,
26
+ "overload": 2.9,
27
+ "accident_overload": 3.6,
28
+ }
29
+
30
+ SCENARIO_BASE_MULTIPLIER: dict[ScenarioType, float] = {
31
+ "normal": 1.15,
32
+ "morning_rush": 1.35,
33
+ "evening_rush": 1.35,
34
+ "accident": 1.85,
35
+ "construction": 1.65,
36
+ "event_spike": 1.75,
37
+ "district_overload": 1.80,
38
+ }
39
+
40
+ BASE_DEMAND_PER_INTERSECTION: dict[ScenarioType, int] = {
41
+ "normal": 42,
42
+ "morning_rush": 52,
43
+ "evening_rush": 52,
44
+ "accident": 60,
45
+ "construction": 56,
46
+ "event_spike": 62,
47
+ "district_overload": 64,
48
+ }
49
+
50
+ INTENSITY_ALLOWED_BY_SCENARIO: dict[ScenarioType, list[DemandIntensity]] = {
51
+ "normal": ["normal", "moderate_rush", "heavy_rush"],
52
+ "morning_rush": ["moderate_rush", "heavy_rush", "overload"],
53
+ "evening_rush": ["moderate_rush", "heavy_rush", "overload"],
54
+ "accident": ["heavy_rush", "overload", "accident_overload"],
55
+ "construction": ["moderate_rush", "heavy_rush", "overload", "accident_overload"],
56
+ "event_spike": ["moderate_rush", "heavy_rush", "overload", "accident_overload"],
57
+ "district_overload": ["moderate_rush", "heavy_rush", "overload", "accident_overload"],
58
+ }
59
+
60
+ def generate(
61
+ self,
62
+ city_graph: CityGraph,
63
+ district_data: DistrictData,
64
+ scenario_names: list[ScenarioType],
65
+ base_seed: int,
66
+ config: DatasetGenerationConfig,
67
+ ) -> dict[str, ScenarioPlan]:
68
+ plans: dict[str, ScenarioPlan] = {}
69
+ for idx, name in enumerate(scenario_names):
70
+ seed = base_seed + (idx * 101)
71
+ rng = random.Random(seed)
72
+ intensity = self._sample_intensity(name, rng, config)
73
+ trip_multiplier = self._trip_multiplier(name, intensity, config)
74
+ trip_mix = self._trip_mix(name, intensity)
75
+ departure_windows = self._departure_windows(name, intensity)
76
+ base_demand = self.BASE_DEMAND_PER_INTERSECTION[name]
77
+
78
+ if name == "normal":
79
+ plans[name] = ScenarioPlan(
80
+ name=name,
81
+ intensity=intensity,
82
+ seed=seed,
83
+ trip_multiplier=trip_multiplier,
84
+ trip_mix=trip_mix,
85
+ departure_windows=departure_windows,
86
+ metadata={
87
+ "description": "Balanced baseline traffic with sampled intensity.",
88
+ "base_demand_per_intersection": base_demand,
89
+ },
90
+ )
91
+ elif name == "morning_rush":
92
+ plans[name] = ScenarioPlan(
93
+ name=name,
94
+ intensity=intensity,
95
+ seed=seed,
96
+ trip_multiplier=trip_multiplier,
97
+ trip_mix=trip_mix,
98
+ departure_windows=departure_windows,
99
+ metadata={
100
+ "description": "Strong residential outbound and work-district inbound pressure.",
101
+ "base_demand_per_intersection": base_demand,
102
+ },
103
+ )
104
+ elif name == "evening_rush":
105
+ plans[name] = ScenarioPlan(
106
+ name=name,
107
+ intensity=intensity,
108
+ seed=seed,
109
+ trip_multiplier=trip_multiplier,
110
+ trip_mix=trip_mix,
111
+ departure_windows=departure_windows,
112
+ metadata={
113
+ "description": "Strong work-district outbound and residential inbound pressure.",
114
+ "base_demand_per_intersection": base_demand,
115
+ },
116
+ )
117
+ elif name == "accident":
118
+ blocked, penalized = self._accident_impairments(
119
+ city_graph, district_data, intensity, rng
120
+ )
121
+ plans[name] = ScenarioPlan(
122
+ name=name,
123
+ intensity=intensity,
124
+ seed=seed,
125
+ trip_multiplier=trip_multiplier,
126
+ trip_mix=trip_mix,
127
+ departure_windows=departure_windows,
128
+ blocked_roads=blocked,
129
+ penalized_roads=penalized,
130
+ metadata={
131
+ "description": "Severe disruption on connector/arterial corridors.",
132
+ "base_demand_per_intersection": base_demand,
133
+ "accident_roads": sorted(blocked),
134
+ "bottleneck_penalties": {k: v for k, v in penalized.items() if k not in blocked},
135
+ },
136
+ )
137
+ elif name == "construction":
138
+ blocked, penalized = self._construction_impairments(
139
+ city_graph, district_data, intensity, rng
140
+ )
141
+ plans[name] = ScenarioPlan(
142
+ name=name,
143
+ intensity=intensity,
144
+ seed=seed,
145
+ trip_multiplier=trip_multiplier,
146
+ trip_mix=trip_mix,
147
+ departure_windows=departure_windows,
148
+ blocked_roads=blocked,
149
+ penalized_roads=penalized,
150
+ metadata={
151
+ "description": "Localized but severe construction bottlenecks.",
152
+ "base_demand_per_intersection": base_demand,
153
+ "construction_roads": sorted(blocked),
154
+ "bottleneck_penalties": {k: v for k, v in penalized.items() if k not in blocked},
155
+ },
156
+ )
157
+ elif name == "event_spike":
158
+ event_district = self._pick_high_pressure_district(district_data, rng)
159
+ plans[name] = ScenarioPlan(
160
+ name=name,
161
+ intensity=intensity,
162
+ seed=seed,
163
+ trip_multiplier=trip_multiplier,
164
+ trip_mix=trip_mix,
165
+ departure_windows=departure_windows,
166
+ event_district=event_district,
167
+ metadata={
168
+ "description": "Pre-event surge and outbound release around event district.",
169
+ "base_demand_per_intersection": base_demand,
170
+ "event_district": event_district,
171
+ },
172
+ )
173
+ elif name == "district_overload":
174
+ overload = self._pick_high_pressure_district(district_data, rng)
175
+ plans[name] = ScenarioPlan(
176
+ name=name,
177
+ intensity=intensity,
178
+ seed=seed,
179
+ trip_multiplier=trip_multiplier,
180
+ trip_mix=trip_mix,
181
+ departure_windows=departure_windows,
182
+ overload_district=overload,
183
+ metadata={
184
+ "description": "One district receives amplified production and attraction.",
185
+ "base_demand_per_intersection": base_demand,
186
+ "overload_district": overload,
187
+ },
188
+ )
189
+ else:
190
+ raise ValueError(f"Unsupported scenario: {name}")
191
+ return plans
192
+
193
+ def _sample_intensity(
194
+ self,
195
+ scenario: ScenarioType,
196
+ rng: random.Random,
197
+ config: DatasetGenerationConfig,
198
+ ) -> DemandIntensity:
199
+ allowed = self.INTENSITY_ALLOWED_BY_SCENARIO[scenario]
200
+ candidates = [i for i in config.intensity_levels if i in allowed]
201
+ if not candidates:
202
+ candidates = allowed
203
+ scenario_bias: dict[ScenarioType, dict[DemandIntensity, float]] = {
204
+ "normal": {"normal": 1.35},
205
+ "morning_rush": {"heavy_rush": 1.4, "overload": 1.45},
206
+ "evening_rush": {"heavy_rush": 1.4, "overload": 1.45},
207
+ "accident": {"overload": 1.8, "accident_overload": 2.2},
208
+ "construction": {"heavy_rush": 1.3, "overload": 1.7, "accident_overload": 2.0},
209
+ "event_spike": {"heavy_rush": 1.35, "overload": 1.65, "accident_overload": 1.8},
210
+ "district_overload": {"heavy_rush": 1.35, "overload": 1.75, "accident_overload": 1.95},
211
+ }
212
+ bias = scenario_bias.get(scenario, {})
213
+ weights = [
214
+ max(0.0, config.intensity_distribution.get(i, 0.0)) * bias.get(i, 1.0)
215
+ for i in candidates
216
+ ]
217
+ if sum(weights) <= 0.0:
218
+ weights = [1.0] * len(candidates)
219
+ return rng.choices(candidates, weights=weights, k=1)[0]
220
+
221
+ def _trip_multiplier(
222
+ self,
223
+ scenario: ScenarioType,
224
+ intensity: DemandIntensity,
225
+ config: DatasetGenerationConfig,
226
+ ) -> float:
227
+ base = self.SCENARIO_BASE_MULTIPLIER[scenario]
228
+ intensity_scale = self.INTENSITY_MULTIPLIER[intensity]
229
+ per_scenario = config.scenario_demand_multipliers.get(scenario, 1.0)
230
+ return base * intensity_scale * config.global_demand_multiplier * per_scenario
231
+
232
+ def _trip_mix(
233
+ self,
234
+ scenario: ScenarioType,
235
+ intensity: DemandIntensity,
236
+ ) -> TripMix:
237
+ base: dict[ScenarioType, tuple[float, float, float]] = {
238
+ "normal": (0.44, 0.34, 0.22),
239
+ "morning_rush": (0.34, 0.38, 0.28),
240
+ "evening_rush": (0.34, 0.38, 0.28),
241
+ "accident": (0.28, 0.40, 0.32),
242
+ "construction": (0.30, 0.40, 0.30),
243
+ "event_spike": (0.24, 0.40, 0.36),
244
+ "district_overload": (0.26, 0.42, 0.32),
245
+ }
246
+ intra, adjacent, long = base[scenario]
247
+ intensity_shift = {
248
+ "normal": 0.00,
249
+ "moderate_rush": 0.03,
250
+ "heavy_rush": 0.06,
251
+ "overload": 0.09,
252
+ "accident_overload": 0.12,
253
+ }[intensity]
254
+ intra = max(0.14, intra - intensity_shift)
255
+ adjacent = min(0.56, adjacent + (0.55 * intensity_shift))
256
+ long = min(0.44, long + (0.45 * intensity_shift))
257
+ norm = intra + adjacent + long
258
+ return TripMix(
259
+ intra_district=intra / norm,
260
+ adjacent_district=adjacent / norm,
261
+ long_distance=long / norm,
262
+ )
263
+
264
+ def _departure_windows(
265
+ self,
266
+ scenario: ScenarioType,
267
+ intensity: DemandIntensity,
268
+ ) -> list[tuple[float, float, float]]:
269
+ compression = {
270
+ "normal": 1.00,
271
+ "moderate_rush": 0.82,
272
+ "heavy_rush": 0.62,
273
+ "overload": 0.46,
274
+ "accident_overload": 0.35,
275
+ }[intensity]
276
+
277
+ if scenario == "morning_rush":
278
+ peak_width = 0.34 * compression
279
+ peak_start = max(0.07, 0.26 - peak_width / 2.0)
280
+ peak_end = min(0.58, peak_start + peak_width)
281
+ return [(0.0, peak_start, 0.12), (peak_start, peak_end, 0.76), (peak_end, 1.0, 0.12)]
282
+ if scenario == "evening_rush":
283
+ peak_width = 0.34 * compression
284
+ peak_end = min(0.95, 0.74 + peak_width / 2.0)
285
+ peak_start = max(0.34, peak_end - peak_width)
286
+ return [(0.0, peak_start, 0.10), (peak_start, peak_end, 0.76), (peak_end, 1.0, 0.14)]
287
+ if scenario in {"event_spike", "district_overload"}:
288
+ peak_width = 0.42 * compression
289
+ peak_start = max(0.20, 0.52 - peak_width / 2.0)
290
+ peak_end = min(0.90, peak_start + peak_width)
291
+ return [(0.0, peak_start, 0.10), (peak_start, peak_end, 0.74), (peak_end, 1.0, 0.16)]
292
+ if scenario in {"accident", "construction"}:
293
+ peak_width = 0.45 * compression
294
+ peak_start = max(0.16, 0.48 - peak_width / 2.0)
295
+ peak_end = min(0.88, peak_start + peak_width)
296
+ return [(0.0, peak_start, 0.16), (peak_start, peak_end, 0.68), (peak_end, 1.0, 0.16)]
297
+ return [(0.0, 0.28, 0.22), (0.28, 0.72, 0.56), (0.72, 1.0, 0.22)]
298
+
299
+ def _road_importance_scores(
300
+ self,
301
+ city_graph: CityGraph,
302
+ district_data: DistrictData,
303
+ ) -> dict[str, float]:
304
+ boundary_nodes = set(district_data.boundary_intersections)
305
+ scores: dict[str, float] = {}
306
+ for road_id, road in city_graph.directed_roads.items():
307
+ score = 1.0
308
+ if road_id in city_graph.arterial_roads:
309
+ score += 3.0
310
+ if road_id in district_data.inter_district_roads:
311
+ score += 2.5
312
+ if road.start_intersection in boundary_nodes or road.end_intersection in boundary_nodes:
313
+ score += 1.8
314
+ score += 0.30 * len(city_graph.adjacency[road.start_intersection])
315
+ score += 0.30 * len(city_graph.adjacency[road.end_intersection])
316
+ scores[road_id] = score
317
+ return scores
318
+
319
+ def _weighted_sample_without_replacement(
320
+ self,
321
+ candidates: list[str],
322
+ weights: dict[str, float],
323
+ k: int,
324
+ rng: random.Random,
325
+ ) -> list[str]:
326
+ remaining = list(candidates)
327
+ picked: list[str] = []
328
+ k = min(k, len(remaining))
329
+ while remaining and len(picked) < k:
330
+ probs = [max(0.01, weights.get(c, 0.01)) for c in remaining]
331
+ choice = rng.choices(remaining, weights=probs, k=1)[0]
332
+ picked.append(choice)
333
+ remaining.remove(choice)
334
+ return picked
335
+
336
+ def _accident_impairments(
337
+ self,
338
+ city_graph: CityGraph,
339
+ district_data: DistrictData,
340
+ intensity: DemandIntensity,
341
+ rng: random.Random,
342
+ ) -> tuple[set[str], dict[str, float]]:
343
+ importance = self._road_importance_scores(city_graph, district_data)
344
+ ranked = sorted(importance.keys(), key=lambda rid: importance[rid], reverse=True)
345
+ block_count = {
346
+ "normal": 1,
347
+ "moderate_rush": 2,
348
+ "heavy_rush": 2,
349
+ "overload": 3,
350
+ "accident_overload": 4,
351
+ }[intensity]
352
+ blocked = set(self._weighted_sample_without_replacement(ranked[:120], importance, block_count, rng))
353
+ severity = {
354
+ "normal": 7.0,
355
+ "moderate_rush": 8.5,
356
+ "heavy_rush": 10.5,
357
+ "overload": 12.5,
358
+ "accident_overload": 14.5,
359
+ }[intensity]
360
+ penalized: dict[str, float] = {rid: severity for rid in blocked}
361
+
362
+ # Expand impairment to adjacent connector roads to create spillback.
363
+ by_intersection: dict[str, set[str]] = defaultdict(set)
364
+ for road_id, road in city_graph.directed_roads.items():
365
+ by_intersection[road.start_intersection].add(road_id)
366
+ by_intersection[road.end_intersection].add(road_id)
367
+ for rid in blocked:
368
+ road = city_graph.directed_roads[rid]
369
+ nearby = by_intersection[road.start_intersection] | by_intersection[road.end_intersection]
370
+ for neighbor in nearby:
371
+ if neighbor in blocked:
372
+ continue
373
+ if neighbor in district_data.inter_district_roads or neighbor in city_graph.arterial_roads:
374
+ penalized[neighbor] = max(penalized.get(neighbor, 0.0), severity * 0.78)
375
+ return blocked, penalized
376
+
377
+ def _construction_impairments(
378
+ self,
379
+ city_graph: CityGraph,
380
+ district_data: DistrictData,
381
+ intensity: DemandIntensity,
382
+ rng: random.Random,
383
+ ) -> tuple[set[str], dict[str, float]]:
384
+ candidate_district = self._pick_high_pressure_district(district_data, rng)
385
+ members = set(district_data.districts[candidate_district].intersections)
386
+ localized = [
387
+ rid
388
+ for rid, road in city_graph.directed_roads.items()
389
+ if road.start_intersection in members or road.end_intersection in members
390
+ ]
391
+ if not localized:
392
+ localized = sorted(city_graph.directed_roads.keys())
393
+
394
+ importance = self._road_importance_scores(city_graph, district_data)
395
+ block_count = {
396
+ "normal": 1,
397
+ "moderate_rush": 1,
398
+ "heavy_rush": 2,
399
+ "overload": 3,
400
+ "accident_overload": 4,
401
+ }[intensity]
402
+ penalize_count = {
403
+ "normal": 10,
404
+ "moderate_rush": 16,
405
+ "heavy_rush": 22,
406
+ "overload": 30,
407
+ "accident_overload": 36,
408
+ }[intensity]
409
+ blocked = set(
410
+ self._weighted_sample_without_replacement(localized, importance, block_count, rng)
411
+ )
412
+ penalize_candidates = self._weighted_sample_without_replacement(
413
+ localized, importance, penalize_count, rng
414
+ )
415
+ severity = {
416
+ "normal": 5.0,
417
+ "moderate_rush": 6.5,
418
+ "heavy_rush": 8.5,
419
+ "overload": 10.5,
420
+ "accident_overload": 12.5,
421
+ }[intensity]
422
+ penalized: dict[str, float] = {}
423
+ for rid in penalize_candidates:
424
+ factor = severity
425
+ if rid in blocked:
426
+ factor = severity * 1.35
427
+ penalized[rid] = max(penalized.get(rid, 0.0), factor)
428
+ for rid in blocked:
429
+ penalized[rid] = max(penalized.get(rid, 0.0), severity * 1.55)
430
+ return blocked, penalized
431
+
432
+ def _pick_high_pressure_district(
433
+ self,
434
+ district_data: DistrictData,
435
+ rng: random.Random,
436
+ ) -> str:
437
+ district_ids = sorted(district_data.districts.keys())
438
+ weights: list[float] = []
439
+ for did in district_ids:
440
+ district = district_data.districts[did]
441
+ w = 1.0 + 0.20 * len(district.neighbors) + 0.06 * len(district.boundary_intersections)
442
+ weights.append(w)
443
+ return rng.choices(district_ids, weights=weights, k=1)[0]
data/generators/schemas.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Typed schemas and dataclasses for synthetic CityFlow generation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import Any, Literal
8
+
9
+ TopologyType = Literal[
10
+ "rectangular_grid",
11
+ "irregular_grid",
12
+ "arterial_local",
13
+ "ring_road",
14
+ "mixed",
15
+ ]
16
+
17
+ DistrictType = Literal["residential", "commercial", "industrial", "mixed"]
18
+ DemandIntensity = Literal[
19
+ "normal",
20
+ "moderate_rush",
21
+ "heavy_rush",
22
+ "overload",
23
+ "accident_overload",
24
+ ]
25
+ ScenarioType = Literal[
26
+ "normal",
27
+ "morning_rush",
28
+ "evening_rush",
29
+ "accident",
30
+ "construction",
31
+ "event_spike",
32
+ "district_overload",
33
+ ]
34
+
35
+
36
+ @dataclass(slots=True, frozen=True)
37
+ class TripMix:
38
+ """Trip category distribution."""
39
+
40
+ intra_district: float = 0.5
41
+ adjacent_district: float = 0.3
42
+ long_distance: float = 0.2
43
+
44
+
45
+ @dataclass(slots=True)
46
+ class DatasetGenerationConfig:
47
+ """Top-level CLI / generation configuration."""
48
+
49
+ num_cities: int
50
+ output_dir: Path
51
+ seed: int = 42
52
+ min_districts: int = 6
53
+ max_districts: int = 20
54
+ min_intersections_per_district: int = 4
55
+ max_intersections_per_district: int = 10
56
+ topologies: list[TopologyType] = field(
57
+ default_factory=lambda: ["irregular_grid"]
58
+ )
59
+ scenarios: list[ScenarioType] = field(
60
+ default_factory=lambda: [
61
+ "normal",
62
+ "morning_rush",
63
+ "evening_rush",
64
+ "accident",
65
+ "construction",
66
+ "event_spike",
67
+ "district_overload",
68
+ ]
69
+ )
70
+ intensity_levels: list[DemandIntensity] = field(
71
+ default_factory=lambda: [
72
+ "normal",
73
+ "moderate_rush",
74
+ "heavy_rush",
75
+ "overload",
76
+ "accident_overload",
77
+ ]
78
+ )
79
+ intensity_distribution: dict[DemandIntensity, float] = field(
80
+ default_factory=lambda: {
81
+ "normal": 0.20,
82
+ "moderate_rush": 0.42,
83
+ "heavy_rush": 0.24,
84
+ "overload": 0.10,
85
+ "accident_overload": 0.04,
86
+ }
87
+ )
88
+ global_demand_multiplier: float = 1.25
89
+ scenario_demand_multipliers: dict[str, float] = field(
90
+ default_factory=lambda: {
91
+ "normal": 1.15,
92
+ "morning_rush": 1.35,
93
+ "evening_rush": 1.35,
94
+ "accident": 1.75,
95
+ "construction": 1.55,
96
+ "event_spike": 1.65,
97
+ "district_overload": 1.70,
98
+ }
99
+ )
100
+ ring_diagonal_keep_prob: float = 0.07
101
+ ring_max_diagonal_fraction: float = 0.03
102
+ simulation_steps: int = 3600
103
+ interval: float = 1.0
104
+ save_replay: bool = False
105
+ fail_fast: bool = False
106
+
107
+
108
+ @dataclass(slots=True, frozen=True)
109
+ class RoadRecord:
110
+ """Directed road edge record."""
111
+
112
+ id: str
113
+ start_intersection: str
114
+ end_intersection: str
115
+ length: float
116
+ speed_limit: float
117
+ num_lanes: int
118
+ points: list[dict[str, float]]
119
+ is_arterial: bool
120
+
121
+
122
+ @dataclass(slots=True)
123
+ class CityGraph:
124
+ """Intermediate graph representation for generation."""
125
+
126
+ city_id: str
127
+ topology: TopologyType
128
+ seed: int
129
+ intersections: dict[str, tuple[float, float]]
130
+ adjacency: dict[str, set[str]]
131
+ directed_roads: dict[str, RoadRecord]
132
+ roadnet: dict[str, Any]
133
+ arterial_roads: set[str]
134
+ gateway_intersections: set[str] = field(default_factory=set)
135
+ gateway_roads: set[str] = field(default_factory=set)
136
+ inter_district_roads: set[str] = field(default_factory=set)
137
+
138
+
139
+ @dataclass(slots=True)
140
+ class DistrictRecord:
141
+ """District-level metadata."""
142
+
143
+ id: str
144
+ district_type: DistrictType
145
+ intersections: list[str]
146
+ neighbors: list[str]
147
+ boundary_intersections: list[str]
148
+ entry_roads: list[str]
149
+ exit_roads: list[str]
150
+
151
+
152
+ @dataclass(slots=True)
153
+ class DistrictData:
154
+ """District overlay output."""
155
+
156
+ intersection_to_district: dict[str, str]
157
+ districts: dict[str, DistrictRecord]
158
+ district_neighbors: dict[str, list[str]]
159
+ boundary_intersections: list[str]
160
+ inter_district_roads: list[str]
161
+
162
+
163
+ @dataclass(slots=True)
164
+ class ScenarioPlan:
165
+ """Scenario-specific route demand and impairment configuration."""
166
+
167
+ name: ScenarioType
168
+ intensity: DemandIntensity
169
+ seed: int
170
+ trip_multiplier: float
171
+ trip_mix: TripMix
172
+ departure_windows: list[tuple[float, float, float]]
173
+ blocked_roads: set[str] = field(default_factory=set)
174
+ penalized_roads: dict[str, float] = field(default_factory=dict)
175
+ event_district: str | None = None
176
+ overload_district: str | None = None
177
+ metadata: dict[str, Any] = field(default_factory=dict)
data/generators/utils.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for deterministic generation, IO, graph operations, and validation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import heapq
6
+ import json
7
+ import math
8
+ import random
9
+ from collections import Counter, defaultdict, deque
10
+ from pathlib import Path
11
+ from typing import Any, Iterable
12
+
13
+ from .schemas import CityGraph, DistrictData
14
+
15
+
16
+ def ensure_dir(path: Path) -> None:
17
+ path.mkdir(parents=True, exist_ok=True)
18
+
19
+
20
+ def write_json(path: Path, payload: Any) -> None:
21
+ ensure_dir(path.parent)
22
+ path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
23
+
24
+
25
+ def euclidean(a: tuple[float, float], b: tuple[float, float]) -> float:
26
+ return math.hypot(a[0] - b[0], a[1] - b[1])
27
+
28
+
29
+ def clamp(value: int, low: int, high: int) -> int:
30
+ return max(low, min(high, value))
31
+
32
+
33
+ def choose_weighted(rng: random.Random, values: list[str], weights: list[float]) -> str:
34
+ total = sum(weights)
35
+ if total <= 0:
36
+ return values[rng.randrange(len(values))]
37
+ cutoff = rng.random() * total
38
+ cursor = 0.0
39
+ for value, weight in zip(values, weights):
40
+ cursor += weight
41
+ if cursor >= cutoff:
42
+ return value
43
+ return values[-1]
44
+
45
+
46
+ def connected_components(nodes: Iterable[str], adjacency: dict[str, set[str]]) -> list[set[str]]:
47
+ pending = set(nodes)
48
+ components: list[set[str]] = []
49
+ while pending:
50
+ root = pending.pop()
51
+ comp = {root}
52
+ queue = deque([root])
53
+ while queue:
54
+ cur = queue.popleft()
55
+ for nxt in adjacency[cur]:
56
+ if nxt in pending:
57
+ pending.remove(nxt)
58
+ comp.add(nxt)
59
+ queue.append(nxt)
60
+ components.append(comp)
61
+ return components
62
+
63
+
64
+ def dijkstra_shortest_path(
65
+ start: str,
66
+ end: str,
67
+ graph: dict[str, list[tuple[str, float, str]]],
68
+ ) -> list[str] | None:
69
+ """Return road-id path from start intersection to end intersection."""
70
+ if start == end:
71
+ return []
72
+ queue: list[tuple[float, str]] = [(0.0, start)]
73
+ dist: dict[str, float] = {start: 0.0}
74
+ prev: dict[str, tuple[str, str]] = {}
75
+
76
+ while queue:
77
+ cost, node = heapq.heappop(queue)
78
+ if node == end:
79
+ break
80
+ if cost > dist[node]:
81
+ continue
82
+ for nxt, edge_cost, road_id in graph.get(node, []):
83
+ candidate = cost + edge_cost
84
+ if candidate < dist.get(nxt, float("inf")):
85
+ dist[nxt] = candidate
86
+ prev[nxt] = (node, road_id)
87
+ heapq.heappush(queue, (candidate, nxt))
88
+
89
+ if end not in prev:
90
+ return None
91
+
92
+ route: list[str] = []
93
+ cursor = end
94
+ while cursor != start:
95
+ parent, road_id = prev[cursor]
96
+ route.append(road_id)
97
+ cursor = parent
98
+ route.reverse()
99
+ return route
100
+
101
+
102
+ def validate_unique_ids(city_graph: CityGraph) -> None:
103
+ intersection_ids = set(city_graph.intersections.keys())
104
+ if len(intersection_ids) != len(city_graph.intersections):
105
+ raise ValueError("Duplicate intersection IDs found.")
106
+ road_ids = set(city_graph.directed_roads.keys())
107
+ if len(road_ids) != len(city_graph.directed_roads):
108
+ raise ValueError("Duplicate road IDs found.")
109
+
110
+
111
+ def validate_district_contiguity(city_graph: CityGraph, district_data: DistrictData) -> None:
112
+ by_district: dict[str, set[str]] = defaultdict(set)
113
+ for intersection_id, district_id in district_data.intersection_to_district.items():
114
+ by_district[district_id].add(intersection_id)
115
+
116
+ for district_id, members in by_district.items():
117
+ if not members:
118
+ raise ValueError(f"District {district_id} has no intersections.")
119
+ components = connected_components(members, city_graph.adjacency)
120
+ if len(components) != 1:
121
+ raise ValueError(f"District {district_id} is not contiguous.")
122
+
123
+
124
+ def validate_inter_district_connectivity(district_data: DistrictData) -> None:
125
+ connected = sum(1 for roads in district_data.district_neighbors.values() if roads)
126
+ if connected == 0:
127
+ raise ValueError("No inter-district connections found.")
128
+
129
+
130
+ def validate_district_exit_capacity(
131
+ district_data: DistrictData,
132
+ min_exit_roads: int = 2,
133
+ min_entry_roads: int = 2,
134
+ min_neighbor_districts: int = 1,
135
+ ) -> None:
136
+ underconnected_exit: list[str] = []
137
+ underconnected_entry: list[str] = []
138
+ underconnected_neighbors: list[str] = []
139
+ for district_id, district in district_data.districts.items():
140
+ if len(district.exit_roads) < min_exit_roads:
141
+ underconnected_exit.append(
142
+ f"{district_id}:{len(district.exit_roads)}"
143
+ )
144
+ if len(district.entry_roads) < min_entry_roads:
145
+ underconnected_entry.append(
146
+ f"{district_id}:{len(district.entry_roads)}"
147
+ )
148
+ if len(district.neighbors) < min_neighbor_districts:
149
+ underconnected_neighbors.append(
150
+ f"{district_id}:{len(district.neighbors)}"
151
+ )
152
+ if underconnected_exit or underconnected_entry or underconnected_neighbors:
153
+ parts: list[str] = []
154
+ if underconnected_exit:
155
+ parts.append(
156
+ f"exit<{min_exit_roads}: " + ", ".join(underconnected_exit[:8])
157
+ )
158
+ if underconnected_entry:
159
+ parts.append(
160
+ f"entry<{min_entry_roads}: " + ", ".join(underconnected_entry[:8])
161
+ )
162
+ if underconnected_neighbors:
163
+ parts.append(
164
+ f"neighbors<{min_neighbor_districts}: "
165
+ + ", ".join(underconnected_neighbors[:8])
166
+ )
167
+ raise ValueError(
168
+ "District external connectivity too low: " + " | ".join(parts)
169
+ )
170
+
171
+
172
+ def validate_routes(
173
+ flow_entries: list[dict[str, Any]],
174
+ roads_by_id: dict[str, Any],
175
+ ) -> None:
176
+ if not flow_entries:
177
+ raise ValueError("Scenario flow is empty.")
178
+ for idx, vehicle in enumerate(flow_entries):
179
+ route = vehicle.get("route", [])
180
+ if not route:
181
+ raise ValueError(f"Flow entry {idx} has empty route.")
182
+ for road_id in route:
183
+ if road_id not in roads_by_id:
184
+ raise ValueError(f"Flow entry {idx} references missing road {road_id}.")
185
+ for left, right in zip(route, route[1:]):
186
+ left_end = roads_by_id[left]["endIntersection"]
187
+ right_start = roads_by_id[right]["startIntersection"]
188
+ if left_end != right_start:
189
+ raise ValueError(
190
+ f"Invalid route transition: {left} -> {right} in entry {idx}."
191
+ )
192
+
193
+
194
+ def build_road_index(roadnet: dict[str, Any]) -> dict[str, dict[str, Any]]:
195
+ return {road["id"]: road for road in roadnet.get("roads", [])}
196
+
197
+
198
+ def build_roadlink_index(
199
+ roadnet: dict[str, Any],
200
+ ) -> dict[str, set[tuple[str, str]]]:
201
+ roadlinks_by_intersection: dict[str, set[tuple[str, str]]] = defaultdict(set)
202
+ for intersection in roadnet.get("intersections", []):
203
+ if intersection.get("virtual", False):
204
+ # CityFlow ignores roadLinks on virtual intersections.
205
+ continue
206
+ iid = intersection["id"]
207
+ for road_link in intersection.get("roadLinks", []):
208
+ pair = (road_link["startRoad"], road_link["endRoad"])
209
+ roadlinks_by_intersection[iid].add(pair)
210
+ return roadlinks_by_intersection
211
+
212
+
213
+ def validate_route_with_reasons(
214
+ route: list[str],
215
+ roads_by_id: dict[str, dict[str, Any]],
216
+ roadlinks_by_intersection: dict[str, set[tuple[str, str]]],
217
+ ) -> list[str]:
218
+ reasons: list[str] = []
219
+ if not route:
220
+ return ["empty_route"]
221
+ if len(route) < 2:
222
+ return ["route_too_short"]
223
+
224
+ for rid in route:
225
+ if rid not in roads_by_id:
226
+ reasons.append(f"missing_road:{rid}")
227
+ return reasons
228
+
229
+ for left, right in zip(route, route[1:]):
230
+ left_road = roads_by_id[left]
231
+ right_road = roads_by_id[right]
232
+ shared_intersection = left_road["endIntersection"]
233
+ if shared_intersection != right_road["startIntersection"]:
234
+ reasons.append("mismatched_intersection_transition")
235
+ continue
236
+ if (left, right) not in roadlinks_by_intersection.get(shared_intersection, set()):
237
+ reasons.append("missing_roadlink_transition")
238
+ return reasons
239
+
240
+
241
+ def summarize_route_validation(
242
+ flow_entries: list[dict[str, Any]],
243
+ roads_by_id: dict[str, dict[str, Any]],
244
+ roadlinks_by_intersection: dict[str, set[tuple[str, str]]],
245
+ ) -> dict[str, Any]:
246
+ reason_counter: Counter[str] = Counter()
247
+ total = len(flow_entries)
248
+ valid = 0
249
+ invalid = 0
250
+ for flow in flow_entries:
251
+ reasons = validate_route_with_reasons(
252
+ route=flow.get("route", []),
253
+ roads_by_id=roads_by_id,
254
+ roadlinks_by_intersection=roadlinks_by_intersection,
255
+ )
256
+ if reasons:
257
+ invalid += 1
258
+ reason_counter.update(reasons)
259
+ else:
260
+ valid += 1
261
+ return {
262
+ "total_routes": total,
263
+ "valid_routes": valid,
264
+ "invalid_routes": invalid,
265
+ "top_failure_reasons": reason_counter.most_common(10),
266
+ }
267
+
268
+
269
+ def compute_scenario_diagnostics(
270
+ flow_entries: list[dict[str, Any]],
271
+ city_graph: CityGraph,
272
+ district_data: DistrictData,
273
+ ) -> dict[str, Any]:
274
+ roads_by_id = build_road_index(city_graph.roadnet)
275
+ assignment = district_data.intersection_to_district
276
+ road_usage: Counter[str] = Counter()
277
+ origin_counter: Counter[str] = Counter()
278
+ destination_counter: Counter[str] = Counter()
279
+ corridor_counter: Counter[str] = Counter()
280
+ external_origin = 0
281
+ external_destination = 0
282
+
283
+ for flow in flow_entries:
284
+ route = flow.get("route", [])
285
+ if not route:
286
+ continue
287
+ first = roads_by_id.get(route[0])
288
+ last = roads_by_id.get(route[-1])
289
+ if first:
290
+ origin_key = assignment.get(first["startIntersection"], "external")
291
+ origin_counter[origin_key] += 1
292
+ if origin_key == "external":
293
+ external_origin += 1
294
+ if last:
295
+ destination_key = assignment.get(last["endIntersection"], "external")
296
+ destination_counter[destination_key] += 1
297
+ if destination_key == "external":
298
+ external_destination += 1
299
+ for road_id in route:
300
+ road_usage[road_id] += 1
301
+ road = roads_by_id[road_id]
302
+ ds = assignment.get(road["startIntersection"], "external")
303
+ de = assignment.get(road["endIntersection"], "external")
304
+ if ds != de:
305
+ corridor_counter[f"{ds}->{de}"] += 1
306
+
307
+ total_routes = len(flow_entries)
308
+ total_road_traversals = sum(road_usage.values())
309
+ total_roads = len(roads_by_id)
310
+ used_roads = len(road_usage)
311
+ unused_roads = total_roads - used_roads
312
+ boundary_roads = set(district_data.inter_district_roads)
313
+ gateway_roads = set(city_graph.gateway_roads)
314
+ boundary_usage = sum(
315
+ count for road_id, count in road_usage.items() if road_id in boundary_roads
316
+ )
317
+ gateway_usage = sum(
318
+ count for road_id, count in road_usage.items() if road_id in gateway_roads
319
+ )
320
+ top_road_usage = road_usage.most_common(15)
321
+ top_corridors = corridor_counter.most_common(10)
322
+ total_lanes = sum(road.num_lanes for road in city_graph.directed_roads.values())
323
+ demand_per_lane = total_routes / max(1.0, float(total_lanes))
324
+ avg_route_len = total_road_traversals / max(1.0, float(total_routes))
325
+ concentration = (
326
+ sum(count for _, count in top_road_usage[:10]) / max(1.0, float(total_road_traversals))
327
+ )
328
+ boundary_share = boundary_usage / max(1.0, float(total_road_traversals))
329
+ congestion_score = min(
330
+ 100.0,
331
+ 1.9 * demand_per_lane
332
+ + 2.3 * avg_route_len
333
+ + 46.0 * concentration
334
+ + 34.0 * boundary_share,
335
+ )
336
+ if congestion_score < 30.0:
337
+ congestion_level = "manageable"
338
+ elif congestion_score < 52.0:
339
+ congestion_level = "moderate"
340
+ elif congestion_score < 72.0:
341
+ congestion_level = "heavy"
342
+ else:
343
+ congestion_level = "extreme"
344
+
345
+ return {
346
+ "total_vehicles": total_routes,
347
+ "vehicles_by_origin_district": dict(origin_counter),
348
+ "vehicles_by_destination_district": dict(destination_counter),
349
+ "vehicles_from_external": external_origin,
350
+ "vehicles_to_external": external_destination,
351
+ "roads_used": used_roads,
352
+ "unused_roads": unused_roads,
353
+ "boundary_road_usage": boundary_usage,
354
+ "gateway_road_usage": gateway_usage,
355
+ "boundary_road_share": round(boundary_share, 4),
356
+ "top_used_roads": [
357
+ {
358
+ "road_id": road_id,
359
+ "traversals": count,
360
+ "is_arterial": road_id in city_graph.arterial_roads,
361
+ "is_inter_district": road_id in boundary_roads,
362
+ "is_gateway": road_id in gateway_roads,
363
+ }
364
+ for road_id, count in top_road_usage
365
+ ],
366
+ "top_used_corridors": [
367
+ {"corridor": corridor, "traversals": count}
368
+ for corridor, count in top_corridors
369
+ ],
370
+ "estimated_congestion_intensity": {
371
+ "score": round(congestion_score, 3),
372
+ "level": congestion_level,
373
+ "demand_per_lane": round(demand_per_lane, 3),
374
+ "avg_route_length": round(avg_route_len, 3),
375
+ "road_usage_concentration": round(concentration, 4),
376
+ },
377
+ }
data/generators/validate_routes.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Validate generated CityFlow routes and print summary statistics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ import os
8
+ import sys
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ from .utils import build_road_index, build_roadlink_index, summarize_route_validation
13
+
14
+
15
+ def _load_json(path: Path) -> Any:
16
+ return json.loads(path.read_text(encoding="utf-8"))
17
+
18
+
19
+ def _resolve_from_config(config_path: Path) -> tuple[Path, Path]:
20
+ config = _load_json(config_path)
21
+ base_dir = Path(config["dir"])
22
+ roadnet_path = (base_dir / config["roadnetFile"]).resolve()
23
+ flow_path = (base_dir / config["flowFile"]).resolve()
24
+ return roadnet_path, flow_path
25
+
26
+
27
+ def _print_summary(summary: dict[str, Any]) -> None:
28
+ print(f"total routes: {summary['total_routes']}")
29
+ print(f"valid routes: {summary['valid_routes']}")
30
+ print(f"invalid routes: {summary['invalid_routes']}")
31
+ if summary["top_failure_reasons"]:
32
+ formatted = ", ".join(
33
+ f"{reason}={count}" for reason, count in summary["top_failure_reasons"]
34
+ )
35
+ print(f"top failure reasons: {formatted}")
36
+ else:
37
+ print("top failure reasons: none")
38
+
39
+
40
+ def main() -> int:
41
+ parser = argparse.ArgumentParser(description="Validate CityFlow flow routes.")
42
+ parser.add_argument(
43
+ "--config",
44
+ type=Path,
45
+ default=None,
46
+ help="Path to scenario config.json.",
47
+ )
48
+ parser.add_argument(
49
+ "--roadnet",
50
+ type=Path,
51
+ default=None,
52
+ help="Path to roadnet.json (required if --config not provided).",
53
+ )
54
+ parser.add_argument(
55
+ "--flow",
56
+ type=Path,
57
+ default=None,
58
+ help="Path to flow.json (required if --config not provided).",
59
+ )
60
+ args = parser.parse_args()
61
+
62
+ if args.config is not None:
63
+ roadnet_path, flow_path = _resolve_from_config(args.config.resolve())
64
+ else:
65
+ if args.roadnet is None or args.flow is None:
66
+ raise ValueError("Provide --config OR both --roadnet and --flow.")
67
+ roadnet_path = args.roadnet.resolve()
68
+ flow_path = args.flow.resolve()
69
+
70
+ if not roadnet_path.exists():
71
+ raise FileNotFoundError(os.fspath(roadnet_path))
72
+ if not flow_path.exists():
73
+ raise FileNotFoundError(os.fspath(flow_path))
74
+
75
+ roadnet = _load_json(roadnet_path)
76
+ flows = _load_json(flow_path)
77
+ roads_by_id = build_road_index(roadnet)
78
+ roadlinks_by_intersection = build_roadlink_index(roadnet)
79
+ summary = summarize_route_validation(
80
+ flow_entries=flows,
81
+ roads_by_id=roads_by_id,
82
+ roadlinks_by_intersection=roadlinks_by_intersection,
83
+ )
84
+ _print_summary(summary)
85
+ return 1 if summary["invalid_routes"] > 0 else 0
86
+
87
+
88
+ if __name__ == "__main__":
89
+ sys.exit(main())