"""Station graph operations using NetworkX — expanded line-graph routing.""" import json from collections import defaultdict from pathlib import Path from dataclasses import dataclass, field import networkx as nx TRANSFER_PENALTY_MIN = 5.0 @dataclass class RouteResult: path: list[str] # station IDs in order stations: list[dict] # full station info per stop (name, line, is_transfer, etc.) distance_miles: float estimated_minutes: float transfers: int line_sequence: list[str] # e.g. ["red", "blue"] if transferring class MetroGraph: def __init__(self, system_dir: Path): """Load graph.json, stations.json, lines.json from a system directory.""" self.system_dir = system_dir with open(system_dir / "stations.json") as f: stations_list = json.load(f) self.stations: dict[str, dict] = {s["id"]: s for s in stations_list} with open(system_dir / "lines.json") as f: self.lines: dict[str, dict] = {l["id"]: l for l in json.load(f)} with open(system_dir / "graph.json") as f: graph_data = json.load(f) self._edges_raw: list[dict] = graph_data["edges"] # station_id -> set of line_ids serving that station (derived from edges) self.station_lines: dict[str, set[str]] = defaultdict(set) for edge in self._edges_raw: self.station_lines[edge["from"]].add(edge["line"]) self.station_lines[edge["to"]].add(edge["line"]) # Simple graph for connectivity checks (is_valid_path, adjacent_stations) self.G: nx.Graph = nx.Graph() for sid, sdata in self.stations.items(): self.G.add_node(sid, **sdata) for edge in self._edges_raw: self.G.add_edge( edge["from"], edge["to"], distance_miles=edge["distance_miles"], travel_time_min=edge["travel_time_min"], line=edge["line"], type=edge["type"], ) # Expanded directed graph for routing self._expanded = self._build_expanded(self._edges_raw, set(self.stations)) def _build_expanded( self, edges: list[dict], station_ids: set[str], ) -> nx.DiGraph: """Build the expanded line graph for transfer-aware Dijkstra. Nodes: ("enter", station_id) — virtual entry point (station_id, line_id) — station on a specific line ("exit", station_id) — virtual exit point Edges: entry: ("enter", s) → (s, line) weight=0, distance=0 exit: (s, line) → ("exit", s) weight=0, distance=0 travel: (sA, line) → (sB, line) weight=travel_time, distance=d transfer: (s, lineA) → (s, lineB) weight=TRANSFER_PENALTY_MIN, distance=0 """ G = nx.DiGraph() # Collect which lines serve each station station_lines: dict[str, set[str]] = defaultdict(set) for edge in edges: s_from, s_to = edge["from"], edge["to"] line = edge["line"] dist = edge["distance_miles"] time = edge["travel_time_min"] station_lines[s_from].add(line) station_lines[s_to].add(line) # Travel edges (both directions since graph is undirected) G.add_edge( (s_from, line), (s_to, line), weight=time, distance_miles=dist, line=line, edge_type="travel", ) G.add_edge( (s_to, line), (s_from, line), weight=time, distance_miles=dist, line=line, edge_type="travel", ) # Entry, exit, and transfer edges for sid in station_ids: lines = station_lines.get(sid, set()) for line in lines: # Entry G.add_edge( ("enter", sid), (sid, line), weight=0, distance_miles=0, edge_type="entry", ) # Exit G.add_edge( (sid, line), ("exit", sid), weight=0, distance_miles=0, edge_type="exit", ) # Transfer edges between all line pairs at this station lines_list = sorted(lines) for i, lineA in enumerate(lines_list): for lineB in lines_list[i + 1:]: G.add_edge( (sid, lineA), (sid, lineB), weight=TRANSFER_PENALTY_MIN, distance_miles=0, edge_type="transfer", ) G.add_edge( (sid, lineB), (sid, lineA), weight=TRANSFER_PENALTY_MIN, distance_miles=0, edge_type="transfer", ) return G def lines_for_station(self, station_id: str) -> set[str]: """Return the set of line ids that serve a station.""" sid = self._resolve_station(station_id) return set(self.station_lines.get(sid, set())) def _line_subgraph(self, line_id: str) -> nx.Graph: if line_id not in self.lines: raise ValueError(f"Unknown line: {line_id}") sub = nx.Graph() for edge in self._edges_raw: if edge["line"] == line_id: sub.add_edge(edge["from"], edge["to"]) return sub def is_loop_line(self, line_id: str) -> bool: """True if the line has no terminals (every station has degree >= 2 on its own line).""" sub = self._line_subgraph(line_id) if sub.number_of_nodes() == 0: return False return all(deg >= 2 for _, deg in sub.degree()) def line_terminals(self, line_id: str) -> list[str]: """Stations with degree 1 on the line subgraph. Empty list for loop lines.""" sub = self._line_subgraph(line_id) return [n for n, deg in sub.degree() if deg == 1] def expand_line_closures( self, closures: list[dict], ) -> list[tuple[str, str]]: """Expand line-level closures into segment_closures. Each closure dict: {"line": str, "from_station"?: str, "to_station"?: str}. Omitting both endpoints closes the entire line. Partial closure requires both endpoints and raises ValueError on a loop line (ambiguous). """ segments: list[tuple[str, str]] = [] for c in closures: line_id = c.get("line") if not line_id or line_id not in self.lines: raise ValueError(f"Unknown line: {line_id}") from_s = c.get("from_station") to_s = c.get("to_station") ordered = list(self.lines[line_id].get("stations", [])) if not ordered: raise ValueError(f"Line '{line_id}' has no stations defined") if from_s is None and to_s is None: keep = set(ordered) elif from_s is None or to_s is None: raise ValueError( f"Partial closure on line '{line_id}' requires both from_station and to_station" ) else: if self.is_loop_line(line_id): raise ValueError( f"Partial closure on loop line '{line_id}' is ambiguous — use whole-line closure or specify segments" ) a = self._resolve_station(from_s) b = self._resolve_station(to_s) if a not in ordered or b not in ordered: raise ValueError( f"Endpoints '{from_s}'/'{to_s}' are not on line '{line_id}'" ) i, j = ordered.index(a), ordered.index(b) lo, hi = min(i, j), max(i, j) keep = set(ordered[lo:hi + 1]) for edge in self._edges_raw: if ( edge["line"] == line_id and edge["from"] in keep and edge["to"] in keep ): segments.append((edge["from"], edge["to"])) return segments def shortest_path(self, origin: str, destination: str) -> RouteResult: """Find shortest path by time (with transfer penalty). Returns RouteResult. Raises ValueError if either station cannot be resolved. Raises nx.NetworkXNoPath if no path exists between the two stations. Raises nx.NodeNotFound if a resolved ID is not present in the graph. """ origin_id = self._resolve_station(origin) dest_id = self._resolve_station(destination) if origin_id == dest_id: station = self.stations[origin_id] stop = { "station_id": origin_id, "station_name": station["name"], "line": None, "is_transfer": False, "transfer_to": None, } return RouteResult( path=[origin_id], stations=[stop], distance_miles=0.0, estimated_minutes=0.0, transfers=0, line_sequence=[], ) return self._route_on_expanded(origin_id, dest_id, self._expanded) def shortest_path_avoiding( self, origin: str, destination: str, blocked_edges: list[tuple[str, str]] | None = None, blocked_stations: list[str] | None = None, ) -> RouteResult: """Compute shortest path avoiding specified edges and stations. Used by case generator for computing post-disruption alternative routes. Rebuilds the expanded graph with disrupted edges/stations removed. Raises ValueError if origin or destination is blocked or cannot be resolved. Raises nx.NetworkXNoPath if no alternative path exists. """ origin_id = self._resolve_station(origin) dest_id = self._resolve_station(destination) blocked_station_set = set(blocked_stations) if blocked_stations else set() blocked_edge_set = set() if blocked_edges: for u, v in blocked_edges: blocked_edge_set.add((u, v)) blocked_edge_set.add((v, u)) if origin_id in blocked_station_set: raise ValueError( f"Origin station '{origin}' is blocked by disruption" ) if dest_id in blocked_station_set: raise ValueError( f"Destination station '{destination}' is blocked by disruption" ) # Filter edges and stations remaining_stations = set(self.stations) - blocked_station_set remaining_edges = [ e for e in self._edges_raw if e["from"] not in blocked_station_set and e["to"] not in blocked_station_set and (e["from"], e["to"]) not in blocked_edge_set ] expanded = self._build_expanded(remaining_edges, remaining_stations) try: return self._route_on_expanded(origin_id, dest_id, expanded) except (nx.NetworkXNoPath, nx.NodeNotFound): raise nx.NetworkXNoPath( f"No alternative path between '{origin}' and '{destination}' " "with current disruption" ) def shortest_path_with_restrictions( self, origin: str, destination: str, station_restrictions: list[dict] | None = None, segment_closures: list[tuple[str, str]] | None = None, ) -> RouteResult: """Compute shortest path with typed station restrictions. station_restrictions: list of {"station": name_or_id, "restriction": type} - "closed": no entry, exit, transfer, or pass-through - "skip": trains pass through but don't stop (no entry/exit/transfer) - "no_transfer": can board/alight but cannot change lines segment_closures: list of (stationA, stationB) pairs where track is closed. Raises ValueError if origin/destination is closed or skip. Raises nx.NetworkXNoPath if no path exists with restrictions. """ if not station_restrictions and not segment_closures: return self.shortest_path(origin, destination) origin_id = self._resolve_station(origin) dest_id = self._resolve_station(destination) # Build restrictions map: station_id → restriction type restrictions_map: dict[str, str] = {} for r in (station_restrictions or []): sid = self._resolve_station(r["station"]) restrictions_map[sid] = r["restriction"] # Validate origin/destination for label, sid, name in [("Origin", origin_id, origin), ("Destination", dest_id, destination)]: restriction = restrictions_map.get(sid) if restriction in ("closed", "skip"): raise ValueError( f"{label} station '{name}' is {restriction} by disruption" ) # Build segment closure set (both directions) closed_segments: set[tuple[str, str]] = set() for seg in (segment_closures or []): u = self._resolve_station(seg[0]) v = self._resolve_station(seg[1]) closed_segments.add((u, v)) closed_segments.add((v, u)) # Build expanded graph with restrictions closed_stations = {s for s, r in restrictions_map.items() if r == "closed"} skip_stations = {s for s, r in restrictions_map.items() if r == "skip"} no_transfer_stations = {s for s, r in restrictions_map.items() if r == "no_transfer"} G = nx.DiGraph() station_lines: dict[str, set[str]] = defaultdict(set) # Phase 1: travel edges for edge in self._edges_raw: s_from, s_to = edge["from"], edge["to"] line = edge["line"] dist = edge["distance_miles"] time = edge["travel_time_min"] # Skip segment closures if (s_from, s_to) in closed_segments: continue # Skip travel edges touching closed stations if s_from in closed_stations or s_to in closed_stations: continue station_lines[s_from].add(line) station_lines[s_to].add(line) G.add_edge( (s_from, line), (s_to, line), weight=time, distance_miles=dist, line=line, edge_type="travel", ) G.add_edge( (s_to, line), (s_from, line), weight=time, distance_miles=dist, line=line, edge_type="travel", ) # Phase 2: entry, exit, transfer edges no_entry_exit = closed_stations | skip_stations no_transfer = closed_stations | skip_stations | no_transfer_stations for sid in set(self.stations) - closed_stations: lines = station_lines.get(sid, set()) if sid not in no_entry_exit: for line in lines: G.add_edge( ("enter", sid), (sid, line), weight=0, distance_miles=0, edge_type="entry", ) G.add_edge( (sid, line), ("exit", sid), weight=0, distance_miles=0, edge_type="exit", ) if sid not in no_transfer: lines_list = sorted(lines) for i, lineA in enumerate(lines_list): for lineB in lines_list[i + 1:]: G.add_edge( (sid, lineA), (sid, lineB), weight=TRANSFER_PENALTY_MIN, distance_miles=0, edge_type="transfer", ) G.add_edge( (sid, lineB), (sid, lineA), weight=TRANSFER_PENALTY_MIN, distance_miles=0, edge_type="transfer", ) try: return self._route_on_expanded(origin_id, dest_id, G) except (nx.NetworkXNoPath, nx.NodeNotFound): raise nx.NetworkXNoPath( f"No path between '{origin}' and '{destination}' " "with current restrictions" ) def _route_on_expanded( self, origin_id: str, dest_id: str, expanded: nx.DiGraph ) -> RouteResult: """Run Dijkstra on the expanded graph and convert to RouteResult.""" enter_node = ("enter", origin_id) exit_node = ("exit", dest_id) if enter_node not in expanded: raise nx.NodeNotFound( f"Node '{origin_id}' is not in the expanded graph" ) if exit_node not in expanded: raise nx.NodeNotFound( f"Node '{dest_id}' is not in the expanded graph" ) try: exp_path = nx.shortest_path( expanded, enter_node, exit_node, weight="weight" ) except nx.NetworkXNoPath: raise nx.NetworkXNoPath( f"No path found between '{origin_id}' and '{dest_id}'" ) # Convert expanded path to station-level RouteResult path: list[str] = [] stops: list[dict] = [] line_sequence: list[str] = [] total_distance = 0.0 total_time = 0.0 transfers = 0 current_line: str | None = None for i in range(len(exp_path) - 1): node = exp_path[i] next_node = exp_path[i + 1] edge_data = expanded[node][next_node] edge_type = edge_data["edge_type"] if edge_type == "entry": # (enter, station) -> (station, line): add origin station station_id = node[1] line = next_node[1] current_line = line if line not in line_sequence: line_sequence.append(line) station = self.stations[station_id] path.append(station_id) stops.append({ "station_id": station_id, "station_name": station["name"], "line": current_line, "is_transfer": False, "transfer_to": None, }) elif edge_type == "travel": # (stationA, line) -> (stationB, line): add stationB station_id = next_node[0] total_distance += edge_data["distance_miles"] total_time += edge_data["weight"] station = self.stations[station_id] path.append(station_id) stops.append({ "station_id": station_id, "station_name": station["name"], "line": current_line, "is_transfer": False, "transfer_to": None, }) elif edge_type == "transfer": # (station, lineA) -> (station, lineB): transfer at station new_line = next_node[1] transfers += 1 total_time += edge_data["weight"] if new_line not in line_sequence: line_sequence.append(new_line) # Mark the last stop as a transfer point if stops: stops[-1]["is_transfer"] = True stops[-1]["transfer_to"] = new_line current_line = new_line # exit edges: no action needed return RouteResult( path=path, stations=stops, distance_miles=round(total_distance, 2), estimated_minutes=round(total_time, 1), transfers=transfers, line_sequence=line_sequence, ) def is_valid_path(self, path: list[str]) -> bool: """Check if all consecutive stations in path are adjacent in the graph.""" if len(path) == 0: return False for i in range(len(path) - 1): if not self.G.has_edge(path[i], path[i + 1]): return False return True def adjacent_stations(self, station_id: str) -> list[str]: """Return neighbor station IDs for a given station. Raises ValueError if the station cannot be resolved. """ sid = self._resolve_station(station_id) return list(self.G.neighbors(sid)) def station_info(self, station_id: str) -> dict | None: """Return full station data, or None if the station does not exist.""" try: sid = self._resolve_station(station_id) except ValueError: return None return self.stations.get(sid) def _resolve_station(self, name_or_id: str) -> str: """Resolve a station name or ID to its canonical ID. Accepts an exact station ID or a station name (case-insensitive). Also matches the base name without parenthetical suffixes, e.g. "Olympic Park" matches "Aolinpike Gongyuan (Olympic Park)". Raises ValueError if no match is found. """ if name_or_id in self.stations: return name_or_id name_lower = name_or_id.lower().strip() for sid, sdata in self.stations.items(): full = sdata["name"].lower() # Exact match if full == name_lower: return sid # Match base name (before parenthetical) if "(" in full: base = full.split("(")[0].strip() paren = full.split("(")[1].rstrip(")").strip() if name_lower == base or name_lower == paren: return sid raise ValueError(f"Unknown station: '{name_or_id}'")