File size: 22,173 Bytes
08e15f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
"""
Real-world routing service using OSMnx for road network data.

This module provides:
- OSMnxRoutingService: Downloads OSM network, caches locally, computes routes
- DistanceMatrix: Precomputes all pairwise routes with times and geometries
- Haversine fallback when OSMnx is unavailable
"""

from __future__ import annotations

import logging
import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Optional

import polyline

if TYPE_CHECKING:
    from .domain import Location

logger = logging.getLogger(__name__)

# Cache directory for OSM network data
CACHE_DIR = Path(__file__).parent.parent.parent / ".osm_cache"


@dataclass
class RouteResult:
    """Result from a routing query."""

    duration_seconds: int
    distance_meters: int
    geometry: Optional[str] = None  # Encoded polyline


@dataclass
class DistanceMatrix:
    """
    Precomputed distance/time matrix for all location pairs.

    Stores RouteResult for each (origin, destination) pair,
    enabling O(1) lookup during solver execution.
    """

    _matrix: dict[tuple[tuple[float, float], tuple[float, float]], RouteResult] = field(
        default_factory=dict
    )

    def _key(
        self, origin: "Location", destination: "Location"
    ) -> tuple[tuple[float, float], tuple[float, float]]:
        """Create hashable key from two locations."""
        return (
            (origin.latitude, origin.longitude),
            (destination.latitude, destination.longitude),
        )

    def set_route(
        self, origin: "Location", destination: "Location", result: RouteResult
    ) -> None:
        """Store a route result in the matrix."""
        self._matrix[self._key(origin, destination)] = result

    def get_route(
        self, origin: "Location", destination: "Location"
    ) -> Optional[RouteResult]:
        """Get a route result from the matrix."""
        return self._matrix.get(self._key(origin, destination))

    def get_driving_time(self, origin: "Location", destination: "Location") -> int:
        """Get driving time in seconds between two locations."""
        result = self.get_route(origin, destination)
        if result is None:
            # Fallback to haversine if not in matrix
            return _haversine_driving_time(origin, destination)
        return result.duration_seconds

    def get_geometry(
        self, origin: "Location", destination: "Location"
    ) -> Optional[str]:
        """Get encoded polyline geometry for a route segment."""
        result = self.get_route(origin, destination)
        return result.geometry if result else None


def _haversine_driving_time(origin: "Location", destination: "Location") -> int:
    """
    Calculate driving time using haversine formula (fallback).

    Uses 50 km/h average speed assumption.
    """
    if (
        origin.latitude == destination.latitude
        and origin.longitude == destination.longitude
    ):
        return 0

    EARTH_RADIUS_M = 6371000
    AVERAGE_SPEED_KMPH = 50

    lat1 = math.radians(origin.latitude)
    lon1 = math.radians(origin.longitude)
    lat2 = math.radians(destination.latitude)
    lon2 = math.radians(destination.longitude)

    # Haversine formula
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2
    c = 2 * math.asin(math.sqrt(a))
    distance_meters = EARTH_RADIUS_M * c

    # Convert to driving time
    return round(distance_meters / AVERAGE_SPEED_KMPH * 3.6)


