Spaces:
Running
Running
| """ | |
| Graph data management using NetworkX and OSMnx. | |
| Responsible for downloading, processing, caching, and serving routing road networks. | |
| """ | |
| import logging | |
| import networkx as nx | |
| import osmnx as ox | |
| from pathlib import Path | |
| from typing import Optional, List | |
| import math | |
| # Configure module logger | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| class GraphManager: | |
| """ | |
| Manages loading, updating, and querying the spatial graph. | |
| """ | |
| def __init__(self, cache_dir: str = "./data/graph"): | |
| """ | |
| Initializes the GraphManager. | |
| Args: | |
| cache_dir (str): Directory where graph files are stored locally. | |
| """ | |
| self.cache_dir = Path(cache_dir) | |
| # Ensure the cache directory exists | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| self.graph: Optional[nx.MultiDiGraph] = None | |
| # Configure OSMnx performance/cache settings | |
| ox.settings.use_cache = True | |
| ox.settings.log_console = False | |
| def _impute_edge_attributes(self, G: nx.MultiDiGraph) -> nx.MultiDiGraph: | |
| """ | |
| Imputes missing edge speeds and calculates base travel times. | |
| Args: | |
| G (nx.MultiDiGraph): Unprocessed NetworkX graph object. | |
| Returns: | |
| nx.MultiDiGraph: Graph with populated edge speed/time properties. | |
| """ | |
| logger.info("Imputing edge speeds and travel times...") | |
| # Fallback speeds based on standard highway types (in km/h) | |
| # This acts as a fallback when maxspeed tags are missing from OSM data | |
| fallback_speeds = { | |
| 'residential': 30, | |
| 'secondary': 50, | |
| 'tertiary': 40, | |
| 'primary': 60, | |
| 'motorway': 100, | |
| 'unclassified': 30, | |
| 'default': 30 | |
| } | |
| # Adds speed_kph attribute to edges based on fallback / existing data | |
| G = ox.add_edge_speeds(G, hwy_speeds=fallback_speeds) | |
| # Calculates base_travel_time (seconds) based on length (meters) / speed_kph | |
| G = ox.add_edge_travel_times(G) | |
| return G | |
| def load_or_create_graph(self, city_name: str) -> nx.MultiDiGraph: | |
| """ | |
| Downloads a city's road network, extracts nodes & edges, imputes properties, | |
| and safely caches it locally to avoid repeated downloads. | |
| Args: | |
| city_name (str): Full city string (e.g., San Francisco, California, USA). | |
| Returns: | |
| nx.MultiDiGraph: Processed NetworkX routing graph. | |
| """ | |
| # Create a safe, deterministic filename from the city name | |
| safe_name = city_name.replace(", ", "_").replace(" ", "_").lower() | |
| filepath = self.cache_dir / f"{safe_name}.graphml" | |
| if filepath.exists(): | |
| logger.info(f"Graph for '{city_name}' found in cache at {filepath}. Loading (this may take a moment)...") | |
| try: | |
| self.graph = ox.load_graphml(filepath) | |
| logger.info(f"Successfully loaded graph: {len(self.graph.nodes)} nodes, {len(self.graph.edges)} edges.") | |
| return self.graph | |
| except Exception as e: | |
| logger.error(f"Failed to load cached graph: {e}. Will attempt re-download...") | |
| logger.info(f"Graph for '{city_name}' not found locally. Downloading from OSM (Drive network)...") | |
| try: | |
| # 1. Download drive-able road network. | |
| # Note: Simplify=True automatically ensures correct topology and length attributes | |
| G = ox.graph_from_place(city_name, network_type="drive", simplify=True) | |
| # 2. Add speed and base travel time attributes | |
| G = self._impute_edge_attributes(G) | |
| # 3. Cache logically to disk | |
| logger.info(f"Saving graph to {filepath}...") | |
| ox.save_graphml(G, filepath=filepath) | |
| logger.info(f"Successfully cached new graph: {len(G.nodes)} nodes, {len(G.edges)} edges.") | |
| self.graph = G | |
| return self.graph | |
| except Exception as e: | |
| logger.error(f"Failed to generate graph for '{city_name}': {e}") | |
| raise RuntimeError(f"Graph generation failed. Please verify the city name: '{city_name}'. Error: {e}") | |
| def load_graph_dynamically(self, lat1: float, lng1: float, lat2: Optional[float] = None, lng2: Optional[float] = None, radius_km_override: Optional[float] = None) -> nx.MultiDiGraph: | |
| """ | |
| Downloads or loads a road network around a specific coordinate or encompassing two coordinates. | |
| Calculates a center point and a safe radius to ensure the entire route is covered. | |
| Args: | |
| radius_km_override: If provided, use this radius instead of computing from coordinates. | |
| Used by cache warming to download larger city-covering graphs. | |
| """ | |
| if lat2 is not None and lng2 is not None: | |
| center_lat = (lat1 + lat2) / 2.0 | |
| center_lng = (lng1 + lng2) / 2.0 | |
| # Calculate distance using Haversine | |
| R = 6371.0 # Radius of earth in km | |
| dlat = math.radians(lat2 - lat1) | |
| dlng = math.radians(lng2 - lng1) | |
| a = math.sin(dlat / 2)**2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlng / 2)**2 | |
| c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) | |
| dist_km = R * c | |
| # 0.75x multiplier covers both endpoints; cap at 15km to stay within Overpass limits | |
| radius_km = max(5.0, min(15.0, dist_km * 0.75)) | |
| logger.info(f"Calculated distance: {dist_km:.2f}km. Setting radius to: {radius_km:.2f}km") | |
| else: | |
| center_lat = lat1 | |
| center_lng = lng1 | |
| radius_km = radius_km_override if radius_km_override else 5.0 | |
| # Round the center to ~2 decimal places for cache sharing (approx 1.1km grid) | |
| # This will reuse graphs if requests are nearby | |
| round_lat = round(center_lat, 2) | |
| round_lng = round(center_lng, 2) | |
| cache_name = f"dynamic_{round_lat}_{round_lng}_{int(radius_km)}km.graphml" | |
| filepath = self.cache_dir / cache_name | |
| radius_meters = int(radius_km * 1000) | |
| logger.info(f"Graph Bounds Initialization - Center: ({center_lat:.6f}, {center_lng:.6f}), Radius: {radius_meters}m") | |
| graph_loaded = False | |
| if filepath.exists(): | |
| logger.info(f"Graph for center ({round_lat}, {round_lng}), r={radius_meters}m found in cache. Loading...") | |
| try: | |
| self.graph = ox.load_graphml(filepath) | |
| graph_loaded = True | |
| except Exception as e: | |
| logger.error(f"Failed to load cached dynamic graph: {e}. Re-downloading...") | |
| if not graph_loaded: | |
| logger.info(f"Downloading dynamic graph for center {center_lat}, {center_lng} with radius {radius_meters}m...") | |
| # Configure OSMnx for reliable downloads | |
| import osmnx | |
| osmnx.settings.timeout = 300 | |
| # For large areas (>7km), use tiled download to avoid Overpass API limits | |
| TILE_RADIUS_M = 5000 # 5km per tile — safe for any city density | |
| if radius_meters > 7000: | |
| G = self._download_tiled_graph(center_lat, center_lng, radius_km, TILE_RADIUS_M) | |
| else: | |
| G = self._download_single_graph(center_lat, center_lng, radius_meters) | |
| G = self._impute_edge_attributes(G) | |
| logger.info(f"Saving dynamic graph to {filepath}...") | |
| ox.save_graphml(G, filepath=filepath) | |
| self.graph = G | |
| # Validate that both source and destination nodes exist in graph | |
| if lat2 is not None and lng2 is not None: | |
| try: | |
| node1 = ox.distance.nearest_nodes(self.graph, X=lng1, Y=lat1) | |
| node2 = ox.distance.nearest_nodes(self.graph, X=lng2, Y=lat2) | |
| if node1 not in self.graph.nodes or node2 not in self.graph.nodes: | |
| raise ValueError("Resolved node IDs are missing from the graph.") | |
| logger.info(f"Validated nodes in graph: Source Node ID {node1}, Dest Node ID {node2}") | |
| except Exception as e: | |
| logger.error("Source or destination nodes could not be validated in the graph.") | |
| raise RuntimeError(f"Node validation failed. Coordinates may be unroutable or graph is empty: {e}") | |
| return self.graph | |
| def _download_single_graph(self, center_lat: float, center_lng: float, radius_m: int) -> nx.MultiDiGraph: | |
| """Download a single graph tile (for areas ≤ 7km radius).""" | |
| try: | |
| G = ox.graph_from_point( | |
| (center_lat, center_lng), dist=radius_m, | |
| network_type="drive", simplify=True, | |
| ) | |
| logger.info(f"Single tile: {len(G.nodes)} nodes, {len(G.edges)} edges") | |
| return G | |
| except Exception as e: | |
| raise RuntimeError(f"Graph download failed for ({center_lat}, {center_lng}), r={radius_m}m: {e}") | |
| def _download_tiled_graph( | |
| self, center_lat: float, center_lng: float, | |
| radius_km: float, tile_radius_m: int = 5000, | |
| ) -> nx.MultiDiGraph: | |
| """ | |
| Download a large area by splitting into overlapping tiles. | |
| For a 15km radius area, this generates ~9-13 tile centers in a grid, | |
| downloads each as a separate 5km-radius graph (well within Overpass limits), | |
| and merges them into one unified graph. | |
| Each tile is individually retried (2 attempts) so one flaky download | |
| doesn't kill the entire operation. | |
| """ | |
| import time | |
| # Generate tile centers in a grid pattern | |
| # Convert tile spacing to degrees (with 20% overlap for seamless merging) | |
| tile_spacing_deg = (tile_radius_m * 0.8 * 2) / 111000.0 # ~80% of diameter | |
| # How many tiles in each direction from center | |
| n_tiles = max(1, int(math.ceil(radius_km * 1000 / (tile_radius_m * 1.6)))) | |
| tile_centers = [] | |
| for dy in range(-n_tiles, n_tiles + 1): | |
| for dx in range(-n_tiles, n_tiles + 1): | |
| t_lat = center_lat + dy * tile_spacing_deg | |
| t_lng = center_lng + dx * tile_spacing_deg / math.cos(math.radians(center_lat)) | |
| # Only include tiles within the target radius | |
| dlat = (t_lat - center_lat) * 111000 | |
| dlng = (t_lng - center_lng) * 111000 * math.cos(math.radians(center_lat)) | |
| dist = math.sqrt(dlat**2 + dlng**2) | |
| if dist <= radius_km * 1000 + tile_radius_m: | |
| tile_centers.append((t_lat, t_lng)) | |
| logger.info( | |
| f"Tiled download: {len(tile_centers)} tiles of {tile_radius_m}m radius " | |
| f"to cover {radius_km}km area around ({center_lat:.4f}, {center_lng:.4f})" | |
| ) | |
| merged_graph = None | |
| success_count = 0 | |
| for i, (t_lat, t_lng) in enumerate(tile_centers): | |
| for attempt in range(1, 3): # 2 attempts per tile | |
| try: | |
| G_tile = ox.graph_from_point( | |
| (t_lat, t_lng), dist=tile_radius_m, | |
| network_type="drive", simplify=True, | |
| ) | |
| if merged_graph is None: | |
| merged_graph = G_tile | |
| else: | |
| merged_graph = nx.compose(merged_graph, G_tile) | |
| success_count += 1 | |
| logger.info( | |
| f" Tile {i+1}/{len(tile_centers)} ✓ " | |
| f"({len(G_tile.nodes)} nodes, running total: {len(merged_graph.nodes)})" | |
| ) | |
| break # Success — move to next tile | |
| except Exception as e: | |
| if attempt < 2: | |
| logger.warning(f" Tile {i+1} attempt {attempt} failed: {e}. Retrying in 5s...") | |
| time.sleep(5) | |
| else: | |
| logger.warning(f" Tile {i+1} SKIPPED after 2 attempts: {e}") | |
| # Brief pause between tiles to be respectful to Overpass | |
| if i < len(tile_centers) - 1: | |
| time.sleep(1) | |
| if merged_graph is None or len(merged_graph.nodes) == 0: | |
| raise RuntimeError( | |
| f"All {len(tile_centers)} tile downloads failed for " | |
| f"({center_lat}, {center_lng}), r={radius_km}km" | |
| ) | |
| logger.info( | |
| f"Tiled download complete: {success_count}/{len(tile_centers)} tiles merged → " | |
| f"{len(merged_graph.nodes)} nodes, {len(merged_graph.edges)} edges" | |
| ) | |
| return merged_graph | |
| def get_nearest_node(self, lat: float, lng: float) -> int: | |
| """ | |
| Finds the physically closest graph node ID to geographical coordinates. | |
| Args: | |
| lat (float): Latitude | |
| lng (float): Longitude | |
| Returns: | |
| int: The nearest node ID. | |
| """ | |
| if self.graph is None: | |
| raise ValueError("Graph uninitialized. Call load_or_create_graph() first.") | |
| return ox.distance.nearest_nodes(self.graph, X=lng, Y=lat) | |
| def calculate_shortest_path(self, orig_node: int, dest_node: int, weight: str = "travel_time") -> List[int]: | |
| """ | |
| Computes the optimal path utilizing standard Dijkstra's algorithm. | |
| Args: | |
| orig_node (int): Origin node ID. | |
| dest_node (int): Destination node ID. | |
| weight (str): Edge optimization attribute ('length', 'travel_time', etc.) | |
| Returns: | |
| List[int]: Sequence of node IDs comprising the shortest path. | |
| """ | |
| if self.graph is None: | |
| raise ValueError("Graph uninitialized. Call load_or_create_graph() first.") | |
| try: | |
| path = nx.shortest_path(self.graph, orig=orig_node, dest=dest_node, weight=weight) | |
| return path | |
| except nx.NetworkXNoPath: | |
| logger.error(f"No path found between node {orig_node} and node {dest_node}.") | |
| raise nx.NetworkXNoPath("No valid path exists between the requested nodes.") | |