File size: 22,117 Bytes
2d05890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
"""Station graph operations using NetworkX β€” expanded line-graph routing."""

import json
from collections import defaultdict
from pathlib import Path
from dataclasses import dataclass, field

import networkx as nx

TRANSFER_PENALTY_MIN = 5.0


@dataclass
class RouteResult:
    path: list[str]           # station IDs in order
    stations: list[dict]      # full station info per stop (name, line, is_transfer, etc.)
    distance_miles: float
    estimated_minutes: float
    transfers: int
    line_sequence: list[str]  # e.g. ["red", "blue"] if transferring


class MetroGraph:
    def __init__(self, system_dir: Path):
        """Load graph.json, stations.json, lines.json from a system directory."""
        self.system_dir = system_dir

        with open(system_dir / "stations.json") as f:
            stations_list = json.load(f)
        self.stations: dict[str, dict] = {s["id"]: s for s in stations_list}

        with open(system_dir / "lines.json") as f:
            self.lines: dict[str, dict] = {l["id"]: l for l in json.load(f)}

        with open(system_dir / "graph.json") as f:
            graph_data = json.load(f)

        self._edges_raw: list[dict] = graph_data["edges"]

        # station_id -> set of line_ids serving that station (derived from edges)
        self.station_lines: dict[str, set[str]] = defaultdict(set)
        for edge in self._edges_raw:
            self.station_lines[edge["from"]].add(edge["line"])
            self.station_lines[edge["to"]].add(edge["line"])

        # Simple graph for connectivity checks (is_valid_path, adjacent_stations)
        self.G: nx.Graph = nx.Graph()
        for sid, sdata in self.stations.items():
            self.G.add_node(sid, **sdata)
        for edge in self._edges_raw:
            self.G.add_edge(
                edge["from"],
                edge["to"],
                distance_miles=edge["distance_miles"],
                travel_time_min=edge["travel_time_min"],
                line=edge["line"],
                type=edge["type"],
            )

        # Expanded directed graph for routing
        self._expanded = self._build_expanded(self._edges_raw, set(self.stations))

    def _build_expanded(
        self,
        edges: list[dict],
        station_ids: set[str],
    ) -> nx.DiGraph:
        """Build the expanded line graph for transfer-aware Dijkstra.

        Nodes:
          ("enter", station_id)  β€” virtual entry point
          (station_id, line_id)  β€” station on a specific line
          ("exit", station_id)   β€” virtual exit point

        Edges:
          entry:    ("enter", s) β†’ (s, line)        weight=0, distance=0
          exit:     (s, line)    β†’ ("exit", s)       weight=0, distance=0
          travel:   (sA, line)   β†’ (sB, line)        weight=travel_time, distance=d
          transfer: (s, lineA)   β†’ (s, lineB)        weight=TRANSFER_PENALTY_MIN, distance=0
        """
        G = nx.DiGraph()

        # Collect which lines serve each station
        station_lines: dict[str, set[str]] = defaultdict(set)

        for edge in edges:
            s_from, s_to = edge["from"], edge["to"]
            line = edge["line"]
            dist = edge["distance_miles"]
            time = edge["travel_time_min"]

            station_lines[s_from].add(line)
            station_lines[s_to].add(line)

            # Travel edges (both directions since graph is undirected)
            G.add_edge(
                (s_from, line), (s_to, line),
                weight=time, distance_miles=dist, line=line,
                edge_type="travel",
            )
            G.add_edge(
                (s_to, line), (s_from, line),
                weight=time, distance_miles=dist, line=line,
                edge_type="travel",
            )

        # Entry, exit, and transfer edges
        for sid in station_ids:
            lines = station_lines.get(sid, set())
            for line in lines:
                # Entry
                G.add_edge(
                    ("enter", sid), (sid, line),
                    weight=0, distance_miles=0, edge_type="entry",
                )
                # Exit
                G.add_edge(
                    (sid, line), ("exit", sid),
                    weight=0, distance_miles=0, edge_type="exit",
                )

            # Transfer edges between all line pairs at this station
            lines_list = sorted(lines)
            for i, lineA in enumerate(lines_list):
                for lineB in lines_list[i + 1:]:
                    G.add_edge(
                        (sid, lineA), (sid, lineB),
                        weight=TRANSFER_PENALTY_MIN, distance_miles=0,
                        edge_type="transfer",
                    )
                    G.add_edge(
                        (sid, lineB), (sid, lineA),
                        weight=TRANSFER_PENALTY_MIN, distance_miles=0,
                        edge_type="transfer",
                    )

        return G

    def lines_for_station(self, station_id: str) -> set[str]:
        """Return the set of line ids that serve a station."""
        sid = self._resolve_station(station_id)
        return set(self.station_lines.get(sid, set()))

    def _line_subgraph(self, line_id: str) -> nx.Graph:
        if line_id not in self.lines:
            raise ValueError(f"Unknown line: {line_id}")
        sub = nx.Graph()
        for edge in self._edges_raw:
            if edge["line"] == line_id:
                sub.add_edge(edge["from"], edge["to"])
        return sub

    def is_loop_line(self, line_id: str) -> bool:
        """True if the line has no terminals (every station has degree >= 2 on its own line)."""
        sub = self._line_subgraph(line_id)
        if sub.number_of_nodes() == 0:
            return False
        return all(deg >= 2 for _, deg in sub.degree())

    def line_terminals(self, line_id: str) -> list[str]:
        """Stations with degree 1 on the line subgraph. Empty list for loop lines."""
        sub = self._line_subgraph(line_id)
        return [n for n, deg in sub.degree() if deg == 1]

    def expand_line_closures(
        self,
        closures: list[dict],
    ) -> list[tuple[str, str]]:
        """Expand line-level closures into segment_closures.

        Each closure dict: {"line": str, "from_station"?: str, "to_station"?: str}.
        Omitting both endpoints closes the entire line. Partial closure requires
        both endpoints and raises ValueError on a loop line (ambiguous).
        """
        segments: list[tuple[str, str]] = []
        for c in closures:
            line_id = c.get("line")
            if not line_id or line_id not in self.lines:
                raise ValueError(f"Unknown line: {line_id}")
            from_s = c.get("from_station")
            to_s = c.get("to_station")
            ordered = list(self.lines[line_id].get("stations", []))
            if not ordered:
                raise ValueError(f"Line '{line_id}' has no stations defined")
            if from_s is None and to_s is None:
                keep = set(ordered)
            elif from_s is None or to_s is None:
                raise ValueError(
                    f"Partial closure on line '{line_id}' requires both from_station and to_station"
                )
            else:
                if self.is_loop_line(line_id):
                    raise ValueError(
                        f"Partial closure on loop line '{line_id}' is ambiguous β€” use whole-line closure or specify segments"
                    )
                a = self._resolve_station(from_s)
                b = self._resolve_station(to_s)
                if a not in ordered or b not in ordered:
                    raise ValueError(
                        f"Endpoints '{from_s}'/'{to_s}' are not on line '{line_id}'"
                    )
                i, j = ordered.index(a), ordered.index(b)
                lo, hi = min(i, j), max(i, j)
                keep = set(ordered[lo:hi + 1])
            for edge in self._edges_raw:
                if (
                    edge["line"] == line_id
                    and edge["from"] in keep
                    and edge["to"] in keep
                ):
                    segments.append((edge["from"], edge["to"]))
        return segments

    def shortest_path(self, origin: str, destination: str) -> RouteResult:
        """Find shortest path by time (with transfer penalty). Returns RouteResult.

        Raises ValueError if either station cannot be resolved.
        Raises nx.NetworkXNoPath if no path exists between the two stations.
        Raises nx.NodeNotFound if a resolved ID is not present in the graph.
        """
        origin_id = self._resolve_station(origin)
        dest_id = self._resolve_station(destination)

        if origin_id == dest_id:
            station = self.stations[origin_id]
            stop = {
                "station_id": origin_id,
                "station_name": station["name"],
                "line": None,
                "is_transfer": False,
                "transfer_to": None,
            }
            return RouteResult(
                path=[origin_id],
                stations=[stop],
                distance_miles=0.0,
                estimated_minutes=0.0,
                transfers=0,
                line_sequence=[],
            )

        return self._route_on_expanded(origin_id, dest_id, self._expanded)

    def shortest_path_avoiding(
        self,
        origin: str,
        destination: str,
        blocked_edges: list[tuple[str, str]] | None = None,
        blocked_stations: list[str] | None = None,
    ) -> RouteResult:
        """Compute shortest path avoiding specified edges and stations.

        Used by case generator for computing post-disruption alternative routes.
        Rebuilds the expanded graph with disrupted edges/stations removed.

        Raises ValueError if origin or destination is blocked or cannot be resolved.
        Raises nx.NetworkXNoPath if no alternative path exists.
        """
        origin_id = self._resolve_station(origin)
        dest_id = self._resolve_station(destination)

        blocked_station_set = set(blocked_stations) if blocked_stations else set()
        blocked_edge_set = set()
        if blocked_edges:
            for u, v in blocked_edges:
                blocked_edge_set.add((u, v))
                blocked_edge_set.add((v, u))

        if origin_id in blocked_station_set:
            raise ValueError(
                f"Origin station '{origin}' is blocked by disruption"
            )
        if dest_id in blocked_station_set:
            raise ValueError(
                f"Destination station '{destination}' is blocked by disruption"
            )

        # Filter edges and stations
        remaining_stations = set(self.stations) - blocked_station_set
        remaining_edges = [
            e for e in self._edges_raw
            if e["from"] not in blocked_station_set
            and e["to"] not in blocked_station_set
            and (e["from"], e["to"]) not in blocked_edge_set
        ]

        expanded = self._build_expanded(remaining_edges, remaining_stations)

        try:
            return self._route_on_expanded(origin_id, dest_id, expanded)
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            raise nx.NetworkXNoPath(
                f"No alternative path between '{origin}' and '{destination}' "
                "with current disruption"
            )

    def shortest_path_with_restrictions(
        self,
        origin: str,
        destination: str,
        station_restrictions: list[dict] | None = None,
        segment_closures: list[tuple[str, str]] | None = None,
    ) -> RouteResult:
        """Compute shortest path with typed station restrictions.

        station_restrictions: list of {"station": name_or_id, "restriction": type}
            - "closed": no entry, exit, transfer, or pass-through
            - "skip": trains pass through but don't stop (no entry/exit/transfer)
            - "no_transfer": can board/alight but cannot change lines

        segment_closures: list of (stationA, stationB) pairs where track is closed.

        Raises ValueError if origin/destination is closed or skip.
        Raises nx.NetworkXNoPath if no path exists with restrictions.
        """
        if not station_restrictions and not segment_closures:
            return self.shortest_path(origin, destination)

        origin_id = self._resolve_station(origin)
        dest_id = self._resolve_station(destination)

        # Build restrictions map: station_id β†’ restriction type
        restrictions_map: dict[str, str] = {}
        for r in (station_restrictions or []):
            sid = self._resolve_station(r["station"])
            restrictions_map[sid] = r["restriction"]

        # Validate origin/destination
        for label, sid, name in [("Origin", origin_id, origin),
                                  ("Destination", dest_id, destination)]:
            restriction = restrictions_map.get(sid)
            if restriction in ("closed", "skip"):
                raise ValueError(
                    f"{label} station '{name}' is {restriction} by disruption"
                )

        # Build segment closure set (both directions)
        closed_segments: set[tuple[str, str]] = set()
        for seg in (segment_closures or []):
            u = self._resolve_station(seg[0])
            v = self._resolve_station(seg[1])
            closed_segments.add((u, v))
            closed_segments.add((v, u))

        # Build expanded graph with restrictions
        closed_stations = {s for s, r in restrictions_map.items() if r == "closed"}
        skip_stations = {s for s, r in restrictions_map.items() if r == "skip"}
        no_transfer_stations = {s for s, r in restrictions_map.items()
                                if r == "no_transfer"}

        G = nx.DiGraph()
        station_lines: dict[str, set[str]] = defaultdict(set)

        # Phase 1: travel edges
        for edge in self._edges_raw:
            s_from, s_to = edge["from"], edge["to"]
            line = edge["line"]
            dist = edge["distance_miles"]
            time = edge["travel_time_min"]

            # Skip segment closures
            if (s_from, s_to) in closed_segments:
                continue
            # Skip travel edges touching closed stations
            if s_from in closed_stations or s_to in closed_stations:
                continue

            station_lines[s_from].add(line)
            station_lines[s_to].add(line)

            G.add_edge(
                (s_from, line), (s_to, line),
                weight=time, distance_miles=dist, line=line,
                edge_type="travel",
            )
            G.add_edge(
                (s_to, line), (s_from, line),
                weight=time, distance_miles=dist, line=line,
                edge_type="travel",
            )

        # Phase 2: entry, exit, transfer edges
        no_entry_exit = closed_stations | skip_stations
        no_transfer = closed_stations | skip_stations | no_transfer_stations

        for sid in set(self.stations) - closed_stations:
            lines = station_lines.get(sid, set())

            if sid not in no_entry_exit:
                for line in lines:
                    G.add_edge(
                        ("enter", sid), (sid, line),
                        weight=0, distance_miles=0, edge_type="entry",
                    )
                    G.add_edge(
                        (sid, line), ("exit", sid),
                        weight=0, distance_miles=0, edge_type="exit",
                    )

            if sid not in no_transfer:
                lines_list = sorted(lines)
                for i, lineA in enumerate(lines_list):
                    for lineB in lines_list[i + 1:]:
                        G.add_edge(
                            (sid, lineA), (sid, lineB),
                            weight=TRANSFER_PENALTY_MIN, distance_miles=0,
                            edge_type="transfer",
                        )
                        G.add_edge(
                            (sid, lineB), (sid, lineA),
                            weight=TRANSFER_PENALTY_MIN, distance_miles=0,
                            edge_type="transfer",
                        )

        try:
            return self._route_on_expanded(origin_id, dest_id, G)
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            raise nx.NetworkXNoPath(
                f"No path between '{origin}' and '{destination}' "
                "with current restrictions"
            )

    def _route_on_expanded(
        self, origin_id: str, dest_id: str, expanded: nx.DiGraph
    ) -> RouteResult:
        """Run Dijkstra on the expanded graph and convert to RouteResult."""
        enter_node = ("enter", origin_id)
        exit_node = ("exit", dest_id)

        if enter_node not in expanded:
            raise nx.NodeNotFound(
                f"Node '{origin_id}' is not in the expanded graph"
            )
        if exit_node not in expanded:
            raise nx.NodeNotFound(
                f"Node '{dest_id}' is not in the expanded graph"
            )

        try:
            exp_path = nx.shortest_path(
                expanded, enter_node, exit_node, weight="weight"
            )
        except nx.NetworkXNoPath:
            raise nx.NetworkXNoPath(
                f"No path found between '{origin_id}' and '{dest_id}'"
            )

        # Convert expanded path to station-level RouteResult
        path: list[str] = []
        stops: list[dict] = []
        line_sequence: list[str] = []
        total_distance = 0.0
        total_time = 0.0
        transfers = 0
        current_line: str | None = None

        for i in range(len(exp_path) - 1):
            node = exp_path[i]
            next_node = exp_path[i + 1]
            edge_data = expanded[node][next_node]
            edge_type = edge_data["edge_type"]

            if edge_type == "entry":
                # (enter, station) -> (station, line): add origin station
                station_id = node[1]
                line = next_node[1]
                current_line = line
                if line not in line_sequence:
                    line_sequence.append(line)
                station = self.stations[station_id]
                path.append(station_id)
                stops.append({
                    "station_id": station_id,
                    "station_name": station["name"],
                    "line": current_line,
                    "is_transfer": False,
                    "transfer_to": None,
                })

            elif edge_type == "travel":
                # (stationA, line) -> (stationB, line): add stationB
                station_id = next_node[0]
                total_distance += edge_data["distance_miles"]
                total_time += edge_data["weight"]
                station = self.stations[station_id]
                path.append(station_id)
                stops.append({
                    "station_id": station_id,
                    "station_name": station["name"],
                    "line": current_line,
                    "is_transfer": False,
                    "transfer_to": None,
                })

            elif edge_type == "transfer":
                # (station, lineA) -> (station, lineB): transfer at station
                new_line = next_node[1]
                transfers += 1
                total_time += edge_data["weight"]
                if new_line not in line_sequence:
                    line_sequence.append(new_line)
                # Mark the last stop as a transfer point
                if stops:
                    stops[-1]["is_transfer"] = True
                    stops[-1]["transfer_to"] = new_line
                current_line = new_line

            # exit edges: no action needed

        return RouteResult(
            path=path,
            stations=stops,
            distance_miles=round(total_distance, 2),
            estimated_minutes=round(total_time, 1),
            transfers=transfers,
            line_sequence=line_sequence,
        )

    def is_valid_path(self, path: list[str]) -> bool:
        """Check if all consecutive stations in path are adjacent in the graph."""
        if len(path) == 0:
            return False
        for i in range(len(path) - 1):
            if not self.G.has_edge(path[i], path[i + 1]):
                return False
        return True

    def adjacent_stations(self, station_id: str) -> list[str]:
        """Return neighbor station IDs for a given station.

        Raises ValueError if the station cannot be resolved.
        """
        sid = self._resolve_station(station_id)
        return list(self.G.neighbors(sid))

    def station_info(self, station_id: str) -> dict | None:
        """Return full station data, or None if the station does not exist."""
        try:
            sid = self._resolve_station(station_id)
        except ValueError:
            return None
        return self.stations.get(sid)

    def _resolve_station(self, name_or_id: str) -> str:
        """Resolve a station name or ID to its canonical ID.

        Accepts an exact station ID or a station name (case-insensitive).
        Also matches the base name without parenthetical suffixes, e.g.
        "Olympic Park" matches "Aolinpike Gongyuan (Olympic Park)".
        Raises ValueError if no match is found.
        """
        if name_or_id in self.stations:
            return name_or_id

        name_lower = name_or_id.lower().strip()
        for sid, sdata in self.stations.items():
            full = sdata["name"].lower()
            # Exact match
            if full == name_lower:
                return sid
            # Match base name (before parenthetical)
            if "(" in full:
                base = full.split("(")[0].strip()
                paren = full.split("(")[1].rstrip(")").strip()
                if name_lower == base or name_lower == paren:
                    return sid

        raise ValueError(f"Unknown station: '{name_or_id}'")