Spaces:
Running
Running
File size: 14,836 Bytes
1dc52fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 | """
Graph data management using NetworkX and OSMnx.
Responsible for downloading, processing, caching, and serving routing road networks.
"""
import logging
import networkx as nx
import osmnx as ox
from pathlib import Path
from typing import Optional, List
import math
# Configure module logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
class GraphManager:
"""
Manages loading, updating, and querying the spatial graph.
"""
def __init__(self, cache_dir: str = "./data/graph"):
"""
Initializes the GraphManager.
Args:
cache_dir (str): Directory where graph files are stored locally.
"""
self.cache_dir = Path(cache_dir)
# Ensure the cache directory exists
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.graph: Optional[nx.MultiDiGraph] = None
# Configure OSMnx performance/cache settings
ox.settings.use_cache = True
ox.settings.log_console = False
def _impute_edge_attributes(self, G: nx.MultiDiGraph) -> nx.MultiDiGraph:
"""
Imputes missing edge speeds and calculates base travel times.
Args:
G (nx.MultiDiGraph): Unprocessed NetworkX graph object.
Returns:
nx.MultiDiGraph: Graph with populated edge speed/time properties.
"""
logger.info("Imputing edge speeds and travel times...")
# Fallback speeds based on standard highway types (in km/h)
# This acts as a fallback when maxspeed tags are missing from OSM data
fallback_speeds = {
'residential': 30,
'secondary': 50,
'tertiary': 40,
'primary': 60,
'motorway': 100,
'unclassified': 30,
'default': 30
}
# Adds speed_kph attribute to edges based on fallback / existing data
G = ox.add_edge_speeds(G, hwy_speeds=fallback_speeds)
# Calculates base_travel_time (seconds) based on length (meters) / speed_kph
G = ox.add_edge_travel_times(G)
return G
def load_or_create_graph(self, city_name: str) -> nx.MultiDiGraph:
"""
Downloads a city's road network, extracts nodes & edges, imputes properties,
and safely caches it locally to avoid repeated downloads.
Args:
city_name (str): Full city string (e.g., San Francisco, California, USA).
Returns:
nx.MultiDiGraph: Processed NetworkX routing graph.
"""
# Create a safe, deterministic filename from the city name
safe_name = city_name.replace(", ", "_").replace(" ", "_").lower()
filepath = self.cache_dir / f"{safe_name}.graphml"
if filepath.exists():
logger.info(f"Graph for '{city_name}' found in cache at {filepath}. Loading (this may take a moment)...")
try:
self.graph = ox.load_graphml(filepath)
logger.info(f"Successfully loaded graph: {len(self.graph.nodes)} nodes, {len(self.graph.edges)} edges.")
return self.graph
except Exception as e:
logger.error(f"Failed to load cached graph: {e}. Will attempt re-download...")
logger.info(f"Graph for '{city_name}' not found locally. Downloading from OSM (Drive network)...")
try:
# 1. Download drive-able road network.
# Note: Simplify=True automatically ensures correct topology and length attributes
G = ox.graph_from_place(city_name, network_type="drive", simplify=True)
# 2. Add speed and base travel time attributes
G = self._impute_edge_attributes(G)
# 3. Cache logically to disk
logger.info(f"Saving graph to {filepath}...")
ox.save_graphml(G, filepath=filepath)
logger.info(f"Successfully cached new graph: {len(G.nodes)} nodes, {len(G.edges)} edges.")
self.graph = G
return self.graph
except Exception as e:
logger.error(f"Failed to generate graph for '{city_name}': {e}")
raise RuntimeError(f"Graph generation failed. Please verify the city name: '{city_name}'. Error: {e}")
def load_graph_dynamically(self, lat1: float, lng1: float, lat2: Optional[float] = None, lng2: Optional[float] = None, radius_km_override: Optional[float] = None) -> nx.MultiDiGraph:
"""
Downloads or loads a road network around a specific coordinate or encompassing two coordinates.
Calculates a center point and a safe radius to ensure the entire route is covered.
Args:
radius_km_override: If provided, use this radius instead of computing from coordinates.
Used by cache warming to download larger city-covering graphs.
"""
if lat2 is not None and lng2 is not None:
center_lat = (lat1 + lat2) / 2.0
center_lng = (lng1 + lng2) / 2.0
# Calculate distance using Haversine
R = 6371.0 # Radius of earth in km
dlat = math.radians(lat2 - lat1)
dlng = math.radians(lng2 - lng1)
a = math.sin(dlat / 2)**2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlng / 2)**2
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
dist_km = R * c
# 0.75x multiplier covers both endpoints; cap at 15km to stay within Overpass limits
radius_km = max(5.0, min(15.0, dist_km * 0.75))
logger.info(f"Calculated distance: {dist_km:.2f}km. Setting radius to: {radius_km:.2f}km")
else:
center_lat = lat1
center_lng = lng1
radius_km = radius_km_override if radius_km_override else 5.0
# Round the center to ~2 decimal places for cache sharing (approx 1.1km grid)
# This will reuse graphs if requests are nearby
round_lat = round(center_lat, 2)
round_lng = round(center_lng, 2)
cache_name = f"dynamic_{round_lat}_{round_lng}_{int(radius_km)}km.graphml"
filepath = self.cache_dir / cache_name
radius_meters = int(radius_km * 1000)
logger.info(f"Graph Bounds Initialization - Center: ({center_lat:.6f}, {center_lng:.6f}), Radius: {radius_meters}m")
graph_loaded = False
if filepath.exists():
logger.info(f"Graph for center ({round_lat}, {round_lng}), r={radius_meters}m found in cache. Loading...")
try:
self.graph = ox.load_graphml(filepath)
graph_loaded = True
except Exception as e:
logger.error(f"Failed to load cached dynamic graph: {e}. Re-downloading...")
if not graph_loaded:
logger.info(f"Downloading dynamic graph for center {center_lat}, {center_lng} with radius {radius_meters}m...")
# Configure OSMnx for reliable downloads
import osmnx
osmnx.settings.timeout = 300
# For large areas (>7km), use tiled download to avoid Overpass API limits
TILE_RADIUS_M = 5000 # 5km per tile — safe for any city density
if radius_meters > 7000:
G = self._download_tiled_graph(center_lat, center_lng, radius_km, TILE_RADIUS_M)
else:
G = self._download_single_graph(center_lat, center_lng, radius_meters)
G = self._impute_edge_attributes(G)
logger.info(f"Saving dynamic graph to {filepath}...")
ox.save_graphml(G, filepath=filepath)
self.graph = G
# Validate that both source and destination nodes exist in graph
if lat2 is not None and lng2 is not None:
try:
node1 = ox.distance.nearest_nodes(self.graph, X=lng1, Y=lat1)
node2 = ox.distance.nearest_nodes(self.graph, X=lng2, Y=lat2)
if node1 not in self.graph.nodes or node2 not in self.graph.nodes:
raise ValueError("Resolved node IDs are missing from the graph.")
logger.info(f"Validated nodes in graph: Source Node ID {node1}, Dest Node ID {node2}")
except Exception as e:
logger.error("Source or destination nodes could not be validated in the graph.")
raise RuntimeError(f"Node validation failed. Coordinates may be unroutable or graph is empty: {e}")
return self.graph
def _download_single_graph(self, center_lat: float, center_lng: float, radius_m: int) -> nx.MultiDiGraph:
"""Download a single graph tile (for areas ≤ 7km radius)."""
try:
G = ox.graph_from_point(
(center_lat, center_lng), dist=radius_m,
network_type="drive", simplify=True,
)
logger.info(f"Single tile: {len(G.nodes)} nodes, {len(G.edges)} edges")
return G
except Exception as e:
raise RuntimeError(f"Graph download failed for ({center_lat}, {center_lng}), r={radius_m}m: {e}")
def _download_tiled_graph(
self, center_lat: float, center_lng: float,
radius_km: float, tile_radius_m: int = 5000,
) -> nx.MultiDiGraph:
"""
Download a large area by splitting into overlapping tiles.
For a 15km radius area, this generates ~9-13 tile centers in a grid,
downloads each as a separate 5km-radius graph (well within Overpass limits),
and merges them into one unified graph.
Each tile is individually retried (2 attempts) so one flaky download
doesn't kill the entire operation.
"""
import time
# Generate tile centers in a grid pattern
# Convert tile spacing to degrees (with 20% overlap for seamless merging)
tile_spacing_deg = (tile_radius_m * 0.8 * 2) / 111000.0 # ~80% of diameter
# How many tiles in each direction from center
n_tiles = max(1, int(math.ceil(radius_km * 1000 / (tile_radius_m * 1.6))))
tile_centers = []
for dy in range(-n_tiles, n_tiles + 1):
for dx in range(-n_tiles, n_tiles + 1):
t_lat = center_lat + dy * tile_spacing_deg
t_lng = center_lng + dx * tile_spacing_deg / math.cos(math.radians(center_lat))
# Only include tiles within the target radius
dlat = (t_lat - center_lat) * 111000
dlng = (t_lng - center_lng) * 111000 * math.cos(math.radians(center_lat))
dist = math.sqrt(dlat**2 + dlng**2)
if dist <= radius_km * 1000 + tile_radius_m:
tile_centers.append((t_lat, t_lng))
logger.info(
f"Tiled download: {len(tile_centers)} tiles of {tile_radius_m}m radius "
f"to cover {radius_km}km area around ({center_lat:.4f}, {center_lng:.4f})"
)
merged_graph = None
success_count = 0
for i, (t_lat, t_lng) in enumerate(tile_centers):
for attempt in range(1, 3): # 2 attempts per tile
try:
G_tile = ox.graph_from_point(
(t_lat, t_lng), dist=tile_radius_m,
network_type="drive", simplify=True,
)
if merged_graph is None:
merged_graph = G_tile
else:
merged_graph = nx.compose(merged_graph, G_tile)
success_count += 1
logger.info(
f" Tile {i+1}/{len(tile_centers)} ✓ "
f"({len(G_tile.nodes)} nodes, running total: {len(merged_graph.nodes)})"
)
break # Success — move to next tile
except Exception as e:
if attempt < 2:
logger.warning(f" Tile {i+1} attempt {attempt} failed: {e}. Retrying in 5s...")
time.sleep(5)
else:
logger.warning(f" Tile {i+1} SKIPPED after 2 attempts: {e}")
# Brief pause between tiles to be respectful to Overpass
if i < len(tile_centers) - 1:
time.sleep(1)
if merged_graph is None or len(merged_graph.nodes) == 0:
raise RuntimeError(
f"All {len(tile_centers)} tile downloads failed for "
f"({center_lat}, {center_lng}), r={radius_km}km"
)
logger.info(
f"Tiled download complete: {success_count}/{len(tile_centers)} tiles merged → "
f"{len(merged_graph.nodes)} nodes, {len(merged_graph.edges)} edges"
)
return merged_graph
def get_nearest_node(self, lat: float, lng: float) -> int:
"""
Finds the physically closest graph node ID to geographical coordinates.
Args:
lat (float): Latitude
lng (float): Longitude
Returns:
int: The nearest node ID.
"""
if self.graph is None:
raise ValueError("Graph uninitialized. Call load_or_create_graph() first.")
return ox.distance.nearest_nodes(self.graph, X=lng, Y=lat)
def calculate_shortest_path(self, orig_node: int, dest_node: int, weight: str = "travel_time") -> List[int]:
"""
Computes the optimal path utilizing standard Dijkstra's algorithm.
Args:
orig_node (int): Origin node ID.
dest_node (int): Destination node ID.
weight (str): Edge optimization attribute ('length', 'travel_time', etc.)
Returns:
List[int]: Sequence of node IDs comprising the shortest path.
"""
if self.graph is None:
raise ValueError("Graph uninitialized. Call load_or_create_graph() first.")
try:
path = nx.shortest_path(self.graph, orig=orig_node, dest=dest_node, weight=weight)
return path
except nx.NetworkXNoPath:
logger.error(f"No path found between node {orig_node} and node {dest_node}.")
raise nx.NetworkXNoPath("No valid path exists between the requested nodes.")
|