File size: 22,117 Bytes
2d05890 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 | """Station graph operations using NetworkX β expanded line-graph routing."""
import json
from collections import defaultdict
from pathlib import Path
from dataclasses import dataclass, field
import networkx as nx
TRANSFER_PENALTY_MIN = 5.0
@dataclass
class RouteResult:
path: list[str] # station IDs in order
stations: list[dict] # full station info per stop (name, line, is_transfer, etc.)
distance_miles: float
estimated_minutes: float
transfers: int
line_sequence: list[str] # e.g. ["red", "blue"] if transferring
class MetroGraph:
def __init__(self, system_dir: Path):
"""Load graph.json, stations.json, lines.json from a system directory."""
self.system_dir = system_dir
with open(system_dir / "stations.json") as f:
stations_list = json.load(f)
self.stations: dict[str, dict] = {s["id"]: s for s in stations_list}
with open(system_dir / "lines.json") as f:
self.lines: dict[str, dict] = {l["id"]: l for l in json.load(f)}
with open(system_dir / "graph.json") as f:
graph_data = json.load(f)
self._edges_raw: list[dict] = graph_data["edges"]
# station_id -> set of line_ids serving that station (derived from edges)
self.station_lines: dict[str, set[str]] = defaultdict(set)
for edge in self._edges_raw:
self.station_lines[edge["from"]].add(edge["line"])
self.station_lines[edge["to"]].add(edge["line"])
# Simple graph for connectivity checks (is_valid_path, adjacent_stations)
self.G: nx.Graph = nx.Graph()
for sid, sdata in self.stations.items():
self.G.add_node(sid, **sdata)
for edge in self._edges_raw:
self.G.add_edge(
edge["from"],
edge["to"],
distance_miles=edge["distance_miles"],
travel_time_min=edge["travel_time_min"],
line=edge["line"],
type=edge["type"],
)
# Expanded directed graph for routing
self._expanded = self._build_expanded(self._edges_raw, set(self.stations))
def _build_expanded(
self,
edges: list[dict],
station_ids: set[str],
) -> nx.DiGraph:
"""Build the expanded line graph for transfer-aware Dijkstra.
Nodes:
("enter", station_id) β virtual entry point
(station_id, line_id) β station on a specific line
("exit", station_id) β virtual exit point
Edges:
entry: ("enter", s) β (s, line) weight=0, distance=0
exit: (s, line) β ("exit", s) weight=0, distance=0
travel: (sA, line) β (sB, line) weight=travel_time, distance=d
transfer: (s, lineA) β (s, lineB) weight=TRANSFER_PENALTY_MIN, distance=0
"""
G = nx.DiGraph()
# Collect which lines serve each station
station_lines: dict[str, set[str]] = defaultdict(set)
for edge in edges:
s_from, s_to = edge["from"], edge["to"]
line = edge["line"]
dist = edge["distance_miles"]
time = edge["travel_time_min"]
station_lines[s_from].add(line)
station_lines[s_to].add(line)
# Travel edges (both directions since graph is undirected)
G.add_edge(
(s_from, line), (s_to, line),
weight=time, distance_miles=dist, line=line,
edge_type="travel",
)
G.add_edge(
(s_to, line), (s_from, line),
weight=time, distance_miles=dist, line=line,
edge_type="travel",
)
# Entry, exit, and transfer edges
for sid in station_ids:
lines = station_lines.get(sid, set())
for line in lines:
# Entry
G.add_edge(
("enter", sid), (sid, line),
weight=0, distance_miles=0, edge_type="entry",
)
# Exit
G.add_edge(
(sid, line), ("exit", sid),
weight=0, distance_miles=0, edge_type="exit",
)
# Transfer edges between all line pairs at this station
lines_list = sorted(lines)
for i, lineA in enumerate(lines_list):
for lineB in lines_list[i + 1:]:
G.add_edge(
(sid, lineA), (sid, lineB),
weight=TRANSFER_PENALTY_MIN, distance_miles=0,
edge_type="transfer",
)
G.add_edge(
(sid, lineB), (sid, lineA),
weight=TRANSFER_PENALTY_MIN, distance_miles=0,
edge_type="transfer",
)
return G
def lines_for_station(self, station_id: str) -> set[str]:
"""Return the set of line ids that serve a station."""
sid = self._resolve_station(station_id)
return set(self.station_lines.get(sid, set()))
def _line_subgraph(self, line_id: str) -> nx.Graph:
if line_id not in self.lines:
raise ValueError(f"Unknown line: {line_id}")
sub = nx.Graph()
for edge in self._edges_raw:
if edge["line"] == line_id:
sub.add_edge(edge["from"], edge["to"])
return sub
def is_loop_line(self, line_id: str) -> bool:
"""True if the line has no terminals (every station has degree >= 2 on its own line)."""
sub = self._line_subgraph(line_id)
if sub.number_of_nodes() == 0:
return False
return all(deg >= 2 for _, deg in sub.degree())
def line_terminals(self, line_id: str) -> list[str]:
"""Stations with degree 1 on the line subgraph. Empty list for loop lines."""
sub = self._line_subgraph(line_id)
return [n for n, deg in sub.degree() if deg == 1]
def expand_line_closures(
self,
closures: list[dict],
) -> list[tuple[str, str]]:
"""Expand line-level closures into segment_closures.
Each closure dict: {"line": str, "from_station"?: str, "to_station"?: str}.
Omitting both endpoints closes the entire line. Partial closure requires
both endpoints and raises ValueError on a loop line (ambiguous).
"""
segments: list[tuple[str, str]] = []
for c in closures:
line_id = c.get("line")
if not line_id or line_id not in self.lines:
raise ValueError(f"Unknown line: {line_id}")
from_s = c.get("from_station")
to_s = c.get("to_station")
ordered = list(self.lines[line_id].get("stations", []))
if not ordered:
raise ValueError(f"Line '{line_id}' has no stations defined")
if from_s is None and to_s is None:
keep = set(ordered)
elif from_s is None or to_s is None:
raise ValueError(
f"Partial closure on line '{line_id}' requires both from_station and to_station"
)
else:
if self.is_loop_line(line_id):
raise ValueError(
f"Partial closure on loop line '{line_id}' is ambiguous β use whole-line closure or specify segments"
)
a = self._resolve_station(from_s)
b = self._resolve_station(to_s)
if a not in ordered or b not in ordered:
raise ValueError(
f"Endpoints '{from_s}'/'{to_s}' are not on line '{line_id}'"
)
i, j = ordered.index(a), ordered.index(b)
lo, hi = min(i, j), max(i, j)
keep = set(ordered[lo:hi + 1])
for edge in self._edges_raw:
if (
edge["line"] == line_id
and edge["from"] in keep
and edge["to"] in keep
):
segments.append((edge["from"], edge["to"]))
return segments
def shortest_path(self, origin: str, destination: str) -> RouteResult:
"""Find shortest path by time (with transfer penalty). Returns RouteResult.
Raises ValueError if either station cannot be resolved.
Raises nx.NetworkXNoPath if no path exists between the two stations.
Raises nx.NodeNotFound if a resolved ID is not present in the graph.
"""
origin_id = self._resolve_station(origin)
dest_id = self._resolve_station(destination)
if origin_id == dest_id:
station = self.stations[origin_id]
stop = {
"station_id": origin_id,
"station_name": station["name"],
"line": None,
"is_transfer": False,
"transfer_to": None,
}
return RouteResult(
path=[origin_id],
stations=[stop],
distance_miles=0.0,
estimated_minutes=0.0,
transfers=0,
line_sequence=[],
)
return self._route_on_expanded(origin_id, dest_id, self._expanded)
def shortest_path_avoiding(
self,
origin: str,
destination: str,
blocked_edges: list[tuple[str, str]] | None = None,
blocked_stations: list[str] | None = None,
) -> RouteResult:
"""Compute shortest path avoiding specified edges and stations.
Used by case generator for computing post-disruption alternative routes.
Rebuilds the expanded graph with disrupted edges/stations removed.
Raises ValueError if origin or destination is blocked or cannot be resolved.
Raises nx.NetworkXNoPath if no alternative path exists.
"""
origin_id = self._resolve_station(origin)
dest_id = self._resolve_station(destination)
blocked_station_set = set(blocked_stations) if blocked_stations else set()
blocked_edge_set = set()
if blocked_edges:
for u, v in blocked_edges:
blocked_edge_set.add((u, v))
blocked_edge_set.add((v, u))
if origin_id in blocked_station_set:
raise ValueError(
f"Origin station '{origin}' is blocked by disruption"
)
if dest_id in blocked_station_set:
raise ValueError(
f"Destination station '{destination}' is blocked by disruption"
)
# Filter edges and stations
remaining_stations = set(self.stations) - blocked_station_set
remaining_edges = [
e for e in self._edges_raw
if e["from"] not in blocked_station_set
and e["to"] not in blocked_station_set
and (e["from"], e["to"]) not in blocked_edge_set
]
expanded = self._build_expanded(remaining_edges, remaining_stations)
try:
return self._route_on_expanded(origin_id, dest_id, expanded)
except (nx.NetworkXNoPath, nx.NodeNotFound):
raise nx.NetworkXNoPath(
f"No alternative path between '{origin}' and '{destination}' "
"with current disruption"
)
def shortest_path_with_restrictions(
self,
origin: str,
destination: str,
station_restrictions: list[dict] | None = None,
segment_closures: list[tuple[str, str]] | None = None,
) -> RouteResult:
"""Compute shortest path with typed station restrictions.
station_restrictions: list of {"station": name_or_id, "restriction": type}
- "closed": no entry, exit, transfer, or pass-through
- "skip": trains pass through but don't stop (no entry/exit/transfer)
- "no_transfer": can board/alight but cannot change lines
segment_closures: list of (stationA, stationB) pairs where track is closed.
Raises ValueError if origin/destination is closed or skip.
Raises nx.NetworkXNoPath if no path exists with restrictions.
"""
if not station_restrictions and not segment_closures:
return self.shortest_path(origin, destination)
origin_id = self._resolve_station(origin)
dest_id = self._resolve_station(destination)
# Build restrictions map: station_id β restriction type
restrictions_map: dict[str, str] = {}
for r in (station_restrictions or []):
sid = self._resolve_station(r["station"])
restrictions_map[sid] = r["restriction"]
# Validate origin/destination
for label, sid, name in [("Origin", origin_id, origin),
("Destination", dest_id, destination)]:
restriction = restrictions_map.get(sid)
if restriction in ("closed", "skip"):
raise ValueError(
f"{label} station '{name}' is {restriction} by disruption"
)
# Build segment closure set (both directions)
closed_segments: set[tuple[str, str]] = set()
for seg in (segment_closures or []):
u = self._resolve_station(seg[0])
v = self._resolve_station(seg[1])
closed_segments.add((u, v))
closed_segments.add((v, u))
# Build expanded graph with restrictions
closed_stations = {s for s, r in restrictions_map.items() if r == "closed"}
skip_stations = {s for s, r in restrictions_map.items() if r == "skip"}
no_transfer_stations = {s for s, r in restrictions_map.items()
if r == "no_transfer"}
G = nx.DiGraph()
station_lines: dict[str, set[str]] = defaultdict(set)
# Phase 1: travel edges
for edge in self._edges_raw:
s_from, s_to = edge["from"], edge["to"]
line = edge["line"]
dist = edge["distance_miles"]
time = edge["travel_time_min"]
# Skip segment closures
if (s_from, s_to) in closed_segments:
continue
# Skip travel edges touching closed stations
if s_from in closed_stations or s_to in closed_stations:
continue
station_lines[s_from].add(line)
station_lines[s_to].add(line)
G.add_edge(
(s_from, line), (s_to, line),
weight=time, distance_miles=dist, line=line,
edge_type="travel",
)
G.add_edge(
(s_to, line), (s_from, line),
weight=time, distance_miles=dist, line=line,
edge_type="travel",
)
# Phase 2: entry, exit, transfer edges
no_entry_exit = closed_stations | skip_stations
no_transfer = closed_stations | skip_stations | no_transfer_stations
for sid in set(self.stations) - closed_stations:
lines = station_lines.get(sid, set())
if sid not in no_entry_exit:
for line in lines:
G.add_edge(
("enter", sid), (sid, line),
weight=0, distance_miles=0, edge_type="entry",
)
G.add_edge(
(sid, line), ("exit", sid),
weight=0, distance_miles=0, edge_type="exit",
)
if sid not in no_transfer:
lines_list = sorted(lines)
for i, lineA in enumerate(lines_list):
for lineB in lines_list[i + 1:]:
G.add_edge(
(sid, lineA), (sid, lineB),
weight=TRANSFER_PENALTY_MIN, distance_miles=0,
edge_type="transfer",
)
G.add_edge(
(sid, lineB), (sid, lineA),
weight=TRANSFER_PENALTY_MIN, distance_miles=0,
edge_type="transfer",
)
try:
return self._route_on_expanded(origin_id, dest_id, G)
except (nx.NetworkXNoPath, nx.NodeNotFound):
raise nx.NetworkXNoPath(
f"No path between '{origin}' and '{destination}' "
"with current restrictions"
)
def _route_on_expanded(
self, origin_id: str, dest_id: str, expanded: nx.DiGraph
) -> RouteResult:
"""Run Dijkstra on the expanded graph and convert to RouteResult."""
enter_node = ("enter", origin_id)
exit_node = ("exit", dest_id)
if enter_node not in expanded:
raise nx.NodeNotFound(
f"Node '{origin_id}' is not in the expanded graph"
)
if exit_node not in expanded:
raise nx.NodeNotFound(
f"Node '{dest_id}' is not in the expanded graph"
)
try:
exp_path = nx.shortest_path(
expanded, enter_node, exit_node, weight="weight"
)
except nx.NetworkXNoPath:
raise nx.NetworkXNoPath(
f"No path found between '{origin_id}' and '{dest_id}'"
)
# Convert expanded path to station-level RouteResult
path: list[str] = []
stops: list[dict] = []
line_sequence: list[str] = []
total_distance = 0.0
total_time = 0.0
transfers = 0
current_line: str | None = None
for i in range(len(exp_path) - 1):
node = exp_path[i]
next_node = exp_path[i + 1]
edge_data = expanded[node][next_node]
edge_type = edge_data["edge_type"]
if edge_type == "entry":
# (enter, station) -> (station, line): add origin station
station_id = node[1]
line = next_node[1]
current_line = line
if line not in line_sequence:
line_sequence.append(line)
station = self.stations[station_id]
path.append(station_id)
stops.append({
"station_id": station_id,
"station_name": station["name"],
"line": current_line,
"is_transfer": False,
"transfer_to": None,
})
elif edge_type == "travel":
# (stationA, line) -> (stationB, line): add stationB
station_id = next_node[0]
total_distance += edge_data["distance_miles"]
total_time += edge_data["weight"]
station = self.stations[station_id]
path.append(station_id)
stops.append({
"station_id": station_id,
"station_name": station["name"],
"line": current_line,
"is_transfer": False,
"transfer_to": None,
})
elif edge_type == "transfer":
# (station, lineA) -> (station, lineB): transfer at station
new_line = next_node[1]
transfers += 1
total_time += edge_data["weight"]
if new_line not in line_sequence:
line_sequence.append(new_line)
# Mark the last stop as a transfer point
if stops:
stops[-1]["is_transfer"] = True
stops[-1]["transfer_to"] = new_line
current_line = new_line
# exit edges: no action needed
return RouteResult(
path=path,
stations=stops,
distance_miles=round(total_distance, 2),
estimated_minutes=round(total_time, 1),
transfers=transfers,
line_sequence=line_sequence,
)
def is_valid_path(self, path: list[str]) -> bool:
"""Check if all consecutive stations in path are adjacent in the graph."""
if len(path) == 0:
return False
for i in range(len(path) - 1):
if not self.G.has_edge(path[i], path[i + 1]):
return False
return True
def adjacent_stations(self, station_id: str) -> list[str]:
"""Return neighbor station IDs for a given station.
Raises ValueError if the station cannot be resolved.
"""
sid = self._resolve_station(station_id)
return list(self.G.neighbors(sid))
def station_info(self, station_id: str) -> dict | None:
"""Return full station data, or None if the station does not exist."""
try:
sid = self._resolve_station(station_id)
except ValueError:
return None
return self.stations.get(sid)
def _resolve_station(self, name_or_id: str) -> str:
"""Resolve a station name or ID to its canonical ID.
Accepts an exact station ID or a station name (case-insensitive).
Also matches the base name without parenthetical suffixes, e.g.
"Olympic Park" matches "Aolinpike Gongyuan (Olympic Park)".
Raises ValueError if no match is found.
"""
if name_or_id in self.stations:
return name_or_id
name_lower = name_or_id.lower().strip()
for sid, sdata in self.stations.items():
full = sdata["name"].lower()
# Exact match
if full == name_lower:
return sid
# Match base name (before parenthetical)
if "(" in full:
base = full.split("(")[0].strip()
paren = full.split("(")[1].rstrip(")").strip()
if name_lower == base or name_lower == paren:
return sid
raise ValueError(f"Unknown station: '{name_or_id}'")
|