rifatSDAS's picture
Fix map legend and global query issues
d5a8545
import gradio as gr
import pandas as pd
import geopandas as gpd
import folium
from folium import plugins
import plotly.express as px
import plotly.graph_objects as go
from huggingface_hub import InferenceClient
import json
import os
import tempfile
import io
from datetime import datetime
import numpy as np
from pathlib import Path
import warnings
import logging
# Suppress GeoPandas CRS warnings (area/centroid calculations are approximate for demo purposes)
warnings.filterwarnings('ignore', message='.*Geometry is in a geographic CRS.*')
# Suppress asyncio cleanup warnings in Python 3.13+ (harmless but noisy)
warnings.filterwarnings('ignore', message='.*Invalid file descriptor.*')
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
import branca.colormap as cm
import zipfile
import urllib.request
def download_natural_earth_data(data_dir: Path) -> Path:
"""Download Natural Earth countries shapefile if not present."""
shp_file = data_dir / "ne_110m_admin_0_countries.shp"
if shp_file.exists():
return shp_file
# Create data directory
data_dir.mkdir(parents=True, exist_ok=True)
# Natural Earth 110m countries download URL
url = "https://naciscdn.org/naturalearth/110m/cultural/ne_110m_admin_0_countries.zip"
zip_path = data_dir / "ne_110m_admin_0_countries.zip"
print(f"Downloading Natural Earth data from {url}...")
urllib.request.urlretrieve(url, zip_path)
print(f"Extracting to {data_dir}...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(data_dir)
# Clean up zip file
zip_path.unlink()
print("Natural Earth data ready.")
return shp_file
def format_number(num):
"""Format large numbers with K/M/B/T suffixes for better readability."""
if num is None or (isinstance(num, float) and np.isnan(num)):
return 'N/A'
abs_num = abs(num)
if abs_num >= 1e12:
return f'{num/1e12:.1f}T'
elif abs_num >= 1e9:
return f'{num/1e9:.1f}B'
elif abs_num >= 1e6:
return f'{num/1e6:.1f}M'
elif abs_num >= 1e3:
return f'{num/1e3:.1f}K'
else:
return f'{num:.1f}'
# Path to local Natural Earth data (downloaded at runtime if not present)
DATA_DIR = Path(__file__).parent / "data" / "ne_110m_admin_0_countries"
NATURAL_EARTH_SHP = download_natural_earth_data(DATA_DIR)
# Initialize HF Inference Client
client = InferenceClient(token=os.environ.get("HF_TOKEN"))
# ===== UI/UX Enhancement Constants =====
MAP_STYLES = {
"Light": "CartoDB positron",
"Dark": "CartoDB dark_matter",
"Street": "OpenStreetMap",
"Satellite": "Esri.WorldImagery"
}
COLOR_SCHEMES = {
"Default": px.colors.qualitative.Plotly,
"Vivid": px.colors.qualitative.Vivid,
"Pastel": px.colors.qualitative.Pastel,
"Bold": px.colors.qualitative.Bold,
"Earth": px.colors.qualitative.Safe
}
CHOROPLETH_COLORS = {
"Yellow-Orange-Red": "YlOrRd",
"Yellow-Green-Blue": "YlGnBu",
"Purple-Red": "PuRd",
"Blue-Purple": "BuPu",
"Greens": "Greens",
"Blues": "Blues",
"Oranges": "OrRd",
"Spectral": "Spectral"
}
INDICATORS = {
"Population": "pop_est",
"GDP (Million $)": "gdp_md_est",
"Population Density": "pop_density",
"GDP per Capita": "gdp_per_capita"
}
# Global cache for world data
_world_data_cache = None
def load_world_data():
"""Load world countries geospatial data"""
global _world_data_cache
if _world_data_cache is None:
raw = gpd.read_file(NATURAL_EARTH_SHP)
# Select only the columns we need (using original uppercase names)
# and rename them to match expected lowercase names
_world_data_cache = raw[['NAME', 'CONTINENT', 'POP_EST', 'GDP_MD', 'geometry']].copy()
_world_data_cache.columns = ['name', 'continent', 'pop_est', 'gdp_md_est', 'geometry']
return _world_data_cache
def parse_query_with_llm(user_query):
"""
Use LLM to parse natural language query into structured format
"""
system_prompt = """You are a geospatial and geographic data query parser. Extract structured information from user queries.
Response format (JSON only):
{
"locations": ["country/region names"],
"indicators": ["GDP", "population", "CO2 emissions", etc.],
"time_range": {"start": "YYYY", "end": "YYYY"},
"visualization": "map/chart/table",
"aggregation": "sum/average/comparison",
"query_type": "single_country/multi_country/regional/global"
}
Examples:
- "Show me GDP of Asian countries" โ†’ locations: Asia, indicators: GDP, visualization: chart
- "Compare population density in Europe vs Africa" โ†’ locations: [Europe, Africa], indicators: population density
- "Environmental data for Brazil over last decade" โ†’ locations: [Brazil], indicators: environmental
Return ONLY valid JSON, no explanations."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Parse this query: {user_query}"}
]
try:
response = client.chat_completion(
messages=messages,
model="meta-llama/Llama-3.1-8B-Instruct",
max_tokens=500,
temperature=0.1
)
raw_content = response.choices[0].message.content
print(f"LLM raw response: {raw_content}")
parsed = json.loads(raw_content)
print(f"Parsed query: {parsed}")
return parsed
except Exception as e:
print(f"LLM parsing error: {e}")
print(f"HF_TOKEN set: {bool(os.environ.get('HF_TOKEN'))}")
return {
"locations": [],
"indicators": ["population", "gdp_md_est"],
"visualization": "table",
"query_type": "global"
}
def fetch_geospatial_data(parsed_query):
"""
Fetch and process geospatial data based on parsed query
"""
world = load_world_data()
# Filter by locations
locations = parsed_query.get("locations", [])
# Treat "global", "world", "worldwide", "all" as requests for all data
global_terms = {"global", "world", "worldwide", "all", "earth", "globe"}
is_global_query = not locations or (len(locations) == 1 and locations[0].lower() in global_terms)
if not is_global_query:
# Filter by continent or country
mask = world['continent'].isin(locations) | world['name'].isin(locations)
filtered_data = world[mask].copy()
else:
filtered_data = world.copy()
# Add computed indicators
filtered_data.loc[:, 'pop_density'] = filtered_data['pop_est'] / filtered_data['geometry'].area * 1000000
filtered_data.loc[:, 'gdp_per_capita'] = filtered_data['gdp_md_est'] / filtered_data['pop_est'] * 1000000
return filtered_data
def create_interactive_map(gdf, indicators, map_style='Light', color_scale='Yellow-Orange-Red'):
"""
Create an interactive Folium map with customizable style and colors.
Supports multiple indicators with toggleable layer groups.
Args:
gdf: GeoDataFrame with country data
indicators: Single indicator string or list of indicator column names
map_style: Map tile style
color_scale: Color scheme for choropleth
"""
# Ensure indicators is a list
if isinstance(indicators, str):
indicators = [indicators]
# Calculate center
center_lat = gdf.geometry.centroid.y.mean()
center_lon = gdf.geometry.centroid.x.mean()
# Get tile style
tiles = MAP_STYLES.get(map_style, 'CartoDB positron')
# Create map
m = folium.Map(
location=[center_lat, center_lon],
zoom_start=2,
tiles=tiles
)
# Color schemes for different indicators (these match branca colormap names exactly)
color_scheme_list = ['YlOrRd', 'YlGnBu', 'Greens', 'Blues', 'PuRd', 'OrRd']
# Map color scheme names to branca colormap objects
color_map_dict = {
'YlOrRd': cm.linear.YlOrRd_09,
'YlGnBu': cm.linear.YlGnBu_09,
'Blues': cm.linear.Blues_09,
'Greens': cm.linear.Greens_09,
'Reds': cm.linear.Reds_09,
'PuRd': cm.linear.PuRd_09,
'OrRd': cm.linear.OrRd_09,
'BuPu': cm.linear.BuPu_09,
'Spectral': cm.linear.Spectral_11
}
# Get user-selected color scheme for first indicator
fill_color = CHOROPLETH_COLORS.get(color_scale, 'YlOrRd')
# Store legends to add after all layers (for proper stacking)
legends_html = []
# Create a layer group for each indicator
for idx, indicator in enumerate(indicators):
# Use user-selected color for first indicator, cycle through others for additional indicators
if idx == 0:
current_color_name = fill_color
else:
current_color_name = color_scheme_list[idx % len(color_scheme_list)]
# Get min/max for this indicator
valid_values = gdf[indicator].dropna()
if valid_values.empty:
vmin, vmax = 0, 1
else:
vmin = float(valid_values.min())
vmax = float(valid_values.max())
# Ensure valid range
if np.isnan(vmin) or np.isnan(vmax) or np.isinf(vmin) or np.isinf(vmax):
vmin, vmax = 0, 1
if vmax <= vmin:
vmax = vmin + 1
# Create feature group for this indicator
indicator_name = indicator.replace('_', ' ').title()
feature_group = folium.FeatureGroup(name=f"๐Ÿ“Š {indicator_name}", show=(idx == 0))
# Get the base branca colormap and scale it to our data range
base_colormap = color_map_dict.get(current_color_name, cm.linear.YlOrRd_09)
scaled_colormap = base_colormap.scale(vmin, vmax)
scaled_colormap.caption = f"{indicator_name} ({format_number(vmin)} - {format_number(vmax)})"
# Create style function using the SAME scaled colormap for exact color matching
# Use default arguments to capture current loop values
def make_style_function(cmap, ind, v_min, v_max):
def style_function(feature):
value = feature['properties'].get(ind)
if value is None or (isinstance(value, float) and np.isnan(value)):
return {
'fillColor': 'lightgray',
'fillOpacity': 0.7,
'color': 'white',
'weight': 0.5
}
# Clamp value to range
clamped_value = max(v_min, min(v_max, value))
return {
'fillColor': cmap(clamped_value),
'fillOpacity': 0.7,
'color': 'white',
'weight': 0.5
}
return style_function
style_fn = make_style_function(scaled_colormap, indicator, vmin, vmax)
# Convert GeoDataFrame to GeoJSON with indicator values in properties
geojson_data = json.loads(gdf.to_json())
# Add the indicator value to each feature's properties
for i, feature in enumerate(geojson_data['features']):
ind_value = gdf.iloc[i][indicator]
feature['properties'][indicator] = float(ind_value) if pd.notna(ind_value) else None
# Use GeoJson instead of Choropleth for precise color control
geojson_layer = folium.GeoJson(
geojson_data,
style_function=style_fn,
tooltip=folium.GeoJsonTooltip(
fields=['name', indicator],
aliases=['Country:', f'{indicator_name}:'],
style="background-color: white; color: #333333; font-family: arial; font-size: 12px; padding: 10px;"
)
)
geojson_layer.add_to(feature_group)
feature_group.add_to(m)
# Build legend HTML for this indicator with vertical offset
# Only show first indicator's legend by default (others shown when layer toggled)
legend_bottom = 50 + (idx * 80) # Stack legends vertically
legend_html = f'''
<div id="legend-{idx}" style="
position: fixed;
bottom: {legend_bottom}px;
right: 10px;
z-index: 1000;
background-color: white;
padding: 8px 12px;
border-radius: 5px;
box-shadow: 0 2px 6px rgba(0,0,0,0.3);
font-size: 11px;
font-family: Arial, sans-serif;
max-width: 200px;
">
<div style="font-weight: bold; margin-bottom: 5px;">{indicator_name}</div>
<div style="display: flex; align-items: center;">
<div style="
width: 150px;
height: 12px;
background: linear-gradient(to right, {scaled_colormap(vmin)}, {scaled_colormap((vmin+vmax)/2)}, {scaled_colormap(vmax)});
border-radius: 2px;
"></div>
</div>
<div style="display: flex; justify-content: space-between; margin-top: 3px;">
<span>{format_number(vmin)}</span>
<span>{format_number(vmax)}</span>
</div>
</div>
'''
legends_html.append(legend_html)
# Add all legends to the map
for legend_html in legends_html:
m.get_root().html.add_child(folium.Element(legend_html))
# Add tooltips with all indicator values in a toggleable layer group
markers_group = folium.FeatureGroup(name="๐Ÿ“ Info Markers", show=True)
for _, row in gdf.iterrows():
popup_html = f"<b>{row['name']}</b><br>"
popup_html += f"Continent: {row['continent']}<br><hr>"
popup_html += f"Population: {format_number(row['pop_est'])}<br>"
popup_html += f"GDP: ${format_number(row['gdp_md_est'])}M<br>"
if 'pop_density' in row:
popup_html += f"Pop Density: {row['pop_density']:.1f}/kmยฒ<br>"
if 'gdp_per_capita' in row:
popup_html += f"GDP/Capita: ${row['gdp_per_capita']:,.0f}<br>"
folium.Marker(
location=[row.geometry.centroid.y, row.geometry.centroid.x],
popup=popup_html,
icon=folium.Icon(icon='info-sign', color='blue')
).add_to(markers_group)
markers_group.add_to(m)
# Add layer control
folium.LayerControl(collapsed=False).add_to(m)
return m
def create_chart(df, indicators, chart_type='bar', color_scheme='Default', top_n=20, use_country_colors=False):
"""
Create interactive Plotly charts with customizable options.
Supports multiple indicators with grouped/stacked visualizations.
"""
# Get color sequence
colors = COLOR_SCHEMES.get(color_scheme, px.colors.qualitative.Plotly)
# Sort and limit data
sorted_df = df.sort_values(indicators[0], ascending=False).head(top_n)
# Determine color column: use country name for specific country queries, continent for regional
color_col = 'name' if use_country_colors else 'continent'
# Check if we have multiple indicators for grouped visualization
has_multiple_indicators = len(indicators) > 1
if chart_type == 'bar':
if has_multiple_indicators:
# Create grouped bar chart for multiple indicators
fig = go.Figure()
for i, ind in enumerate(indicators):
fig.add_trace(go.Bar(
name=ind.replace('_', ' ').title(),
x=sorted_df['name'],
y=sorted_df[ind],
marker_color=colors[i % len(colors)]
))
fig.update_layout(
barmode='group',
title=f'Comparison: {", ".join([i.replace("_", " ").title() for i in indicators])}',
xaxis_title='Country',
yaxis_title='Value',
height=500
)
else:
fig = px.bar(
sorted_df,
x='name',
y=indicators[0],
color=color_col,
title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}',
labels={'name': 'Country', indicators[0]: indicators[0].replace('_', ' ').title(), color_col: ''},
color_discrete_sequence=colors,
height=500
)
elif chart_type == 'horizontal_bar':
if has_multiple_indicators:
# Create grouped horizontal bar chart
fig = go.Figure()
for i, ind in enumerate(indicators):
fig.add_trace(go.Bar(
name=ind.replace('_', ' ').title(),
y=sorted_df['name'],
x=sorted_df[ind],
orientation='h',
marker_color=colors[i % len(colors)]
))
fig.update_layout(
barmode='group',
title=f'Comparison: {", ".join([i.replace("_", " ").title() for i in indicators])}',
xaxis_title='Value',
yaxis_title='Country',
height=600
)
else:
fig = px.bar(
sorted_df,
y='name',
x=indicators[0],
color=color_col,
title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}',
labels={'name': 'Country', indicators[0]: indicators[0].replace('_', ' ').title(), color_col: ''},
color_discrete_sequence=colors,
orientation='h',
height=600
)
fig.update_layout(yaxis={'categoryorder': 'total ascending'})
elif chart_type == 'scatter':
x_col = indicators[0] if len(indicators) > 0 else 'gdp_md_est'
y_col = indicators[1] if len(indicators) > 1 else 'pop_est'
fig = px.scatter(
df,
x=x_col,
y=y_col,
size='pop_est',
color=color_col,
hover_name='name',
title=f'{x_col.replace("_", " ").title()} vs {y_col.replace("_", " ").title()}',
labels={
x_col: x_col.replace('_', ' ').title(),
y_col: y_col.replace('_', ' ').title(),
color_col: ''
},
color_discrete_sequence=colors,
height=500
)
elif chart_type == 'pie':
if has_multiple_indicators:
# Create subplots for multiple indicators
from plotly.subplots import make_subplots
fig = make_subplots(
rows=1, cols=len(indicators),
specs=[[{'type': 'pie'}] * len(indicators)],
subplot_titles=[ind.replace('_', ' ').title() for ind in indicators]
)
for i, ind in enumerate(indicators):
pie_trace = go.Pie(
values=sorted_df[ind],
labels=sorted_df['name'],
name=ind.replace('_', ' ').title(),
textposition='inside',
textinfo='percent+label',
marker_colors=colors[:len(sorted_df)]
)
fig.add_trace(pie_trace, row=1, col=i+1)
fig.update_layout(
title=f'Top {top_n} Countries: {", ".join([i.replace("_", " ").title() for i in indicators])}',
height=500,
showlegend=False
)
else:
fig = px.pie(
sorted_df,
values=indicators[0],
names='name',
title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}',
color_discrete_sequence=colors,
height=500
)
fig.update_traces(textposition='inside', textinfo='percent+label')
elif chart_type == 'treemap':
if has_multiple_indicators:
# Create subplots for multiple indicators
from plotly.subplots import make_subplots
fig = make_subplots(
rows=1, cols=len(indicators),
specs=[[{'type': 'treemap'}] * len(indicators)],
subplot_titles=[ind.replace('_', ' ').title() for ind in indicators]
)
for i, ind in enumerate(indicators):
# Build treemap data manually for subplot
treemap_trace = go.Treemap(
labels=sorted_df['name'].tolist() + sorted_df['continent'].unique().tolist(),
parents=sorted_df['continent'].tolist() + [''] * len(sorted_df['continent'].unique()),
values=sorted_df[ind].tolist() + [0] * len(sorted_df['continent'].unique()),
name=ind.replace('_', ' ').title(),
textinfo='label+value',
hovertemplate='<b>%{label}</b><br>Value: %{value:,.0f}<extra></extra>'
)
fig.add_trace(treemap_trace, row=1, col=i+1)
fig.update_layout(
title=f'Top {top_n} Countries: {", ".join([i.replace("_", " ").title() for i in indicators])}',
height=600
)
else:
fig = px.treemap(
sorted_df,
path=['continent', 'name'],
values=indicators[0],
title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}',
color='continent',
color_discrete_sequence=colors,
height=600,
hover_data={indicators[0]: ':,.0f'}
)
elif chart_type == 'bubble':
x_col = indicators[0] if len(indicators) > 0 else 'gdp_md_est'
y_col = indicators[1] if len(indicators) > 1 else 'pop_est'
size_col = indicators[0]
fig = px.scatter(
df,
x=x_col,
y=y_col,
size=size_col,
color=color_col,
hover_name='name',
title=f'Bubble Chart: {x_col.replace("_", " ").title()} vs {y_col.replace("_", " ").title()}',
labels={x_col: x_col.replace('_', ' ').title(), y_col: y_col.replace('_', ' ').title(), color_col: ''},
color_discrete_sequence=colors,
size_max=60,
height=500
)
else: # default bar
fig = px.bar(
sorted_df,
x='name',
y=indicators[0],
color=color_col,
title=f'Top {top_n} Countries by {indicators[0].replace("_", " ").title()}',
color_discrete_sequence=colors,
height=500
)
# Remove legend title for cleaner appearance
fig.update_layout(
xaxis_tickangle=-45,
template='plotly_white',
legend_title_text=''
)
return fig
def create_data_table(df):
"""
Create formatted data table
"""
# Select relevant columns
display_cols = ['name', 'continent', 'pop_est', 'gdp_md_est', 'pop_density', 'gdp_per_capita']
table_df = df[display_cols].copy()
# Rename columns
table_df.columns = ['Country', 'Continent', 'Population', 'GDP (Million $)',
'Pop. Density (per kmยฒ)', 'GDP per Capita ($)']
# Format numbers
table_df['Population'] = table_df['Population'].apply(lambda x: f'{x:,.0f}')
table_df['GDP (Million $)'] = table_df['GDP (Million $)'].apply(lambda x: f'${x:,.0f}')
table_df['Pop. Density (per kmยฒ)'] = table_df['Pop. Density (per kmยฒ)'].apply(lambda x: f'{x:.2f}')
table_df['GDP per Capita ($)'] = table_df['GDP per Capita ($)'].apply(lambda x: f'${x:,.2f}')
return table_df.sort_values('Population', ascending=False).head(50)
def process_query(user_query, output_format, chart_type, map_style, color_scheme, choropleth_color, top_n):
"""
Main processing function with advanced options
"""
try:
# Parse query with LLM
parsed = parse_query_with_llm(user_query)
# Fetch data
gdf = fetch_geospatial_data(parsed)
if gdf.empty:
return None, None, None, "No data found for your query. Try different locations or indicators.", None, None
# Map LLM-parsed indicators to actual column names
indicator_mapping = {
'gdp': 'gdp_md_est',
'gdp per capita': 'gdp_per_capita',
'population': 'pop_est',
'population density': 'pop_density',
'pop_density': 'pop_density',
'gdp_md_est': 'gdp_md_est',
'pop_est': 'pop_est',
'gdp_per_capita': 'gdp_per_capita',
# Additional mappings for common LLM responses
'economic': 'gdp_md_est',
'economic indicators': 'gdp_md_est',
'economy': 'gdp_md_est',
'economics': 'gdp_md_est',
'gross domestic product': 'gdp_md_est',
'density': 'pop_density',
'people': 'pop_est',
'inhabitants': 'pop_est'
}
# Get indicators from LLM parsing - support multiple indicators
llm_indicators = parsed.get('indicators', [])
mapped_indicators = []
seen = set() # Avoid duplicates
for ind in llm_indicators:
ind_lower = ind.lower().strip()
if ind_lower in indicator_mapping:
col = indicator_mapping[ind_lower]
if col not in seen:
mapped_indicators.append(col)
seen.add(col)
# Default to population if no valid indicator found
if not mapped_indicators:
mapped_indicators = ['pop_est']
# Get display names for all indicators
indicator_display = ', '.join(llm_indicators) if llm_indicators else 'Population'
# Detect if this is a specific country query (not continent/regional)
# Use country colors in charts when querying specific countries
locations = parsed.get('locations', [])
continents = {'Africa', 'Asia', 'Europe', 'North America', 'South America', 'Oceania', 'Antarctica'}
use_country_colors = bool(locations) and not any(loc in continents for loc in locations)
# Generate outputs based on format
map_html = None
chart_fig = None
table_df = None
map_file = None
csv_file = None
summary = f"๐Ÿ” **Query:** {user_query}\n\n"
summary += f"๐Ÿ“ **Locations:** {', '.join(parsed.get('locations', ['Global']))}\n"
summary += f"๐Ÿ“Š **Indicators:** {indicator_display}\n"
summary += f"๐ŸŒ **Countries found:** {len(gdf)}\n\n"
summary += f"โš™๏ธ **Options:** Chart: {chart_type} | Map: {map_style} | Top N: {top_n}"
# Apply top_n limit for visualizations (sort by primary indicator)
primary_indicator = mapped_indicators[0]
gdf_sorted = gdf.sort_values(primary_indicator, ascending=False)
gdf_top_n = gdf_sorted.head(int(top_n))
if output_format in ['All', 'Map']:
m = create_interactive_map(gdf_top_n, mapped_indicators, map_style, choropleth_color)
map_html = m._repr_html_()
# Save map to temp file for download
map_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w', encoding='utf-8')
m.save(map_file.name)
map_file = map_file.name
if output_format in ['All', 'Chart']:
chart_fig = create_chart(gdf, mapped_indicators, chart_type, color_scheme, int(top_n), use_country_colors)
if output_format in ['All', 'Table']:
table_df = create_data_table(gdf)
# Save table to temp CSV file for download
csv_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv', mode='w', encoding='utf-8')
table_df.to_csv(csv_file.name, index=False)
csv_file = csv_file.name
return map_html, chart_fig, table_df, summary, map_file, csv_file
except Exception as e:
error_msg = f"Error processing query: {str(e)}\n\nPlease try rephrasing your query."
return None, None, None, error_msg, None, None
# Gradio Interface
def create_interface():
with gr.Blocks(title="Geospatial AI Query System") as demo:
gr.Markdown("""
# ๐ŸŒ Geospatial AI Query System
### Natural Language Interface for Geographic Data
Ask questions about countries, regions, and global indicators using natural language.
โ„น๏ธ *Note: Current data includes latest available statistics. Time-range queries will filter when historical data is available.*
""")
with gr.Row():
with gr.Column(scale=3):
query_input = gr.Textbox(
label="Your Query",
placeholder="E.g., Show me GDP and population of BRICS countries",
lines=2
)
with gr.Column(scale=1):
output_format = gr.Radio(
choices=['All', 'Map', 'Chart', 'Table'],
value='All',
label="Output Format"
)
# Advanced Options in Accordion
with gr.Accordion("โš™๏ธ Advanced Options", open=False):
with gr.Row():
chart_type = gr.Dropdown(
choices=['bar', 'horizontal_bar', 'scatter', 'pie', 'treemap', 'bubble'],
value='bar',
label="๐Ÿ“Š Chart Type"
)
map_style = gr.Dropdown(
choices=list(MAP_STYLES.keys()),
value='Light',
label="๐Ÿ—บ๏ธ Map Style"
)
with gr.Row():
color_scheme = gr.Dropdown(
choices=list(COLOR_SCHEMES.keys()),
value='Default',
label="๐ŸŽจ Chart Colors"
)
choropleth_color = gr.Dropdown(
choices=list(CHOROPLETH_COLORS.keys()),
value='Yellow-Orange-Red',
label="๐ŸŒˆ Map Colors"
)
with gr.Row():
top_n = gr.Slider(
minimum=5,
maximum=50,
value=20,
step=5,
label="๐Ÿ”ข Top N Countries"
)
submit_btn = gr.Button("๐Ÿ” Analyze", variant="primary", size="lg")
gr.Markdown("### Results")
summary_output = gr.Textbox(label="Query Summary", lines=4)
with gr.Tabs():
with gr.Tab("๐Ÿ“Š Chart"):
chart_output = gr.Plot(label="Interactive Chart")
with gr.Tab("๐Ÿ—บ๏ธ Map"):
map_output = gr.HTML(label="Interactive Map")
map_download = gr.File(label="๐Ÿ“ฅ Download Map (HTML)", visible=True)
with gr.Tab("๐Ÿ“‹ Table"):
table_output = gr.Dataframe(label="Data Table")
csv_download = gr.File(label="๐Ÿ“ฅ Download Table (CSV)", visible=True)
# Examples
gr.Examples(
examples=[
["Show me population of Asian countries", "All"],
["Compare GDP of European nations", "Chart"],
["What's the population density in Africa?", "Map"],
["Display GDP per capita for South American countries", "Table"],
["Show me the top 10 economies in the world", "Chart"],
["Compare BRICS countries by GDP", "All"]
],
inputs=[query_input, output_format]
)
# Event handler
submit_btn.click(
fn=process_query,
inputs=[query_input, output_format, chart_type, map_style, color_scheme, choropleth_color, top_n],
outputs=[map_output, chart_output, table_output, summary_output, map_download, csv_download]
)
gr.Markdown("""
---
**About:** This app uses LLMs to parse natural language queries and visualize global geospatial data.
**Data Sources:** Natural Earth, World Bank Open Data
**Built by:** [rifatSDAS](https://github.com/rifatSDAS)
""")
return demo
if __name__ == "__main__":
demo = create_interface()
# Enable queue for better concurrency handling on HF Spaces
demo.queue(default_concurrency_limit=10)
demo.launch(theme=gr.themes.Soft())
# To enable Progressive Web App (PWA) features, uncomment the line below
# demo.launch(theme=gr.themes.Soft(), pwa=True)