class OSMnxRoutingService:
    """
    Routing service using OSMnx for real road network data.

    Downloads the OSM network for a given bounding box, caches it locally,
    and computes shortest paths using NetworkX.
    """

    def __init__(self, cache_dir: Path = CACHE_DIR):
        self.cache_dir = cache_dir
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self._graph = None
        self._graph_bbox = None

    def _get_cache_path(
        self, north: float, south: float, east: float, west: float
    ) -> Path:
        """Generate cache file path for a bounding box."""
        # Round to 2 decimal places for cache key
        key = f"osm_{north:.2f}_{south:.2f}_{east:.2f}_{west:.2f}.graphml"
        return self.cache_dir / key

    def load_network(
        self, north: float, south: float, east: float, west: float, padding: float = 0.01
    ) -> bool:
        """
        Load OSM road network for the given bounding box.

        Args:
            north, south, east, west: Bounding box coordinates
            padding: Extra padding around the bbox (in degrees)

        Returns:
            True if network loaded successfully, False otherwise
        """
        try:
            import osmnx as ox

            # Add padding to ensure we have roads outside the strict bbox
            north += padding
            south -= padding
            east += padding
            west -= padding

            cache_path = self._get_cache_path(north, south, east, west)

            if cache_path.exists() and cache_path.stat().st_size > 0:
                logger.info(f"Loading cached OSM network from {cache_path}")
                self._graph = ox.load_graphml(cache_path)

                # Check if the cached graph already has travel_time
                # (we now save enriched graphs)
                sample_edge = next(iter(self._graph.edges(data=True)), None)
                has_travel_time = sample_edge and "travel_time" in sample_edge[2]

                if not has_travel_time:
                    logger.info("Adding edge speeds and travel times to cached graph...")
                    self._graph = ox.add_edge_speeds(self._graph)
                    self._graph = ox.add_edge_travel_times(self._graph)
                    # Re-save with travel times included
                    ox.save_graphml(self._graph, cache_path)
                    logger.info("Updated cache with travel times")
            else:
                logger.info(
                    f"Downloading OSM network for bbox: N={north:.4f}, S={south:.4f}, E={east:.4f}, W={west:.4f}"
                )
                # OSMnx 2.x uses bbox as tuple: (left, bottom, right, top) = (west, south, east, north)
                bbox_tuple = (west, south, east, north)
                self._graph = ox.graph_from_bbox(
                    bbox=bbox_tuple,
                    network_type="drive",
                    simplify=True,
                )

                # Add edge speeds and travel times BEFORE caching
                logger.info("Computing edge speeds and travel times...")
                self._graph = ox.add_edge_speeds(self._graph)
                self._graph = ox.add_edge_travel_times(self._graph)

                # Save enriched graph to cache
                ox.save_graphml(self._graph, cache_path)
                logger.info(f"Saved enriched OSM network to cache: {cache_path}")

            self._graph_bbox = (north, south, east, west)
            logger.info(
                f"OSM network loaded: {self._graph.number_of_nodes()} nodes, "
                f"{self._graph.number_of_edges()} edges"
            )
            return True

        except ImportError:
            logger.warning("OSMnx not installed, falling back to haversine")
            return False
        except Exception as e:
            logger.warning(f"Failed to load OSM network: {e}, falling back to haversine")
            return False

    def get_nearest_node(self, location: "Location") -> Optional[int]:
        """Get the nearest graph node for a location."""
        if self._graph is None:
            return None
        try:
            import osmnx as ox
            return ox.nearest_nodes(self._graph, location.longitude, location.latitude)
        except Exception:
            return None

    def compute_all_routes(
        self,
        locations: list["Location"],
        progress_callback=None
    ) -> dict[tuple[int, int], RouteResult]:
        """
        Compute all pairwise routes efficiently using batch shortest paths.

        Returns a dict mapping (origin_idx, dest_idx) to RouteResult.
        """
        import networkx as nx

        if self._graph is None:
            return {}

        results = {}
        n = len(locations)

        # Map locations to nearest nodes (batch operation)
        if progress_callback:
            progress_callback("routes", "Finding nearest road nodes...", 30, f"{n} locations")

        nodes = []
        for loc in locations:
            node = self.get_nearest_node(loc)
            nodes.append(node)

        # Compute shortest paths from each origin to ALL destinations at once
        # This is MUCH faster than individual shortest_path calls
        total_origins = sum(1 for node in nodes if node is not None)
        processed = 0

        for i, origin_node in enumerate(nodes):
            if origin_node is None:
                continue

            # Compute shortest paths from this origin to all nodes at once
            # Using Dijkstra's algorithm with single-source
            try:
                lengths, paths = nx.single_source_dijkstra(
                    self._graph, origin_node, weight="travel_time"
                )
            except nx.NetworkXError:
                continue

            for j, dest_node in enumerate(nodes):
                if dest_node is None:
                    continue

                origin_loc = locations[i]
                dest_loc = locations[j]

                if i == j or origin_node == dest_node:
                    # Same location
                    results[(i, j)] = RouteResult(
                        duration_seconds=0,
                        distance_meters=0,
                        geometry=polyline.encode(
                            [(origin_loc.latitude, origin_loc.longitude)], precision=5
                        ),
                    )
                elif dest_node in paths:
                    path = paths[dest_node]
                    travel_time = lengths[dest_node]

                    # Calculate distance and extract geometry
                    total_distance = 0
                    coordinates = []

                    for k in range(len(path) - 1):
                        u, v = path[k], path[k + 1]
                        edge_data = self._graph.get_edge_data(u, v)
                        if edge_data:
                            edge = edge_data[0] if isinstance(edge_data, dict) else edge_data
                            total_distance += edge.get("length", 0)

                    for node in path:
                        node_data = self._graph.nodes[node]
                        coordinates.append((node_data["y"], node_data["x"]))

                    results[(i, j)] = RouteResult(
                        duration_seconds=round(travel_time),
                        distance_meters=round(total_distance),
                        geometry=polyline.encode(coordinates, precision=5),
                    )

            processed += 1
            if progress_callback and processed % max(1, total_origins // 10) == 0:
                percent = 30 + int((processed / total_origins) * 65)
                progress_callback(
                    "routes",
                    "Computing routes...",
                    percent,
                    f"{processed}/{total_origins} origins processed"
                )

        return results

    def get_route(
        self, origin: "Location", destination: "Location"
    ) -> Optional[RouteResult]:
        """
        Compute route between two locations.

        Returns:
            RouteResult with duration, distance, and geometry, or None if routing fails
        """
        if self._graph is None:
            return None

        try:
            import osmnx as ox

            # Find nearest nodes to origin and destination
            origin_node = ox.nearest_nodes(
                self._graph, origin.longitude, origin.latitude
            )
            dest_node = ox.nearest_nodes(
                self._graph, destination.longitude, destination.latitude
            )

            # Same node means same location (or very close)
            if origin_node == dest_node:
                return RouteResult(
                    duration_seconds=0,
                    distance_meters=0,
                    geometry=polyline.encode(
                        [(origin.latitude, origin.longitude)], precision=5
                    ),
                )

            # Compute shortest path by travel time
            route = ox.shortest_path(
                self._graph, origin_node, dest_node, weight="travel_time"
            )

            if route is None:
                logger.warning(
                    f"No route found between {origin} and {destination}"
                )
                return None

            # Extract route attributes
            total_time = 0
            total_distance = 0
            coordinates = []

            for i in range(len(route) - 1):
                u, v = route[i], route[i + 1]
                edge_data = self._graph.get_edge_data(u, v)
                if edge_data:
                    # Get the first edge if multiple exist
                    edge = edge_data[0] if isinstance(edge_data, dict) else edge_data
                    total_time += edge.get("travel_time", 0)
                    total_distance += edge.get("length", 0)

            # Get node coordinates for geometry
            for node in route:
                node_data = self._graph.nodes[node]
                coordinates.append((node_data["y"], node_data["x"]))

            # Encode geometry as polyline
            encoded_geometry = polyline.encode(coordinates, precision=5)

            return RouteResult(
                duration_seconds=round(total_time),
                distance_meters=round(total_distance),
                geometry=encoded_geometry,
            )

        except Exception as e:
            logger.warning(f"Routing failed: {e}")
            return None


def compute_distance_matrix(
    locations: list["Location"],
    routing_service: Optional[OSMnxRoutingService] = None,
    bbox: Optional[tuple[float, float, float, float]] = None,
) -> DistanceMatrix:
    """
    Compute distance matrix for all location pairs.

    Args:
        locations: List of Location objects
        routing_service: Optional pre-configured routing service
        bbox: Optional (north, south, east, west) tuple for network download

    Returns:
        DistanceMatrix with precomputed routes
    """
    return compute_distance_matrix_with_progress(
        locations, routing_service, bbox, use_osm=True, progress_callback=None
    )


def compute_distance_matrix_with_progress(
    locations: list["Location"],
    bbox: Optional[tuple[float, float, float, float]] = None,
    use_osm: bool = True,
    progress_callback=None,
    routing_service: Optional[OSMnxRoutingService] = None,
) -> DistanceMatrix:
    """
    Compute distance matrix for all location pairs with progress reporting.

    Args:
        locations: List of Location objects
        bbox: Optional (north, south, east, west) tuple for network download
        use_osm: If True, try to use OSMnx for real routing. If False, use haversine.
        progress_callback: Optional callback(phase, message, percent, detail) for progress updates
        routing_service: Optional pre-configured routing service

    Returns:
        DistanceMatrix with precomputed routes
    """
    matrix = DistanceMatrix()

    if not locations:
        return matrix

    def report_progress(phase: str, message: str, percent: int, detail: str = ""):
        if progress_callback:
            progress_callback(phase, message, percent, detail)
        logger.info(f"[{phase}] {message} ({percent}%) {detail}")

    # Compute bounding box from locations if not provided
    if bbox is None:
        lats = [loc.latitude for loc in locations]
        lons = [loc.longitude for loc in locations]
        bbox = (max(lats), min(lats), max(lons), min(lons))

    osm_loaded = False

    if use_osm:
        # Create routing service if not provided
        if routing_service is None:
            routing_service = OSMnxRoutingService()

        report_progress("network", "Checking for cached road network...", 5)

        # Check if cached
        north, south, east, west = bbox
        north += 0.01  # padding
        south -= 0.01
        east += 0.01
        west -= 0.01

        cache_path = routing_service._get_cache_path(north, south, east, west)
        is_cached = cache_path.exists()

        if is_cached:
            report_progress("network", "Loading cached road network...", 10, str(cache_path.name))
        else:
            report_progress(
                "network",
                "Downloading OpenStreetMap road network...",
                10,
                f"Area: {abs(north-south):.2f}° × {abs(east-west):.2f}°"
            )

        # Try to load OSM network
        osm_loaded = routing_service.load_network(
            north=bbox[0], south=bbox[1], east=bbox[2], west=bbox[3]
        )

        if osm_loaded:
            node_count = routing_service._graph.number_of_nodes()
            edge_count = routing_service._graph.number_of_edges()
            report_progress(
                "network",
                "Road network loaded",
                25,
                f"{node_count:,} nodes, {edge_count:,} edges"
            )
        else:
            report_progress("network", "OSMnx unavailable, using haversine", 25)
    else:
        report_progress("network", "Using fast haversine mode", 25)

    # Compute all pairwise routes
    total_pairs = len(locations) * len(locations)

    if osm_loaded and routing_service:
        # Use batch routing for OSMnx (MUCH faster than individual calls)
        report_progress(
            "routes",
            f"Computing {total_pairs:,} routes (batch mode)...",
            30,
            f"{len(locations)} locations"
        )

        batch_results = routing_service.compute_all_routes(
            locations,
            progress_callback=report_progress
        )

        # Transfer batch results to matrix, with haversine fallback for missing routes
        computed = 0
        for i, origin in enumerate(locations):
            for j, destination in enumerate(locations):
                if (i, j) in batch_results:
                    matrix.set_route(origin, destination, batch_results[(i, j)])
                else:
                    # Fallback to haversine for routes not found
                    matrix.set_route(
                        origin,
                        destination,
                        RouteResult(
                            duration_seconds=_haversine_driving_time(origin, destination),
                            distance_meters=_haversine_distance_meters(origin, destination),
                            geometry=_straight_line_geometry(origin, destination),
                        ),
                    )
                computed += 1

        report_progress("complete", "Distance matrix ready", 100, f"{computed:,} routes computed")
    else:
        # Use haversine fallback for all routes
        report_progress(
            "routes",
            f"Computing {total_pairs:,} route pairs...",
            30,
            f"{len(locations)} locations"
        )

        computed = 0
        for origin in locations:
            for destination in locations:
                if origin is destination:
                    matrix.set_route(
                        origin,
                        destination,
                        RouteResult(
                            duration_seconds=0,
                            distance_meters=0,
                            geometry=polyline.encode(
                                [(origin.latitude, origin.longitude)], precision=5
                            ),
                        ),
                    )
                else:
                    matrix.set_route(
                        origin,
                        destination,
                        RouteResult(
                            duration_seconds=_haversine_driving_time(origin, destination),
                            distance_meters=_haversine_distance_meters(origin, destination),
                            geometry=_straight_line_geometry(origin, destination),
                        ),
                    )
                computed += 1

                # Report progress every 5%
                if total_pairs > 0 and computed % max(1, total_pairs // 20) == 0:
                    percent_complete = int(30 + (computed / total_pairs) * 65)
                    report_progress(
                        "routes",
                        f"Computing routes...",
                        percent_complete,
                        f"{computed:,}/{total_pairs:,} pairs"
                    )

        report_progress("complete", "Distance matrix ready", 100, f"{computed:,} routes computed")

    return matrix


def _haversine_distance_meters(origin: "Location", destination: "Location") -> int:
    """Calculate haversine distance in meters."""
    if (
        origin.latitude == destination.latitude
        and origin.longitude == destination.longitude
    ):
        return 0

    EARTH_RADIUS_M = 6371000

    lat1 = math.radians(origin.latitude)
    lon1 = math.radians(origin.longitude)
    lat2 = math.radians(destination.latitude)
    lon2 = math.radians(destination.longitude)

    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2
    c = 2 * math.asin(math.sqrt(a))

    return round(EARTH_RADIUS_M * c)


def _straight_line_geometry(origin: "Location", destination: "Location") -> str:
    """Generate a straight-line encoded polyline between two points."""
    return polyline.encode(
        [(origin.latitude, origin.longitude), (destination.latitude, destination.longitude)],
        precision=5,
    )