import gradio as gr import pandas as pd import geopandas as gpd import folium from folium import plugins import plotly.express as px import plotly.graph_objects as go from huggingface_hub import InferenceClient import json import os import tempfile import io from datetime import datetime import numpy as np from pathlib import Path import warnings import logging import sys # Suppress Python 3.13 asyncio cleanup warnings (harmless garbage collection issue) if sys.version_info >= (3, 13): warnings.filterwarnings('ignore', message='.*Invalid file descriptor.*') # Suppress GeoPandas CRS warnings (area/centroid calculations are approximate for demo purposes) warnings.filterwarnings('ignore', message='.*Geometry is in a geographic CRS.*') # Suppress asyncio cleanup warnings in Python 3.13+ (harmless but noisy) warnings.filterwarnings('ignore', message='.*Invalid file descriptor.*') logging.getLogger('asyncio').setLevel(logging.CRITICAL) # Configure logger for this module logger = logging.getLogger(__name__) import branca.colormap as cm import zipfile import urllib.request # Import external API handlers try: from data_utils import external_data, ExternalDataHandler EXTERNAL_APIS_AVAILABLE = external_data.is_available except ImportError: EXTERNAL_APIS_AVAILABLE = False external_data = None def download_natural_earth_data(data_dir: Path) -> Path: """Download Natural Earth countries shapefile if not present.""" shp_file = data_dir / "ne_110m_admin_0_countries.shp" if shp_file.exists(): return shp_file # Create data directory data_dir.mkdir(parents=True, exist_ok=True) # Natural Earth 110m countries download URL url = "https://naciscdn.org/naturalearth/110m/cultural/ne_110m_admin_0_countries.zip" zip_path = data_dir / "ne_110m_admin_0_countries.zip" print(f"Downloading Natural Earth data from {url}...") urllib.request.urlretrieve(url, zip_path) print(f"Extracting to {data_dir}...") with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(data_dir) # Clean up zip file zip_path.unlink() print("Natural Earth data ready.") return shp_file def format_number(num): """Format large numbers with K/M/B/T suffixes for better readability.""" if num is None or (isinstance(num, float) and np.isnan(num)): return 'N/A' abs_num = abs(num) if abs_num >= 1e12: return f'{num/1e12:.1f}T' elif abs_num >= 1e9: return f'{num/1e9:.1f}B' elif abs_num >= 1e6: return f'{num/1e6:.1f}M' elif abs_num >= 1e3: return f'{num/1e3:.1f}K' else: return f'{num:.1f}' # Path to local Natural Earth data (downloaded at runtime if not present) DATA_DIR = Path(__file__).parent / "data" / "ne_110m_admin_0_countries" NATURAL_EARTH_SHP = download_natural_earth_data(DATA_DIR) # Initialize HF Inference Client client = InferenceClient(token=os.environ.get("HF_TOKEN")) # ===== UI/UX Enhancement Constants ===== MAP_STYLES = { "Light": "CartoDB positron", "Dark": "CartoDB dark_matter", "Street": "OpenStreetMap", "Satellite": "Esri.WorldImagery" } COLOR_SCHEMES = { "Default": px.colors.qualitative.Plotly, "Vivid": px.colors.qualitative.Vivid, "Pastel": px.colors.qualitative.Pastel, "Bold": px.colors.qualitative.Bold, "Earth": px.colors.qualitative.Safe } CHOROPLETH_COLORS = { "Yellow-Orange-Red": "YlOrRd", "Yellow-Green-Blue": "YlGnBu", "Purple-Red": "PuRd", "Blue-Purple": "BuPu", "Greens": "Greens", "Blues": "Blues", "Oranges": "OrRd", "Spectral": "Spectral" } INDICATORS = { "Population": "pop_est", "GDP (Million $)": "gdp_md_est", "Population Density": "pop_density", "GDP per Capita": "gdp_per_capita" } # Global cache for world data _world_data_cache = None def load_world_data(): """Load world countries geospatial data""" global _world_data_cache if _world_data_cache is None: raw = gpd.read_file(NATURAL_EARTH_SHP) # Select only the columns we need (using original uppercase names) # and rename them to match expected lowercase names _world_data_cache = raw[['NAME', 'CONTINENT', 'POP_EST', 'GDP_MD', 'geometry']].copy() _world_data_cache.columns = ['name', 'continent', 'pop_est', 'gdp_md_est', 'geometry'] return _world_data_cache def parse_query_with_llm(user_query): """ Use LLM to parse natural language query into structured format """ system_prompt = """You are a geospatial and geographic data query parser. Extract structured information from user queries. Response format (JSON only): { "locations": ["country/region names"], "indicators": ["GDP", "population", "CO2 emissions", etc.], "time_range": {"start": "YYYY", "end": "YYYY"}, "visualization": "map/chart/table", "aggregation": "sum/average/comparison", "query_type": "single_country/multi_country/regional/global/poi/conservation", "data_source": "countries/osm_pois/conservation", "poi_category": "Food & Drink/Healthcare/Education/Tourism/Nature/Shopping" (if POI query), "conservation_topic": "wildlife/forests/oceans/freshwater/climate" (if conservation query) } Query type detection: - POI queries: "restaurants in Paris", "hospitals near Berlin", "find hotels in Tokyo" - Conservation queries: "deforestation data", "endangered species", "marine protected areas" - Country data: "GDP of France", "population of Asia", "compare European economies" Examples: - "Show me GDP of Asian countries" β†’ data_source: countries, query_type: regional - "Find restaurants in Paris" β†’ data_source: osm_pois, poi_category: Food & Drink - "Show deforestation datasets" β†’ data_source: conservation, conservation_topic: forests - "Protected areas in Brazil" β†’ data_source: osm_pois, poi_category: Nature Return ONLY valid JSON, no explanations.""" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": f"Parse this query: {user_query}"} ] try: response = client.chat_completion( messages=messages, model="meta-llama/Llama-3.1-8B-Instruct", max_tokens=500, temperature=0.1 ) raw_content = response.choices[0].message.content print(f"LLM raw response: {raw_content}") parsed = json.loads(raw_content) print(f"Parsed query: {parsed}") return parsed except Exception as e: print(f"LLM parsing error: {e}") print(f"HF_TOKEN set: {bool(os.environ.get('HF_TOKEN'))}") return { "locations": [], "indicators": ["population", "gdp_md_est"], "visualization": "table", "query_type": "global", "data_source": "countries" } # ============================================================================= # External API Data Fetching Functions # ============================================================================= # City coordinates for POI queries (approximate centers) CITY_COORDINATES = { "paris": (48.8566, 2.3522), "london": (51.5074, -0.1278), "berlin": (52.5200, 13.4050), "tokyo": (35.6762, 139.6503), "new york": (40.7128, -74.0060), "rome": (41.9028, 12.4964), "madrid": (40.4168, -3.7038), "amsterdam": (52.3676, 4.9041), "sydney": (-33.8688, 151.2093), "singapore": (1.3521, 103.8198), "dubai": (25.2048, 55.2708), "mumbai": (19.0760, 72.8777), "beijing": (39.9042, 116.4074), "seoul": (37.5665, 126.9780), "bangkok": (13.7563, 100.5018), "cairo": (30.0444, 31.2357), "moscow": (55.7558, 37.6173), "sao paulo": (-23.5505, -46.6333), "mexico city": (19.4326, -99.1332), "los angeles": (34.0522, -118.2437), } def fetch_poi_data(parsed_query): """ Fetch Points of Interest data from OpenStreetMap Overpass API. Returns a list of POIs with coordinates and metadata. """ if not EXTERNAL_APIS_AVAILABLE or external_data is None: return [], "External APIs not available" locations = parsed_query.get("locations", []) poi_category = parsed_query.get("poi_category", "Tourism") # Find city coordinates city_coords = None city_name = None # Also check the original user query for city names # This helps when LLM returns "Berlin, Germany" but we need to match "berlin" all_text_to_search = locations.copy() for loc in all_text_to_search: # Normalize: lowercase, remove commas and common suffixes, extra spaces loc_lower = loc.lower().replace(",", " ").replace(".", " ").strip() # Remove common country suffixes that might interfere for suffix in [' germany', ' france', ' japan', ' usa', ' uk', ' italy', ' spain', ' china', ' india', ' brazil', ' australia']: loc_lower = loc_lower.replace(suffix, '') loc_lower = loc_lower.strip() loc_words = set(loc_lower.split()) # Check direct match first if loc_lower in CITY_COORDINATES: city_coords = CITY_COORDINATES[loc_lower] city_name = loc break # Check if any city name appears in the location string for city, coords in CITY_COORDINATES.items(): city_words = set(city.split()) # Match if city name is contained in location OR any word matches city if city in loc_lower or city_words & loc_words or any(city == word for word in loc_words): city_coords = coords city_name = loc break if city_coords: break if not city_coords: return [], f"City not found in database. Available cities: {', '.join(CITY_COORDINATES.keys())}" # Fetch POIs try: pois = external_data.get_pois_near_location( lat=city_coords[0], lon=city_coords[1], radius_km=5, category=poi_category ) return pois, f"Found {len(pois)} POIs near {city_name}" except Exception as e: return [], f"Error fetching POIs: {str(e)}" def fetch_conservation_data(parsed_query): """ Fetch conservation datasets from WWF GLOBIL. Returns a list of dataset metadata. """ if not EXTERNAL_APIS_AVAILABLE or external_data is None: return [], "External APIs not available" topic = parsed_query.get("conservation_topic", "forests") # Also search by keywords in query query_keywords = { "deforestation": "forests", "forest": "forests", "wildlife": "wildlife", "endangered": "wildlife", "species": "wildlife", "marine": "oceans", "ocean": "oceans", "coral": "oceans", "river": "freshwater", "water": "freshwater", "climate": "climate", "carbon": "climate" } # Try to detect topic from locations/indicators all_text = " ".join(parsed_query.get("locations", []) + parsed_query.get("indicators", [])).lower() for keyword, topic_name in query_keywords.items(): if keyword in all_text: topic = topic_name break try: datasets = external_data.search_conservation_data(topic, limit=10) return datasets, f"Found {len(datasets)} datasets about {topic}" except Exception as e: return [], f"Error fetching conservation data: {str(e)}" def create_poi_map(pois, city_name, map_style='Light'): """ Create a Folium map showing Points of Interest with layer control. """ if not pois: return None # Calculate center from POIs lats = [p.get('lat', 0) for p in pois if p.get('lat')] lons = [p.get('lon', 0) for p in pois if p.get('lon')] if not lats or not lons: return None center_lat = sum(lats) / len(lats) center_lon = sum(lons) / len(lons) tiles = MAP_STYLES.get(map_style, 'CartoDB positron') m = folium.Map(location=[center_lat, center_lon], zoom_start=14, tiles=tiles) # Group POIs by type for layer control poi_groups = {} for poi in pois: lat = poi.get('lat') lon = poi.get('lon') if not lat or not lon: continue name = poi.get('name', 'Unknown') tags = poi.get('tags', {}) poi_type = tags.get('amenity') or tags.get('tourism') or tags.get('shop') or 'Other' # Categorize POI type for grouping if 'restaurant' in poi_type.lower() or 'food' in poi_type.lower() or 'cafe' in poi_type.lower(): group_name = '🍽️ Restaurants & Cafes' icon_color = 'red' elif 'hotel' in poi_type.lower() or 'hostel' in poi_type.lower() or 'guest' in poi_type.lower(): group_name = '🏨 Hotels & Accommodation' icon_color = 'green' elif 'hospital' in poi_type.lower() or 'clinic' in poi_type.lower() or 'pharmacy' in poi_type.lower() or 'health' in poi_type.lower(): group_name = 'πŸ₯ Healthcare' icon_color = 'white' elif 'museum' in poi_type.lower() or 'attraction' in poi_type.lower() or 'tourism' in str(tags).lower(): group_name = '🎭 Tourism & Attractions' icon_color = 'purple' elif 'shop' in poi_type.lower() or 'supermarket' in poi_type.lower() or 'mall' in poi_type.lower(): group_name = 'πŸ›’ Shopping' icon_color = 'orange' else: group_name = 'πŸ“ Other POIs' icon_color = 'blue' # Create feature group if not exists if group_name not in poi_groups: poi_groups[group_name] = folium.FeatureGroup(name=group_name) popup_html = f"{name}
Type: {poi_type}" if tags.get('cuisine'): popup_html += f"
Cuisine: {tags['cuisine']}" if tags.get('phone'): popup_html += f"
Phone: {tags['phone']}" if tags.get('opening_hours'): popup_html += f"
Hours: {tags['opening_hours']}" folium.Marker( location=[lat, lon], popup=folium.Popup(popup_html, max_width=250), tooltip=name, icon=folium.Icon(color=icon_color, icon='info-sign') ).add_to(poi_groups[group_name]) # Add all feature groups to map for group in poi_groups.values(): group.add_to(m) # Add layer control for toggling POI categories folium.LayerControl(collapsed=False).add_to(m) return m def create_conservation_table(datasets): """ Create a DataFrame from conservation dataset metadata. """ if not datasets: return pd.DataFrame() rows = [] for ds in datasets: rows.append({ 'Title': ds.get('title', 'Unknown'), 'Description': (ds.get('snippet', '') or '')[:100] + '...', 'Type': ds.get('type', 'Unknown'), 'Views': ds.get('views', 0), 'Access': ds.get('access', 'Unknown'), 'Owner': ds.get('owner', 'Unknown') }) return pd.DataFrame(rows) def create_conservation_map(feature_data_list, map_style='Light'): """ Create a Folium map showing conservation features from GLOBIL. Args: feature_data_list: List of dicts with 'dataset' and 'features' keys map_style: Map tile style Returns: Folium map object """ if not feature_data_list: return None # Collect all coordinates to calculate center all_coords = [] for data in feature_data_list: for feature in data.get('features', []): geom = feature.get('geometry', {}) geom_type = geom.get('type', '') coords = geom.get('coordinates', []) if geom_type == 'Point' and len(coords) >= 2: all_coords.append((coords[1], coords[0])) # lat, lon elif geom_type == 'Polygon' and coords: # Get centroid of first ring ring = coords[0] if isinstance(coords[0], list) and coords[0] else [] if ring and len(ring) > 0: lons = [c[0] for c in ring if len(c) >= 2] lats = [c[1] for c in ring if len(c) >= 2] if lons and lats: all_coords.append((sum(lats)/len(lats), sum(lons)/len(lons))) elif geom_type == 'MultiPolygon' and coords: for polygon in coords: if polygon and polygon[0]: ring = polygon[0] lons = [c[0] for c in ring if len(c) >= 2] lats = [c[1] for c in ring if len(c) >= 2] if lons and lats: all_coords.append((sum(lats)/len(lats), sum(lons)/len(lons))) if not all_coords: # Default to world center if no coordinates found center_lat, center_lon = 20, 0 else: center_lat = sum(c[0] for c in all_coords) / len(all_coords) center_lon = sum(c[1] for c in all_coords) / len(all_coords) tiles = MAP_STYLES.get(map_style, 'CartoDB positron') m = folium.Map(location=[center_lat, center_lon], zoom_start=3, tiles=tiles) # Color palette for different datasets colors = ['#2ecc71', '#3498db', '#9b59b6', '#e74c3c', '#f39c12', '#1abc9c'] for i, data in enumerate(feature_data_list): dataset = data.get('dataset', {}) features = data.get('features', []) dataset_title = dataset.get('title', f'Dataset {i+1}') color = colors[i % len(colors)] # Create a feature group for this dataset fg = folium.FeatureGroup(name=dataset_title) for feature in features[:100]: # Limit features per dataset geom = feature.get('geometry', {}) props = feature.get('properties', {}) geom_type = geom.get('type', '') # Build popup content popup_lines = [f"{dataset_title}"] for key, value in list(props.items())[:5]: # Show first 5 properties if value and str(value).strip(): popup_lines.append(f"{key}: {value}") popup_html = "
".join(popup_lines) try: if geom_type == 'Point': coords = geom.get('coordinates', []) if len(coords) >= 2: folium.CircleMarker( location=[coords[1], coords[0]], radius=6, color=color, fill=True, fillColor=color, fillOpacity=0.7, popup=popup_html ).add_to(fg) elif geom_type in ['Polygon', 'MultiPolygon']: folium.GeoJson( feature, style_function=lambda x, c=color: { 'fillColor': c, 'color': c, 'weight': 2, 'fillOpacity': 0.4 }, popup=folium.Popup(popup_html, max_width=300) ).add_to(fg) elif geom_type in ['LineString', 'MultiLineString']: folium.GeoJson( feature, style_function=lambda x, c=color: { 'color': c, 'weight': 3 }, popup=folium.Popup(popup_html, max_width=300) ).add_to(fg) except Exception as e: logger.debug(f"Error adding feature: {e}") continue fg.add_to(m) # Add layer control if multiple datasets if len(feature_data_list) > 1: folium.LayerControl().add_to(m) return m def create_conservation_chart(feature_data_list, chart_type='bar'): """ Create a Plotly chart from conservation feature data. Args: feature_data_list: List of dicts with 'dataset' and 'features' keys chart_type: Type of chart ('bar', 'pie') Returns: Plotly figure """ if not feature_data_list: return None # Aggregate data by dataset chart_data = [] for data in feature_data_list: dataset = data.get('dataset', {}) features = data.get('features', []) chart_data.append({ 'Dataset': dataset.get('title', 'Unknown')[:40], 'Feature Count': len(features), 'Views': dataset.get('views', 0) }) if not chart_data: return None df = pd.DataFrame(chart_data) if chart_type == 'pie': fig = px.pie( df, values='Feature Count', names='Dataset', title='Conservation Features by Dataset', color_discrete_sequence=px.colors.qualitative.Set2 ) else: # Default to bar fig = px.bar( df, x='Dataset', y='Feature Count', color='Dataset', title='Conservation Features by Dataset', color_discrete_sequence=px.colors.qualitative.Set2 ) fig.update_layout(xaxis_tickangle=-45) fig.update_layout( template='plotly_white', showlegend=True, height=400 ) return fig def fetch_conservation_features_for_query(parsed_query): """ Fetch conservation features based on parsed query. Returns: Tuple of (feature_data_list, status_message, metadata_list) """ if not EXTERNAL_APIS_AVAILABLE or external_data is None: return [], "External APIs not available", [] topic = parsed_query.get("conservation_topic", "forests") # Topic detection from query query_keywords = { "deforestation": "forests", "forest": "forests", "wildlife": "wildlife", "endangered": "wildlife", "species": "wildlife", "marine": "oceans", "ocean": "oceans", "coral": "oceans", "protected": "wildlife", "river": "freshwater", "water": "freshwater", "climate": "climate", "carbon": "climate" } all_text = " ".join(parsed_query.get("locations", []) + parsed_query.get("indicators", [])).lower() for keyword, topic_name in query_keywords.items(): if keyword in all_text: topic = topic_name break try: # Fetch features from GLOBIL feature_data = external_data.fetch_conservation_features( topic, max_datasets=3, max_features=200 ) if not feature_data: # Fallback to metadata only datasets = external_data.search_conservation_data(topic, limit=10) return [], f"No feature data available for {topic}. Showing metadata only.", datasets total_features = sum(d.get('feature_count', 0) for d in feature_data) datasets_found = len(feature_data) return ( feature_data, f"Found {total_features} features from {datasets_found} datasets about {topic}", [d.get('dataset', {}) for d in feature_data] ) except Exception as e: logger.error(f"Error fetching conservation features: {e}") return [], f"Error fetching data: {str(e)}", [] def fetch_geospatial_data(parsed_query): """ Fetch and process geospatial data based on parsed query """ world = load_world_data() # Filter by locations locations = parsed_query.get("locations", []) # Normalize location names - handle common variations location_normalization = { # Continent variations "european nations": "Europe", "european countries": "Europe", "europe": "Europe", "asian nations": "Asia", "asian countries": "Asia", "asia": "Asia", "african nations": "Africa", "african countries": "Africa", "africa": "Africa", "north american nations": "North America", "north american countries": "North America", "north america": "North America", "south american nations": "South America", "south american countries": "South America", "south america": "South America", "oceanian nations": "Oceania", "oceanian countries": "Oceania", "oceania": "Oceania", "australia and oceania": "Oceania", # Country groups "middle east": "Asia", "middle eastern countries": "Asia", "southeast asia": "Asia", "southeast asian countries": "Asia", "latin america": "South America", "latin american countries": "South America", } # Normalize locations normalized_locations = [] for loc in locations: loc_lower = loc.lower().strip() if loc_lower in location_normalization: normalized_locations.append(location_normalization[loc_lower]) else: # Keep original (might be a country name) normalized_locations.append(loc) # Treat "global", "world", "worldwide", "all" as requests for all data global_terms = {"global", "world", "worldwide", "all", "earth", "globe"} is_global_query = not normalized_locations or (len(normalized_locations) == 1 and normalized_locations[0].lower() in global_terms) if not is_global_query: # Filter by continent or country mask = world['continent'].isin(normalized_locations) | world['name'].isin(normalized_locations) filtered_data = world[mask].copy() else: filtered_data = world.copy() # Add computed indicators filtered_data.loc[:, 'pop_density'] = filtered_data['pop_est'] / filtered_data['geometry'].area * 1000000 filtered_data.loc[:, 'gdp_per_capita'] = filtered_data['gdp_md_est'] / filtered_data['pop_est'] * 1000000 return filtered_data def create_interactive_map(gdf, indicators, map_style='Light', color_scale='Yellow-Orange-Red'): """ Create an interactive Folium map with customizable style and colors. Supports multiple indicators with toggleable layer groups. Args: gdf: GeoDataFrame with country data indicators: Single indicator string or list of indicator column names map_style: Map tile style color_scale: Color scheme for choropleth """ # Ensure indicators is a list if isinstance(indicators, str): indicators = [indicators] # Calculate center center_lat = gdf.geometry.centroid.y.mean() center_lon = gdf.geometry.centroid.x.mean() # Get tile style tiles = MAP_STYLES.get(map_style, 'CartoDB positron') # Create map m = folium.Map( location=[center_lat, center_lon], zoom_start=2, tiles=tiles ) # Color schemes for different indicators (these match branca colormap names exactly) color_scheme_list = ['YlOrRd', 'YlGnBu', 'Greens', 'Blues', 'PuRd', 'OrRd'] # Map color scheme names to branca colormap objects color_map_dict = { 'YlOrRd': cm.linear.YlOrRd_09, 'YlGnBu': cm.linear.YlGnBu_09, 'Blues': cm.linear.Blues_09, 'Greens': cm.linear.Greens_09, 'Reds': cm.linear.Reds_09, 'PuRd': cm.linear.PuRd_09, 'OrRd': cm.linear.OrRd_09, 'BuPu': cm.linear.BuPu_09, 'Spectral': cm.linear.Spectral_11 } # Get user-selected color scheme for first indicator fill_color = CHOROPLETH_COLORS.get(color_scale, 'YlOrRd') # Store legends to add after all layers (for proper stacking) legends_html = [] # Create a layer group for each indicator for idx, indicator in enumerate(indicators): # Use user-selected color for first indicator, cycle through others for additional indicators if idx == 0: current_color_name = fill_color else: current_color_name = color_scheme_list[idx % len(color_scheme_list)] # Get min/max for this indicator valid_values = gdf[indicator].dropna() if valid_values.empty: vmin, vmax = 0, 1 else: vmin = float(valid_values.min()) vmax = float(valid_values.max()) # Ensure valid range if np.isnan(vmin) or np.isnan(vmax) or np.isinf(vmin) or np.isinf(vmax): vmin, vmax = 0, 1 if vmax <= vmin: vmax = vmin + 1 # Create feature group for this indicator indicator_name = indicator.replace('_', ' ').title() feature_group = folium.FeatureGroup(name=f"πŸ“Š {indicator_name}", show=(idx == 0)) # Get the base branca colormap and scale it to our data range base_colormap = color_map_dict.get(current_color_name, cm.linear.YlOrRd_09) scaled_colormap = base_colormap.scale(vmin, vmax) scaled_colormap.caption = f"{indicator_name} ({format_number(vmin)} - {format_number(vmax)})" # Create style function using the SAME scaled colormap for exact color matching # Use default arguments to capture current loop values def make_style_function(cmap, ind, v_min, v_max): def style_function(feature): value = feature['properties'].get(ind) if value is None or (isinstance(value, float) and np.isnan(value)): return { 'fillColor': 'lightgray', 'fillOpacity': 0.7, 'color': 'white', 'weight': 0.5 } # Clamp value to range clamped_value = max(v_min, min(v_max, value)) return { 'fillColor': cmap(clamped_value), 'fillOpacity': 0.7, 'color': 'white', 'weight': 0.5 } return style_function style_fn = make_style_function(scaled_colormap, indicator, vmin, vmax) # Convert GeoDataFrame to GeoJSON with indicator values in properties geojson_data = json.loads(gdf.to_json()) # Add the indicator value to each feature's properties for i, feature in enumerate(geojson_data['features']): ind_value = gdf.iloc[i][indicator] feature['properties'][indicator] = float(ind_value) if pd.notna(ind_value) else None # Use GeoJson instead of Choropleth for precise color control geojson_layer = folium.GeoJson( geojson_data, style_function=style_fn, tooltip=folium.GeoJsonTooltip( fields=['name', indicator], aliases=['Country:', f'{indicator_name}:'], style="background-color: white; color: #333333; font-family: arial; font-size: 12px; padding: 10px;" ) ) geojson_layer.add_to(feature_group) feature_group.add_to(m) # Build legend HTML for this indicator with vertical offset # Only show first indicator's legend by default (others shown when layer toggled) legend_bottom = 50 + (idx * 80) # Stack legends vertically legend_html = f'''
{indicator_name}
{format_number(vmin)} {format_number(vmax)}
''' legends_html.append(legend_html) # Add all legends to the map for legend_html in legends_html: m.get_root().html.add_child(folium.Element(legend_html)) # Add tooltips with all indicator values in a toggleable layer group markers_group = folium.FeatureGroup(name="πŸ“ Info Markers", show=True) for _, row in gdf.iterrows(): popup_html = f"{row['name']}
" popup_html += f"Continent: {row['continent']}

" popup_html += f"Population: {format_number(row['pop_est'])}
" popup_html += f"GDP: ${format_number(row['gdp_md_est'])}M
" if 'pop_density' in row: popup_html += f"Pop Density: {row['pop_density']:.1f}/kmΒ²
" if 'gdp_per_capita' in row: popup_html += f"GDP/Capita: ${row['gdp_per_capita']:,.0f}
" folium.Marker( location=[row.geometry.centroid.y, row.geometry.centroid.x], popup=popup_html, icon=folium.Icon(icon='info-sign', color='blue') ).add_to(markers_group) markers_group.add_to(m) # Add layer control folium.LayerControl(collapsed=False).add_to(m) return m def create_chart(df, indicators, chart_type='bar', color_scheme='Default', top_n=20, use_country_colors=False): """ Create interactive Plotly charts with customizable options. Supports multiple indicators with grouped/stacked visualizations. """ # Get color sequence colors = COLOR_SCHEMES.get(color_scheme, px.colors.qualitative.Plotly) # Sort and limit data sorted_df = df.sort_values(indicators[0], ascending=False).head(top_n) # Determine color column: use country name for specific country queries, continent for regional color_col = 'name' if use_country_colors else 'continent' # Check if we have multiple indicators for grouped visualization has_multiple_indicators = len(indicators) > 1 if chart_type == 'bar': if has_multiple_indicators: # Create grouped bar chart for multiple indicators fig = go.Figure() for i, ind in enumerate(indicators): fig.add_trace(go.Bar( name=ind.replace('_', ' ').title(), x=sorted_df['name'], y=sorted_df[ind], marker_color=colors[i % len(colors)] )) fig.update_layout( barmode='group', title=f'Comparison: {", ".join([i.replace("_", " ").title() for i in indicators])}', xaxis_title='Country', yaxis_title='Value', height=500 ) else: fig = px.bar( sorted_df, x='name', y=indicators[0], color=color_col, title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}', labels={'name': 'Country', indicators[0]: indicators[0].replace('_', ' ').title(), color_col: ''}, color_discrete_sequence=colors, height=500 ) elif chart_type == 'horizontal_bar': if has_multiple_indicators: # Create grouped horizontal bar chart fig = go.Figure() for i, ind in enumerate(indicators): fig.add_trace(go.Bar( name=ind.replace('_', ' ').title(), y=sorted_df['name'], x=sorted_df[ind], orientation='h', marker_color=colors[i % len(colors)] )) fig.update_layout( barmode='group', title=f'Comparison: {", ".join([i.replace("_", " ").title() for i in indicators])}', xaxis_title='Value', yaxis_title='Country', height=600 ) else: fig = px.bar( sorted_df, y='name', x=indicators[0], color=color_col, title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}', labels={'name': 'Country', indicators[0]: indicators[0].replace('_', ' ').title(), color_col: ''}, color_discrete_sequence=colors, orientation='h', height=600 ) fig.update_layout(yaxis={'categoryorder': 'total ascending'}) elif chart_type == 'scatter': x_col = indicators[0] if len(indicators) > 0 else 'gdp_md_est' y_col = indicators[1] if len(indicators) > 1 else 'pop_est' fig = px.scatter( df, x=x_col, y=y_col, size='pop_est', color=color_col, hover_name='name', title=f'{x_col.replace("_", " ").title()} vs {y_col.replace("_", " ").title()}', labels={ x_col: x_col.replace('_', ' ').title(), y_col: y_col.replace('_', ' ').title(), color_col: '' }, color_discrete_sequence=colors, height=500 ) elif chart_type == 'pie': if has_multiple_indicators: # Create subplots for multiple indicators from plotly.subplots import make_subplots fig = make_subplots( rows=1, cols=len(indicators), specs=[[{'type': 'pie'}] * len(indicators)], subplot_titles=[ind.replace('_', ' ').title() for ind in indicators] ) for i, ind in enumerate(indicators): pie_trace = go.Pie( values=sorted_df[ind], labels=sorted_df['name'], name=ind.replace('_', ' ').title(), textposition='inside', textinfo='percent+label', marker_colors=colors[:len(sorted_df)] ) fig.add_trace(pie_trace, row=1, col=i+1) fig.update_layout( title=f'Top {top_n} Countries: {", ".join([i.replace("_", " ").title() for i in indicators])}', height=500, showlegend=False ) else: fig = px.pie( sorted_df, values=indicators[0], names='name', title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}', color_discrete_sequence=colors, height=500 ) fig.update_traces(textposition='inside', textinfo='percent+label') elif chart_type == 'treemap': if has_multiple_indicators: # Create subplots for multiple indicators from plotly.subplots import make_subplots fig = make_subplots( rows=1, cols=len(indicators), specs=[[{'type': 'treemap'}] * len(indicators)], subplot_titles=[ind.replace('_', ' ').title() for ind in indicators] ) for i, ind in enumerate(indicators): # Build treemap data manually for subplot treemap_trace = go.Treemap( labels=sorted_df['name'].tolist() + sorted_df['continent'].unique().tolist(), parents=sorted_df['continent'].tolist() + [''] * len(sorted_df['continent'].unique()), values=sorted_df[ind].tolist() + [0] * len(sorted_df['continent'].unique()), name=ind.replace('_', ' ').title(), textinfo='label+value', hovertemplate='%{label}
Value: %{value:,.0f}' ) fig.add_trace(treemap_trace, row=1, col=i+1) fig.update_layout( title=f'Top {top_n} Countries: {", ".join([i.replace("_", " ").title() for i in indicators])}', height=600 ) else: fig = px.treemap( sorted_df, path=['continent', 'name'], values=indicators[0], title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}', color='continent', color_discrete_sequence=colors, height=600, hover_data={indicators[0]: ':,.0f'} ) elif chart_type == 'bubble': x_col = indicators[0] if len(indicators) > 0 else 'gdp_md_est' y_col = indicators[1] if len(indicators) > 1 else 'pop_est' size_col = indicators[0] fig = px.scatter( df, x=x_col, y=y_col, size=size_col, color=color_col, hover_name='name', title=f'Bubble Chart: {x_col.replace("_", " ").title()} vs {y_col.replace("_", " ").title()}', labels={x_col: x_col.replace('_', ' ').title(), y_col: y_col.replace('_', ' ').title(), color_col: ''}, color_discrete_sequence=colors, size_max=60, height=500 ) else: # default bar fig = px.bar( sorted_df, x='name', y=indicators[0], color=color_col, title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}', color_discrete_sequence=colors, height=500 ) # Remove legend title for cleaner appearance fig.update_layout( xaxis_tickangle=-45, template='plotly_white', legend_title_text='' ) return fig def create_data_table(df): """ Create formatted data table """ # Select relevant columns display_cols = ['name', 'continent', 'pop_est', 'gdp_md_est', 'pop_density', 'gdp_per_capita'] table_df = df[display_cols].copy() # Rename columns table_df.columns = ['Country', 'Continent', 'Population', 'GDP (Million $)', 'Pop. Density (per kmΒ²)', 'GDP per Capita ($)'] # Format numbers table_df['Population'] = table_df['Population'].apply(lambda x: f'{x:,.0f}') table_df['GDP (Million $)'] = table_df['GDP (Million $)'].apply(lambda x: f'${x:,.0f}') table_df['Pop. Density (per kmΒ²)'] = table_df['Pop. Density (per kmΒ²)'].apply(lambda x: f'{x:.2f}') table_df['GDP per Capita ($)'] = table_df['GDP per Capita ($)'].apply(lambda x: f'${x:,.2f}') return table_df.sort_values('Population', ascending=False).head(50) def process_query(user_query, output_format, chart_type, map_style, color_scheme, choropleth_color, top_n): """ Main processing function with advanced options. Routes queries to appropriate data sources based on LLM parsing. """ try: # Parse query with LLM parsed = parse_query_with_llm(user_query) # Check data_source to route to appropriate handler data_source = parsed.get("data_source", "countries") # ================================================================= # EXTERNAL API ROUTING: POI Data (OpenStreetMap Overpass) # ================================================================= if data_source == "osm_pois": pois, status_msg = fetch_poi_data(parsed) if not pois: # Check if it's a rate limit or timeout issue if "timeout" in status_msg.lower() or "rate" in status_msg.lower() or "429" in status_msg or "504" in status_msg: error_msg = f"⏳ **OpenStreetMap API Temporarily Unavailable**\n\n" error_msg += "The Overpass API is experiencing high load or rate limiting.\n\n" error_msg += "**Please try:**\n" error_msg += "- Wait 30 seconds and try again\n" error_msg += "- Use a different city\n" error_msg += "- Try the Country/Region queries instead\n\n" error_msg += f"*Technical details: {status_msg}*" else: error_msg = f"❌ **POI Query Failed**\n\n{status_msg}\n\n" error_msg += "**Tip:** Try queries like:\n" error_msg += "- 'Find restaurants in Paris'\n" error_msg += "- 'Show hotels near Tokyo'\n" error_msg += "- 'Find hospitals in Berlin'" return None, None, None, error_msg, None, None map_html = None map_file = None poi_df = None csv_file = None # Create POI map (if requested) if output_format in ['All', 'Map']: city_name = parsed.get("locations", ["Unknown"])[0] poi_map = create_poi_map(pois, city_name, map_style) if poi_map: map_html = poi_map._repr_html_() # Save map to temp file map_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w', encoding='utf-8') poi_map.save(map_file.name) map_file = map_file.name # Create POI table (if requested) if output_format in ['All', 'Table']: poi_df = pd.DataFrame([{ 'Name': p.get('name', 'Unknown'), 'Type': (p.get('tags', {}).get('amenity') or p.get('tags', {}).get('tourism') or p.get('tags', {}).get('shop', 'POI')), 'Lat': p.get('lat', ''), 'Lon': p.get('lon', '') } for p in pois]) # Save to CSV csv_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv', mode='w', encoding='utf-8') poi_df.to_csv(csv_file.name, index=False) csv_file = csv_file.name summary = f"πŸ” **Query:** {user_query}\n\n" summary += "πŸ“ **Data Source:** OpenStreetMap (Overpass API)\n" summary += f"🏷️ **Category:** {parsed.get('poi_category', 'All')}\n" summary += f"πŸ“Š **Results:** {status_msg}\n\n" if output_format in ['All', 'Map']: summary += "πŸ’‘ *Use layer control (top-right) to toggle POI categories*" return map_html, None, poi_df, summary, map_file, csv_file # ================================================================= # EXTERNAL API ROUTING: Conservation Data (WWF GLOBIL) # ================================================================= if data_source == "conservation": # Try to fetch actual feature data first feature_data, status_msg, metadata = fetch_conservation_features_for_query(parsed) map_html = None chart_fig = None map_file = None csv_file = None cons_df = None if feature_data: # We have actual feature data - create visualizations based on output_format # Create map (if requested) if output_format in ['All', 'Map']: cons_map = create_conservation_map(feature_data, map_style) if cons_map: map_html = cons_map._repr_html_() map_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w', encoding='utf-8') cons_map.save(map_file.name) map_file = map_file.name # Create chart (if requested) if output_format in ['All', 'Chart']: chart_fig = create_conservation_chart(feature_data, chart_type if chart_type in ['bar', 'pie'] else 'bar') # Create table (if requested) if output_format in ['All', 'Table']: table_rows = [] for data in feature_data: dataset_title = data.get('dataset', {}).get('title', 'Unknown') for feat in data.get('features', [])[:50]: # Limit rows props = feat.get('properties', {}) row = {'Dataset': dataset_title} # Get first few meaningful properties for key, val in list(props.items())[:4]: if val and str(val).strip() and key.lower() not in ['objectid', 'fid', 'shape']: row[key] = str(val)[:50] if len(row) > 1: table_rows.append(row) if table_rows: cons_df = pd.DataFrame(table_rows) else: # Fallback to metadata table cons_df = create_conservation_table(metadata) summary = f"πŸ” **Query:** {user_query}\n\n" summary += "πŸ“ **Data Source:** WWF GLOBIL (ArcGIS Hub)\n" summary += f"🌿 **Topic:** {parsed.get('conservation_topic', 'General')}\n" summary += f"πŸ“Š **Results:** {status_msg}\n\n" if output_format in ['All', 'Map']: summary += "πŸ—ΊοΈ *Use layer control to toggle conservation areas*\n" if output_format in ['All', 'Chart']: summary += "πŸ“Š *Chart shows feature distribution across datasets*" else: # Fallback to metadata only (no feature data available) datasets, fallback_msg = fetch_conservation_data(parsed) if not datasets: error_msg = f"❌ **Conservation Data Query Failed**\n\n{fallback_msg}\n\n" error_msg += "**Available Topics:**\n" error_msg += "- forests, wildlife, oceans, freshwater, climate\n\n" error_msg += "**Try queries like:**\n" error_msg += "- 'Search for deforestation datasets'\n" error_msg += "- 'Find wildlife conservation data'\n" error_msg += "- 'Show ocean protection datasets'" return None, None, None, error_msg, None, None cons_df = create_conservation_table(datasets) summary = f"πŸ” **Query:** {user_query}\n\n" summary += "πŸ“ **Data Source:** WWF GLOBIL (ArcGIS Hub)\n" summary += f"🌿 **Topic:** {parsed.get('conservation_topic', 'General')}\n" summary += f"πŸ“Š **Results:** {status_msg or fallback_msg}\n\n" summary += "ℹ️ *Feature geometries not available for these datasets*\n" summary += "πŸ’‘ *Showing dataset metadata table*" # Save table to CSV if cons_df is not None and not cons_df.empty: csv_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv', mode='w', encoding='utf-8') cons_df.to_csv(csv_file.name, index=False) csv_file = csv_file.name return map_html, chart_fig, cons_df, summary, map_file, csv_file # ================================================================= # DEFAULT: Country/Region Data (Natural Earth) # ================================================================= # Fetch data from local shapefile gdf = fetch_geospatial_data(parsed) if gdf.empty: return None, None, None, "No data found for your query. Try different locations or indicators.", None, None # Map LLM-parsed indicators to actual column names indicator_mapping = { 'gdp': 'gdp_md_est', 'gdp per capita': 'gdp_per_capita', 'population': 'pop_est', 'population density': 'pop_density', 'pop_density': 'pop_density', 'gdp_md_est': 'gdp_md_est', 'pop_est': 'pop_est', 'gdp_per_capita': 'gdp_per_capita', # Additional mappings for common LLM responses 'economic': 'gdp_md_est', 'economic indicators': 'gdp_md_est', 'economy': 'gdp_md_est', 'economics': 'gdp_md_est', 'gross domestic product': 'gdp_md_est', 'density': 'pop_density', 'people': 'pop_est', 'inhabitants': 'pop_est' } # Get indicators from LLM parsing - support multiple indicators llm_indicators = parsed.get('indicators', []) mapped_indicators = [] seen = set() # Avoid duplicates for ind in llm_indicators: ind_lower = ind.lower().strip() if ind_lower in indicator_mapping: col = indicator_mapping[ind_lower] if col not in seen: mapped_indicators.append(col) seen.add(col) # Default to population if no valid indicator found if not mapped_indicators: mapped_indicators = ['pop_est'] # Get display names for all indicators indicator_display = ', '.join(llm_indicators) if llm_indicators else 'Population' # Detect if this is a specific country query (not continent/regional) # Use country colors in charts when querying specific countries locations = parsed.get('locations', []) continents = {'Africa', 'Asia', 'Europe', 'North America', 'South America', 'Oceania', 'Antarctica'} use_country_colors = bool(locations) and not any(loc in continents for loc in locations) # Generate outputs based on format map_html = None chart_fig = None table_df = None map_file = None csv_file = None summary = f"πŸ” **Query:** {user_query}\n\n" summary += f"πŸ“ **Locations:** {', '.join(parsed.get('locations', ['Global']))}\n" summary += f"πŸ“Š **Indicators:** {indicator_display}\n" summary += f"🌍 **Countries found:** {len(gdf)}\n\n" summary += f"βš™οΈ **Options:** Chart: {chart_type} | Map: {map_style} | Top N: {top_n}" # Apply top_n limit for visualizations (sort by primary indicator) primary_indicator = mapped_indicators[0] gdf_sorted = gdf.sort_values(primary_indicator, ascending=False) gdf_top_n = gdf_sorted.head(int(top_n)) if output_format in ['All', 'Map']: m = create_interactive_map(gdf_top_n, mapped_indicators, map_style, choropleth_color) map_html = m._repr_html_() # Save map to temp file for download map_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w', encoding='utf-8') m.save(map_file.name) map_file = map_file.name if output_format in ['All', 'Chart']: chart_fig = create_chart(gdf, mapped_indicators, chart_type, color_scheme, int(top_n), use_country_colors) if output_format in ['All', 'Table']: table_df = create_data_table(gdf) # Save table to temp CSV file for download csv_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv', mode='w', encoding='utf-8') table_df.to_csv(csv_file.name, index=False) csv_file = csv_file.name return map_html, chart_fig, table_df, summary, map_file, csv_file except Exception as e: error_msg = f"Error processing query: {str(e)}\n\nPlease try rephrasing your query." return None, None, None, error_msg, None, None # Gradio Interface def create_interface(): with gr.Blocks(title="Geospatial AI Query System") as demo: gr.Markdown(""" # 🌍 Geospatial AI Query System ### Natural Language Interface for Geographic Data Ask questions about countries, regions, and global indicators using natural language. ℹ️ *Note: Current data includes latest available statistics. Time-range queries will filter when historical data is available.* """) with gr.Row(): with gr.Column(scale=3): query_input = gr.Textbox( label="Your Query", placeholder="E.g., Show me GDP and population of BRICS countries", lines=2 ) with gr.Column(scale=1): output_format = gr.Radio( choices=['All', 'Map', 'Chart', 'Table'], value='All', label="Output Format" ) # Advanced Options in Accordion with gr.Accordion("βš™οΈ Advanced Options", open=False): with gr.Row(): chart_type = gr.Dropdown( choices=['bar', 'horizontal_bar', 'scatter', 'pie', 'treemap', 'bubble'], value='bar', label="πŸ“Š Chart Type" ) map_style = gr.Dropdown( choices=list(MAP_STYLES.keys()), value='Light', label="πŸ—ΊοΈ Map Style" ) with gr.Row(): color_scheme = gr.Dropdown( choices=list(COLOR_SCHEMES.keys()), value='Default', label="🎨 Chart Colors" ) choropleth_color = gr.Dropdown( choices=list(CHOROPLETH_COLORS.keys()), value='Yellow-Orange-Red', label="🌈 Map Colors" ) with gr.Row(): top_n = gr.Slider( minimum=5, maximum=50, value=20, step=5, label="πŸ”’ Top N Countries" ) submit_btn = gr.Button("πŸ” Analyze", variant="primary", size="lg") gr.Markdown("### Results") summary_output = gr.Textbox(label="Query Summary", lines=4) with gr.Tabs(): with gr.Tab("πŸ“Š Chart"): chart_output = gr.Plot(label="Interactive Chart") with gr.Tab("πŸ—ΊοΈ Map"): map_output = gr.HTML(label="Interactive Map") map_download = gr.File(label="πŸ“₯ Download Map (HTML)", visible=True) with gr.Tab("πŸ“‹ Table"): table_output = gr.Dataframe(label="Data Table") csv_download = gr.File(label="πŸ“₯ Download Table (CSV)", visible=True) # Examples gr.Examples( examples=[ # Country/Regional Data (Natural Earth) ["Show me population of Asian countries", "All"], ["Compare GDP of European nations", "Chart"], ["What's the population density in Africa?", "Map"], ["Show me the top 10 economies in the world", "Chart"], # POI Queries (OpenStreetMap Overpass API) ["Find restaurants and cafes in Paris", "Map"], ["Show hospitals near Berlin", "Map"], ["What tourist attractions are in Tokyo?", "Map"], ["Find hotels in London", "Table"], # Conservation Data (WWF GLOBIL - with Map & Charts) ["Show protected area datasets", "All"], ["Find deforestation data", "All"], ["Search for marine conservation areas", "Map"], ["Show wildlife habitat data", "Chart"], ], inputs=[query_input, output_format] ) # Event handler submit_btn.click( fn=process_query, inputs=[query_input, output_format, chart_type, map_style, color_scheme, choropleth_color, top_n], outputs=[map_output, chart_output, table_output, summary_output, map_download, csv_download] ) gr.Markdown(""" --- **About:** This app uses LLMs to parse natural language queries and visualize global geospatial data. **Data Sources:** - 🌍 **Countries/Regions:** Natural Earth, World Bank Open Data - πŸ“ **Points of Interest:** OpenStreetMap (Overpass API) - 🌿 **Conservation Data:** WWF GLOBIL (ArcGIS Hub) **Built by:** [rifatSDAS](https://github.com/rifatSDAS) """) return demo if __name__ == "__main__": demo = create_interface() # Enable queue for better concurrency handling on HF Spaces demo.queue(default_concurrency_limit=10) demo.launch(theme=gr.themes.Soft()) # To enable Progressive Web App (PWA) features, uncomment the line below # demo.launch(theme=gr.themes.Soft(), pwa=True)