Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import geopandas as gpd | |
| import folium | |
| from folium import plugins | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from huggingface_hub import InferenceClient | |
| import json | |
| import os | |
| import tempfile | |
| import io | |
| from datetime import datetime | |
| import numpy as np | |
| from pathlib import Path | |
| import warnings | |
| import logging | |
| # Suppress GeoPandas CRS warnings (area/centroid calculations are approximate for demo purposes) | |
| warnings.filterwarnings('ignore', message='.*Geometry is in a geographic CRS.*') | |
| # Suppress asyncio cleanup warnings in Python 3.13+ (harmless but noisy) | |
| warnings.filterwarnings('ignore', message='.*Invalid file descriptor.*') | |
| logging.getLogger('asyncio').setLevel(logging.CRITICAL) | |
| import branca.colormap as cm | |
| import zipfile | |
| import urllib.request | |
| def download_natural_earth_data(data_dir: Path) -> Path: | |
| """Download Natural Earth countries shapefile if not present.""" | |
| shp_file = data_dir / "ne_110m_admin_0_countries.shp" | |
| if shp_file.exists(): | |
| return shp_file | |
| # Create data directory | |
| data_dir.mkdir(parents=True, exist_ok=True) | |
| # Natural Earth 110m countries download URL | |
| url = "https://naciscdn.org/naturalearth/110m/cultural/ne_110m_admin_0_countries.zip" | |
| zip_path = data_dir / "ne_110m_admin_0_countries.zip" | |
| print(f"Downloading Natural Earth data from {url}...") | |
| urllib.request.urlretrieve(url, zip_path) | |
| print(f"Extracting to {data_dir}...") | |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
| zip_ref.extractall(data_dir) | |
| # Clean up zip file | |
| zip_path.unlink() | |
| print("Natural Earth data ready.") | |
| return shp_file | |
| def format_number(num): | |
| """Format large numbers with K/M/B/T suffixes for better readability.""" | |
| if num is None or (isinstance(num, float) and np.isnan(num)): | |
| return 'N/A' | |
| abs_num = abs(num) | |
| if abs_num >= 1e12: | |
| return f'{num/1e12:.1f}T' | |
| elif abs_num >= 1e9: | |
| return f'{num/1e9:.1f}B' | |
| elif abs_num >= 1e6: | |
| return f'{num/1e6:.1f}M' | |
| elif abs_num >= 1e3: | |
| return f'{num/1e3:.1f}K' | |
| else: | |
| return f'{num:.1f}' | |
| # Path to local Natural Earth data (downloaded at runtime if not present) | |
| DATA_DIR = Path(__file__).parent / "data" / "ne_110m_admin_0_countries" | |
| NATURAL_EARTH_SHP = download_natural_earth_data(DATA_DIR) | |
| # Initialize HF Inference Client | |
| client = InferenceClient(token=os.environ.get("HF_TOKEN")) | |
| # ===== UI/UX Enhancement Constants ===== | |
| MAP_STYLES = { | |
| "Light": "CartoDB positron", | |
| "Dark": "CartoDB dark_matter", | |
| "Street": "OpenStreetMap", | |
| "Satellite": "Esri.WorldImagery" | |
| } | |
| COLOR_SCHEMES = { | |
| "Default": px.colors.qualitative.Plotly, | |
| "Vivid": px.colors.qualitative.Vivid, | |
| "Pastel": px.colors.qualitative.Pastel, | |
| "Bold": px.colors.qualitative.Bold, | |
| "Earth": px.colors.qualitative.Safe | |
| } | |
| CHOROPLETH_COLORS = { | |
| "Yellow-Orange-Red": "YlOrRd", | |
| "Yellow-Green-Blue": "YlGnBu", | |
| "Purple-Red": "PuRd", | |
| "Blue-Purple": "BuPu", | |
| "Greens": "Greens", | |
| "Blues": "Blues", | |
| "Oranges": "OrRd", | |
| "Spectral": "Spectral" | |
| } | |
| INDICATORS = { | |
| "Population": "pop_est", | |
| "GDP (Million $)": "gdp_md_est", | |
| "Population Density": "pop_density", | |
| "GDP per Capita": "gdp_per_capita" | |
| } | |
| # Global cache for world data | |
| _world_data_cache = None | |
| def load_world_data(): | |
| """Load world countries geospatial data""" | |
| global _world_data_cache | |
| if _world_data_cache is None: | |
| raw = gpd.read_file(NATURAL_EARTH_SHP) | |
| # Select only the columns we need (using original uppercase names) | |
| # and rename them to match expected lowercase names | |
| _world_data_cache = raw[['NAME', 'CONTINENT', 'POP_EST', 'GDP_MD', 'geometry']].copy() | |
| _world_data_cache.columns = ['name', 'continent', 'pop_est', 'gdp_md_est', 'geometry'] | |
| return _world_data_cache | |
| def parse_query_with_llm(user_query): | |
| """ | |
| Use LLM to parse natural language query into structured format | |
| """ | |
| system_prompt = """You are a geospatial and geographic data query parser. Extract structured information from user queries. | |
| Response format (JSON only): | |
| { | |
| "locations": ["country/region names"], | |
| "indicators": ["GDP", "population", "CO2 emissions", etc.], | |
| "time_range": {"start": "YYYY", "end": "YYYY"}, | |
| "visualization": "map/chart/table", | |
| "aggregation": "sum/average/comparison", | |
| "query_type": "single_country/multi_country/regional/global" | |
| } | |
| Examples: | |
| - "Show me GDP of Asian countries" โ locations: Asia, indicators: GDP, visualization: chart | |
| - "Compare population density in Europe vs Africa" โ locations: [Europe, Africa], indicators: population density | |
| - "Environmental data for Brazil over last decade" โ locations: [Brazil], indicators: environmental | |
| Return ONLY valid JSON, no explanations.""" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"Parse this query: {user_query}"} | |
| ] | |
| try: | |
| response = client.chat_completion( | |
| messages=messages, | |
| model="meta-llama/Llama-3.1-8B-Instruct", | |
| max_tokens=500, | |
| temperature=0.1 | |
| ) | |
| raw_content = response.choices[0].message.content | |
| print(f"LLM raw response: {raw_content}") | |
| parsed = json.loads(raw_content) | |
| print(f"Parsed query: {parsed}") | |
| return parsed | |
| except Exception as e: | |
| print(f"LLM parsing error: {e}") | |
| print(f"HF_TOKEN set: {bool(os.environ.get('HF_TOKEN'))}") | |
| return { | |
| "locations": [], | |
| "indicators": ["population", "gdp_md_est"], | |
| "visualization": "table", | |
| "query_type": "global" | |
| } | |
| def fetch_geospatial_data(parsed_query): | |
| """ | |
| Fetch and process geospatial data based on parsed query | |
| """ | |
| world = load_world_data() | |
| # Filter by locations | |
| locations = parsed_query.get("locations", []) | |
| # Treat "global", "world", "worldwide", "all" as requests for all data | |
| global_terms = {"global", "world", "worldwide", "all", "earth", "globe"} | |
| is_global_query = not locations or (len(locations) == 1 and locations[0].lower() in global_terms) | |
| if not is_global_query: | |
| # Filter by continent or country | |
| mask = world['continent'].isin(locations) | world['name'].isin(locations) | |
| filtered_data = world[mask].copy() | |
| else: | |
| filtered_data = world.copy() | |
| # Add computed indicators | |
| filtered_data.loc[:, 'pop_density'] = filtered_data['pop_est'] / filtered_data['geometry'].area * 1000000 | |
| filtered_data.loc[:, 'gdp_per_capita'] = filtered_data['gdp_md_est'] / filtered_data['pop_est'] * 1000000 | |
| return filtered_data | |
| def create_interactive_map(gdf, indicators, map_style='Light', color_scale='Yellow-Orange-Red'): | |
| """ | |
| Create an interactive Folium map with customizable style and colors. | |
| Supports multiple indicators with toggleable layer groups. | |
| Args: | |
| gdf: GeoDataFrame with country data | |
| indicators: Single indicator string or list of indicator column names | |
| map_style: Map tile style | |
| color_scale: Color scheme for choropleth | |
| """ | |
| # Ensure indicators is a list | |
| if isinstance(indicators, str): | |
| indicators = [indicators] | |
| # Calculate center | |
| center_lat = gdf.geometry.centroid.y.mean() | |
| center_lon = gdf.geometry.centroid.x.mean() | |
| # Get tile style | |
| tiles = MAP_STYLES.get(map_style, 'CartoDB positron') | |
| # Create map | |
| m = folium.Map( | |
| location=[center_lat, center_lon], | |
| zoom_start=2, | |
| tiles=tiles | |
| ) | |
| # Color schemes for different indicators (these match branca colormap names exactly) | |
| color_scheme_list = ['YlOrRd', 'YlGnBu', 'Greens', 'Blues', 'PuRd', 'OrRd'] | |
| # Map color scheme names to branca colormap objects | |
| color_map_dict = { | |
| 'YlOrRd': cm.linear.YlOrRd_09, | |
| 'YlGnBu': cm.linear.YlGnBu_09, | |
| 'Blues': cm.linear.Blues_09, | |
| 'Greens': cm.linear.Greens_09, | |
| 'Reds': cm.linear.Reds_09, | |
| 'PuRd': cm.linear.PuRd_09, | |
| 'OrRd': cm.linear.OrRd_09, | |
| 'BuPu': cm.linear.BuPu_09, | |
| 'Spectral': cm.linear.Spectral_11 | |
| } | |
| # Get user-selected color scheme for first indicator | |
| fill_color = CHOROPLETH_COLORS.get(color_scale, 'YlOrRd') | |
| # Store legends to add after all layers (for proper stacking) | |
| legends_html = [] | |
| # Create a layer group for each indicator | |
| for idx, indicator in enumerate(indicators): | |
| # Use user-selected color for first indicator, cycle through others for additional indicators | |
| if idx == 0: | |
| current_color_name = fill_color | |
| else: | |
| current_color_name = color_scheme_list[idx % len(color_scheme_list)] | |
| # Get min/max for this indicator | |
| valid_values = gdf[indicator].dropna() | |
| if valid_values.empty: | |
| vmin, vmax = 0, 1 | |
| else: | |
| vmin = float(valid_values.min()) | |
| vmax = float(valid_values.max()) | |
| # Ensure valid range | |
| if np.isnan(vmin) or np.isnan(vmax) or np.isinf(vmin) or np.isinf(vmax): | |
| vmin, vmax = 0, 1 | |
| if vmax <= vmin: | |
| vmax = vmin + 1 | |
| # Create feature group for this indicator | |
| indicator_name = indicator.replace('_', ' ').title() | |
| feature_group = folium.FeatureGroup(name=f"๐ {indicator_name}", show=(idx == 0)) | |
| # Get the base branca colormap and scale it to our data range | |
| base_colormap = color_map_dict.get(current_color_name, cm.linear.YlOrRd_09) | |
| scaled_colormap = base_colormap.scale(vmin, vmax) | |
| scaled_colormap.caption = f"{indicator_name} ({format_number(vmin)} - {format_number(vmax)})" | |
| # Create style function using the SAME scaled colormap for exact color matching | |
| # Use default arguments to capture current loop values | |
| def make_style_function(cmap, ind, v_min, v_max): | |
| def style_function(feature): | |
| value = feature['properties'].get(ind) | |
| if value is None or (isinstance(value, float) and np.isnan(value)): | |
| return { | |
| 'fillColor': 'lightgray', | |
| 'fillOpacity': 0.7, | |
| 'color': 'white', | |
| 'weight': 0.5 | |
| } | |
| # Clamp value to range | |
| clamped_value = max(v_min, min(v_max, value)) | |
| return { | |
| 'fillColor': cmap(clamped_value), | |
| 'fillOpacity': 0.7, | |
| 'color': 'white', | |
| 'weight': 0.5 | |
| } | |
| return style_function | |
| style_fn = make_style_function(scaled_colormap, indicator, vmin, vmax) | |
| # Convert GeoDataFrame to GeoJSON with indicator values in properties | |
| geojson_data = json.loads(gdf.to_json()) | |
| # Add the indicator value to each feature's properties | |
| for i, feature in enumerate(geojson_data['features']): | |
| ind_value = gdf.iloc[i][indicator] | |
| feature['properties'][indicator] = float(ind_value) if pd.notna(ind_value) else None | |
| # Use GeoJson instead of Choropleth for precise color control | |
| geojson_layer = folium.GeoJson( | |
| geojson_data, | |
| style_function=style_fn, | |
| tooltip=folium.GeoJsonTooltip( | |
| fields=['name', indicator], | |
| aliases=['Country:', f'{indicator_name}:'], | |
| style="background-color: white; color: #333333; font-family: arial; font-size: 12px; padding: 10px;" | |
| ) | |
| ) | |
| geojson_layer.add_to(feature_group) | |
| feature_group.add_to(m) | |
| # Build legend HTML for this indicator with vertical offset | |
| # Only show first indicator's legend by default (others shown when layer toggled) | |
| legend_bottom = 50 + (idx * 80) # Stack legends vertically | |
| legend_html = f''' | |
| <div id="legend-{idx}" style=" | |
| position: fixed; | |
| bottom: {legend_bottom}px; | |
| right: 10px; | |
| z-index: 1000; | |
| background-color: white; | |
| padding: 8px 12px; | |
| border-radius: 5px; | |
| box-shadow: 0 2px 6px rgba(0,0,0,0.3); | |
| font-size: 11px; | |
| font-family: Arial, sans-serif; | |
| max-width: 200px; | |
| "> | |
| <div style="font-weight: bold; margin-bottom: 5px;">{indicator_name}</div> | |
| <div style="display: flex; align-items: center;"> | |
| <div style=" | |
| width: 150px; | |
| height: 12px; | |
| background: linear-gradient(to right, {scaled_colormap(vmin)}, {scaled_colormap((vmin+vmax)/2)}, {scaled_colormap(vmax)}); | |
| border-radius: 2px; | |
| "></div> | |
| </div> | |
| <div style="display: flex; justify-content: space-between; margin-top: 3px;"> | |
| <span>{format_number(vmin)}</span> | |
| <span>{format_number(vmax)}</span> | |
| </div> | |
| </div> | |
| ''' | |
| legends_html.append(legend_html) | |
| # Add all legends to the map | |
| for legend_html in legends_html: | |
| m.get_root().html.add_child(folium.Element(legend_html)) | |
| # Add tooltips with all indicator values in a toggleable layer group | |
| markers_group = folium.FeatureGroup(name="๐ Info Markers", show=True) | |
| for _, row in gdf.iterrows(): | |
| popup_html = f"<b>{row['name']}</b><br>" | |
| popup_html += f"Continent: {row['continent']}<br><hr>" | |
| popup_html += f"Population: {format_number(row['pop_est'])}<br>" | |
| popup_html += f"GDP: ${format_number(row['gdp_md_est'])}M<br>" | |
| if 'pop_density' in row: | |
| popup_html += f"Pop Density: {row['pop_density']:.1f}/kmยฒ<br>" | |
| if 'gdp_per_capita' in row: | |
| popup_html += f"GDP/Capita: ${row['gdp_per_capita']:,.0f}<br>" | |
| folium.Marker( | |
| location=[row.geometry.centroid.y, row.geometry.centroid.x], | |
| popup=popup_html, | |
| icon=folium.Icon(icon='info-sign', color='blue') | |
| ).add_to(markers_group) | |
| markers_group.add_to(m) | |
| # Add layer control | |
| folium.LayerControl(collapsed=False).add_to(m) | |
| return m | |
| def create_chart(df, indicators, chart_type='bar', color_scheme='Default', top_n=20, use_country_colors=False): | |
| """ | |
| Create interactive Plotly charts with customizable options. | |
| Supports multiple indicators with grouped/stacked visualizations. | |
| """ | |
| # Get color sequence | |
| colors = COLOR_SCHEMES.get(color_scheme, px.colors.qualitative.Plotly) | |
| # Sort and limit data | |
| sorted_df = df.sort_values(indicators[0], ascending=False).head(top_n) | |
| # Determine color column: use country name for specific country queries, continent for regional | |
| color_col = 'name' if use_country_colors else 'continent' | |
| # Check if we have multiple indicators for grouped visualization | |
| has_multiple_indicators = len(indicators) > 1 | |
| if chart_type == 'bar': | |
| if has_multiple_indicators: | |
| # Create grouped bar chart for multiple indicators | |
| fig = go.Figure() | |
| for i, ind in enumerate(indicators): | |
| fig.add_trace(go.Bar( | |
| name=ind.replace('_', ' ').title(), | |
| x=sorted_df['name'], | |
| y=sorted_df[ind], | |
| marker_color=colors[i % len(colors)] | |
| )) | |
| fig.update_layout( | |
| barmode='group', | |
| title=f'Comparison: {", ".join([i.replace("_", " ").title() for i in indicators])}', | |
| xaxis_title='Country', | |
| yaxis_title='Value', | |
| height=500 | |
| ) | |
| else: | |
| fig = px.bar( | |
| sorted_df, | |
| x='name', | |
| y=indicators[0], | |
| color=color_col, | |
| title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}', | |
| labels={'name': 'Country', indicators[0]: indicators[0].replace('_', ' ').title(), color_col: ''}, | |
| color_discrete_sequence=colors, | |
| height=500 | |
| ) | |
| elif chart_type == 'horizontal_bar': | |
| if has_multiple_indicators: | |
| # Create grouped horizontal bar chart | |
| fig = go.Figure() | |
| for i, ind in enumerate(indicators): | |
| fig.add_trace(go.Bar( | |
| name=ind.replace('_', ' ').title(), | |
| y=sorted_df['name'], | |
| x=sorted_df[ind], | |
| orientation='h', | |
| marker_color=colors[i % len(colors)] | |
| )) | |
| fig.update_layout( | |
| barmode='group', | |
| title=f'Comparison: {", ".join([i.replace("_", " ").title() for i in indicators])}', | |
| xaxis_title='Value', | |
| yaxis_title='Country', | |
| height=600 | |
| ) | |
| else: | |
| fig = px.bar( | |
| sorted_df, | |
| y='name', | |
| x=indicators[0], | |
| color=color_col, | |
| title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}', | |
| labels={'name': 'Country', indicators[0]: indicators[0].replace('_', ' ').title(), color_col: ''}, | |
| color_discrete_sequence=colors, | |
| orientation='h', | |
| height=600 | |
| ) | |
| fig.update_layout(yaxis={'categoryorder': 'total ascending'}) | |
| elif chart_type == 'scatter': | |
| x_col = indicators[0] if len(indicators) > 0 else 'gdp_md_est' | |
| y_col = indicators[1] if len(indicators) > 1 else 'pop_est' | |
| fig = px.scatter( | |
| df, | |
| x=x_col, | |
| y=y_col, | |
| size='pop_est', | |
| color=color_col, | |
| hover_name='name', | |
| title=f'{x_col.replace("_", " ").title()} vs {y_col.replace("_", " ").title()}', | |
| labels={ | |
| x_col: x_col.replace('_', ' ').title(), | |
| y_col: y_col.replace('_', ' ').title(), | |
| color_col: '' | |
| }, | |
| color_discrete_sequence=colors, | |
| height=500 | |
| ) | |
| elif chart_type == 'pie': | |
| if has_multiple_indicators: | |
| # Create subplots for multiple indicators | |
| from plotly.subplots import make_subplots | |
| fig = make_subplots( | |
| rows=1, cols=len(indicators), | |
| specs=[[{'type': 'pie'}] * len(indicators)], | |
| subplot_titles=[ind.replace('_', ' ').title() for ind in indicators] | |
| ) | |
| for i, ind in enumerate(indicators): | |
| pie_trace = go.Pie( | |
| values=sorted_df[ind], | |
| labels=sorted_df['name'], | |
| name=ind.replace('_', ' ').title(), | |
| textposition='inside', | |
| textinfo='percent+label', | |
| marker_colors=colors[:len(sorted_df)] | |
| ) | |
| fig.add_trace(pie_trace, row=1, col=i+1) | |
| fig.update_layout( | |
| title=f'Top {top_n} Countries: {", ".join([i.replace("_", " ").title() for i in indicators])}', | |
| height=500, | |
| showlegend=False | |
| ) | |
| else: | |
| fig = px.pie( | |
| sorted_df, | |
| values=indicators[0], | |
| names='name', | |
| title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}', | |
| color_discrete_sequence=colors, | |
| height=500 | |
| ) | |
| fig.update_traces(textposition='inside', textinfo='percent+label') | |
| elif chart_type == 'treemap': | |
| if has_multiple_indicators: | |
| # Create subplots for multiple indicators | |
| from plotly.subplots import make_subplots | |
| fig = make_subplots( | |
| rows=1, cols=len(indicators), | |
| specs=[[{'type': 'treemap'}] * len(indicators)], | |
| subplot_titles=[ind.replace('_', ' ').title() for ind in indicators] | |
| ) | |
| for i, ind in enumerate(indicators): | |
| # Build treemap data manually for subplot | |
| treemap_trace = go.Treemap( | |
| labels=sorted_df['name'].tolist() + sorted_df['continent'].unique().tolist(), | |
| parents=sorted_df['continent'].tolist() + [''] * len(sorted_df['continent'].unique()), | |
| values=sorted_df[ind].tolist() + [0] * len(sorted_df['continent'].unique()), | |
| name=ind.replace('_', ' ').title(), | |
| textinfo='label+value', | |
| hovertemplate='<b>%{label}</b><br>Value: %{value:,.0f}<extra></extra>' | |
| ) | |
| fig.add_trace(treemap_trace, row=1, col=i+1) | |
| fig.update_layout( | |
| title=f'Top {top_n} Countries: {", ".join([i.replace("_", " ").title() for i in indicators])}', | |
| height=600 | |
| ) | |
| else: | |
| fig = px.treemap( | |
| sorted_df, | |
| path=['continent', 'name'], | |
| values=indicators[0], | |
| title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}', | |
| color='continent', | |
| color_discrete_sequence=colors, | |
| height=600, | |
| hover_data={indicators[0]: ':,.0f'} | |
| ) | |
| elif chart_type == 'bubble': | |
| x_col = indicators[0] if len(indicators) > 0 else 'gdp_md_est' | |
| y_col = indicators[1] if len(indicators) > 1 else 'pop_est' | |
| size_col = indicators[0] | |
| fig = px.scatter( | |
| df, | |
| x=x_col, | |
| y=y_col, | |
| size=size_col, | |
| color=color_col, | |
| hover_name='name', | |
| title=f'Bubble Chart: {x_col.replace("_", " ").title()} vs {y_col.replace("_", " ").title()}', | |
| labels={x_col: x_col.replace('_', ' ').title(), y_col: y_col.replace('_', ' ').title(), color_col: ''}, | |
| color_discrete_sequence=colors, | |
| size_max=60, | |
| height=500 | |
| ) | |
| else: # default bar | |
| fig = px.bar( | |
| sorted_df, | |
| x='name', | |
| y=indicators[0], | |
| color=color_col, | |
| title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}', | |
| color_discrete_sequence=colors, | |
| height=500 | |
| ) | |
| # Remove legend title for cleaner appearance | |
| fig.update_layout( | |
| xaxis_tickangle=-45, | |
| template='plotly_white', | |
| legend_title_text='' | |
| ) | |
| return fig | |
| def create_data_table(df): | |
| """ | |
| Create formatted data table | |
| """ | |
| # Select relevant columns | |
| display_cols = ['name', 'continent', 'pop_est', 'gdp_md_est', 'pop_density', 'gdp_per_capita'] | |
| table_df = df[display_cols].copy() | |
| # Rename columns | |
| table_df.columns = ['Country', 'Continent', 'Population', 'GDP (Million $)', | |
| 'Pop. Density (per kmยฒ)', 'GDP per Capita ($)'] | |
| # Format numbers | |
| table_df['Population'] = table_df['Population'].apply(lambda x: f'{x:,.0f}') | |
| table_df['GDP (Million $)'] = table_df['GDP (Million $)'].apply(lambda x: f'${x:,.0f}') | |
| table_df['Pop. Density (per kmยฒ)'] = table_df['Pop. Density (per kmยฒ)'].apply(lambda x: f'{x:.2f}') | |
| table_df['GDP per Capita ($)'] = table_df['GDP per Capita ($)'].apply(lambda x: f'${x:,.2f}') | |
| return table_df.sort_values('Population', ascending=False).head(50) | |
| def process_query(user_query, output_format, chart_type, map_style, color_scheme, choropleth_color, top_n): | |
| """ | |
| Main processing function with advanced options | |
| """ | |
| try: | |
| # Parse query with LLM | |
| parsed = parse_query_with_llm(user_query) | |
| # Fetch data | |
| gdf = fetch_geospatial_data(parsed) | |
| if gdf.empty: | |
| return None, None, None, "No data found for your query. Try different locations or indicators.", None, None | |
| # Map LLM-parsed indicators to actual column names | |
| indicator_mapping = { | |
| 'gdp': 'gdp_md_est', | |
| 'gdp per capita': 'gdp_per_capita', | |
| 'population': 'pop_est', | |
| 'population density': 'pop_density', | |
| 'pop_density': 'pop_density', | |
| 'gdp_md_est': 'gdp_md_est', | |
| 'pop_est': 'pop_est', | |
| 'gdp_per_capita': 'gdp_per_capita', | |
| # Additional mappings for common LLM responses | |
| 'economic': 'gdp_md_est', | |
| 'economic indicators': 'gdp_md_est', | |
| 'economy': 'gdp_md_est', | |
| 'economics': 'gdp_md_est', | |
| 'gross domestic product': 'gdp_md_est', | |
| 'density': 'pop_density', | |
| 'people': 'pop_est', | |
| 'inhabitants': 'pop_est' | |
| } | |
| # Get indicators from LLM parsing - support multiple indicators | |
| llm_indicators = parsed.get('indicators', []) | |
| mapped_indicators = [] | |
| seen = set() # Avoid duplicates | |
| for ind in llm_indicators: | |
| ind_lower = ind.lower().strip() | |
| if ind_lower in indicator_mapping: | |
| col = indicator_mapping[ind_lower] | |
| if col not in seen: | |
| mapped_indicators.append(col) | |
| seen.add(col) | |
| # Default to population if no valid indicator found | |
| if not mapped_indicators: | |
| mapped_indicators = ['pop_est'] | |
| # Get display names for all indicators | |
| indicator_display = ', '.join(llm_indicators) if llm_indicators else 'Population' | |
| # Detect if this is a specific country query (not continent/regional) | |
| # Use country colors in charts when querying specific countries | |
| locations = parsed.get('locations', []) | |
| continents = {'Africa', 'Asia', 'Europe', 'North America', 'South America', 'Oceania', 'Antarctica'} | |
| use_country_colors = bool(locations) and not any(loc in continents for loc in locations) | |
| # Generate outputs based on format | |
| map_html = None | |
| chart_fig = None | |
| table_df = None | |
| map_file = None | |
| csv_file = None | |
| summary = f"๐ **Query:** {user_query}\n\n" | |
| summary += f"๐ **Locations:** {', '.join(parsed.get('locations', ['Global']))}\n" | |
| summary += f"๐ **Indicators:** {indicator_display}\n" | |
| summary += f"๐ **Countries found:** {len(gdf)}\n\n" | |
| summary += f"โ๏ธ **Options:** Chart: {chart_type} | Map: {map_style} | Top N: {top_n}" | |
| # Apply top_n limit for visualizations (sort by primary indicator) | |
| primary_indicator = mapped_indicators[0] | |
| gdf_sorted = gdf.sort_values(primary_indicator, ascending=False) | |
| gdf_top_n = gdf_sorted.head(int(top_n)) | |
| if output_format in ['All', 'Map']: | |
| m = create_interactive_map(gdf_top_n, mapped_indicators, map_style, choropleth_color) | |
| map_html = m._repr_html_() | |
| # Save map to temp file for download | |
| map_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w', encoding='utf-8') | |
| m.save(map_file.name) | |
| map_file = map_file.name | |
| if output_format in ['All', 'Chart']: | |
| chart_fig = create_chart(gdf, mapped_indicators, chart_type, color_scheme, int(top_n), use_country_colors) | |
| if output_format in ['All', 'Table']: | |
| table_df = create_data_table(gdf) | |
| # Save table to temp CSV file for download | |
| csv_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv', mode='w', encoding='utf-8') | |
| table_df.to_csv(csv_file.name, index=False) | |
| csv_file = csv_file.name | |
| return map_html, chart_fig, table_df, summary, map_file, csv_file | |
| except Exception as e: | |
| error_msg = f"Error processing query: {str(e)}\n\nPlease try rephrasing your query." | |
| return None, None, None, error_msg, None, None | |
| # Gradio Interface | |
| def create_interface(): | |
| with gr.Blocks(title="Geospatial AI Query System") as demo: | |
| gr.Markdown(""" | |
| # ๐ Geospatial AI Query System | |
| ### Natural Language Interface for Geographic Data | |
| Ask questions about countries, regions, and global indicators using natural language. | |
| โน๏ธ *Note: Current data includes latest available statistics. Time-range queries will filter when historical data is available.* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| query_input = gr.Textbox( | |
| label="Your Query", | |
| placeholder="E.g., Show me GDP and population of BRICS countries", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| output_format = gr.Radio( | |
| choices=['All', 'Map', 'Chart', 'Table'], | |
| value='All', | |
| label="Output Format" | |
| ) | |
| # Advanced Options in Accordion | |
| with gr.Accordion("โ๏ธ Advanced Options", open=False): | |
| with gr.Row(): | |
| chart_type = gr.Dropdown( | |
| choices=['bar', 'horizontal_bar', 'scatter', 'pie', 'treemap', 'bubble'], | |
| value='bar', | |
| label="๐ Chart Type" | |
| ) | |
| map_style = gr.Dropdown( | |
| choices=list(MAP_STYLES.keys()), | |
| value='Light', | |
| label="๐บ๏ธ Map Style" | |
| ) | |
| with gr.Row(): | |
| color_scheme = gr.Dropdown( | |
| choices=list(COLOR_SCHEMES.keys()), | |
| value='Default', | |
| label="๐จ Chart Colors" | |
| ) | |
| choropleth_color = gr.Dropdown( | |
| choices=list(CHOROPLETH_COLORS.keys()), | |
| value='Yellow-Orange-Red', | |
| label="๐ Map Colors" | |
| ) | |
| with gr.Row(): | |
| top_n = gr.Slider( | |
| minimum=5, | |
| maximum=50, | |
| value=20, | |
| step=5, | |
| label="๐ข Top N Countries" | |
| ) | |
| submit_btn = gr.Button("๐ Analyze", variant="primary", size="lg") | |
| gr.Markdown("### Results") | |
| summary_output = gr.Textbox(label="Query Summary", lines=4) | |
| with gr.Tabs(): | |
| with gr.Tab("๐ Chart"): | |
| chart_output = gr.Plot(label="Interactive Chart") | |
| with gr.Tab("๐บ๏ธ Map"): | |
| map_output = gr.HTML(label="Interactive Map") | |
| map_download = gr.File(label="๐ฅ Download Map (HTML)", visible=True) | |
| with gr.Tab("๐ Table"): | |
| table_output = gr.Dataframe(label="Data Table") | |
| csv_download = gr.File(label="๐ฅ Download Table (CSV)", visible=True) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["Show me population of Asian countries", "All"], | |
| ["Compare GDP of European nations", "Chart"], | |
| ["What's the population density in Africa?", "Map"], | |
| ["Display GDP per capita for South American countries", "Table"], | |
| ["Show me the top 10 economies in the world", "Chart"], | |
| ["Compare BRICS countries by GDP", "All"] | |
| ], | |
| inputs=[query_input, output_format] | |
| ) | |
| # Event handler | |
| submit_btn.click( | |
| fn=process_query, | |
| inputs=[query_input, output_format, chart_type, map_style, color_scheme, choropleth_color, top_n], | |
| outputs=[map_output, chart_output, table_output, summary_output, map_download, csv_download] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **About:** This app uses LLMs to parse natural language queries and visualize global geospatial data. | |
| **Data Sources:** Natural Earth, World Bank Open Data | |
| **Built by:** [rifatSDAS](https://github.com/rifatSDAS) | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| # Enable queue for better concurrency handling on HF Spaces | |
| demo.queue(default_concurrency_limit=10) | |
| demo.launch(theme=gr.themes.Soft()) | |
| # To enable Progressive Web App (PWA) features, uncomment the line below | |
| # demo.launch(theme=gr.themes.Soft(), pwa=True) | |