EarthExplorer / app.py
Anonymous
fix: print top 5 metadata
a439024
import gradio as gr
import torch
import time
import os
import tempfile
import zipfile
import numpy as np
import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed
# Import custom modules
from models.siglip_model import SigLIPModel
from models.satclip_model import SatCLIPModel
from models.farslip_model import FarSLIPModel
from models.dinov2_model import DINOv2Model
from models.load_config import load_and_process_config
from visualize import format_results_for_gallery, plot_top5_overview, plot_location_distribution, plot_global_map_static, plot_geographic_distribution
from data_utils import download_and_process_image, get_esri_satellite_image, get_placeholder_image
from PIL import Image as PILImage
from PIL import ImageDraw, ImageFont
# Configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on device: {device}")
# Load and process configuration
config = load_and_process_config()
print(config)
# Initialize Models
print("Initializing models...")
models = {}
# DINOv2
try:
if config and 'dinov2' in config:
models['DINOv2'] = DINOv2Model(
ckpt_path=config['dinov2'].get('ckpt_path'),
embedding_path=config['dinov2'].get('embedding_path'),
device=device
)
else:
models['DINOv2'] = DINOv2Model(device=device)
except Exception as e:
print(f"Failed to load DINOv2: {e}")
# SigLIP
try:
if config and 'siglip' in config:
models['SigLIP'] = SigLIPModel(
ckpt_path=config['siglip'].get('ckpt_path'),
tokenizer_path=config['siglip'].get('tokenizer_path'),
embedding_path=config['siglip'].get('embedding_path'),
device=device
)
else:
models['SigLIP'] = SigLIPModel(device=device)
except Exception as e:
print(f"Failed to load SigLIP: {e}")
# SatCLIP
try:
if config and 'satclip' in config:
models['SatCLIP'] = SatCLIPModel(
ckpt_path=config['satclip'].get('ckpt_path'),
embedding_path=config['satclip'].get('embedding_path'),
device=device
)
else:
models['SatCLIP'] = SatCLIPModel(device=device)
except Exception as e:
print(f"Failed to load SatCLIP: {e}")
# FarSLIP
try:
if config and 'farslip' in config:
models['FarSLIP'] = FarSLIPModel(
ckpt_path=config['farslip'].get('ckpt_path'),
model_name=config['farslip'].get('model_name'),
embedding_path=config['farslip'].get('embedding_path'),
device=device
)
else:
models['FarSLIP'] = FarSLIPModel(device=device)
except Exception as e:
print(f"Failed to load FarSLIP: {e}")
def get_active_model(model_name):
if model_name not in models:
return None, f"Model {model_name} not loaded."
return models[model_name], None
def combine_images(img1, img2):
if img1 is None: return img2
if img2 is None: return img1
# Resize to match width
w1, h1 = img1.size
w2, h2 = img2.size
new_w = max(w1, w2)
new_h1 = int(h1 * new_w / w1)
new_h2 = int(h2 * new_w / w2)
img1 = img1.resize((new_w, new_h1))
img2 = img2.resize((new_w, new_h2))
dst = PILImage.new('RGB', (new_w, new_h1 + new_h2), (255, 255, 255))
dst.paste(img1, (0, 0))
dst.paste(img2, (0, new_h1))
return dst
def create_text_image(text, size=(384, 384)):
img = PILImage.new('RGB', size, color=(240, 240, 240))
d = ImageDraw.Draw(img)
# Try to load a font, fallback to default
try:
# Try to find a font that supports larger size
font = ImageFont.truetype("DejaVuSans.ttf", 40)
except:
font = ImageFont.load_default()
# Wrap text simply
margin = 20
offset = 100
for line in text.split(','):
d.text((margin, offset), line.strip(), font=font, fill=(0, 0, 0))
offset += 50
d.text((margin, offset + 50), "Text Query", font=font, fill=(0, 0, 255))
return img
def fetch_top_k_images(top_indices, probs, df_embed, query_text=None):
"""
Fetches top-k images using actual dataset download (ModelScope) via download_and_process_image.
"""
results = []
# We can run this in parallel
with ThreadPoolExecutor(max_workers=5) as executor:
future_to_idx = {}
for i, idx in enumerate(top_indices):
row = df_embed.iloc[idx]
pid = row['product_id']
# Use download_and_process_image to get real data
future = executor.submit(download_and_process_image, pid, df_source=df_embed, verbose=False)
future_to_idx[future] = idx
for future in as_completed(future_to_idx):
idx = future_to_idx[future]
try:
img_384, img_full = future.result()
if img_384 is None:
# Fallback to Esri if download fails
print(f"Download failed for idx {idx}, falling back to Esri...")
row = df_embed.iloc[idx]
img_384 = get_esri_satellite_image(row['centre_lat'], row['centre_lon'], score=probs[idx], rank=0, query=query_text)
img_full = img_384
row = df_embed.iloc[idx]
results.append({
'image_384': img_384,
'image_full': img_full,
'score': probs[idx],
'lat': row['centre_lat'],
'lon': row['centre_lon'],
'id': row['product_id']
})
except Exception as e:
print(f"Error fetching image for idx {idx}: {e}")
# Sort results by score descending (since futures complete in random order)
results.sort(key=lambda x: x['score'], reverse=True)
return results
def get_all_results_metadata(model, filtered_indices, probs):
if len(filtered_indices) == 0:
return []
# Sort by score descending
filtered_scores = probs[filtered_indices]
sorted_order = np.argsort(filtered_scores)[::-1]
sorted_indices = filtered_indices[sorted_order]
# Extract from DataFrame
df_results = model.df_embed.iloc[sorted_indices].copy()
df_results['score'] = probs[sorted_indices]
# Rename columns
df_results = df_results.rename(columns={'product_id': 'id', 'centre_lat': 'lat', 'centre_lon': 'lon'})
# Convert to list of dicts
return df_results[['id', 'lat', 'lon', 'score']].to_dict('records')
def search_text(query, threshold, model_name):
model, error = get_active_model(model_name)
if error:
yield None, None, error, None, None, None, None
return
if not query:
yield None, None, "Please enter a query.", None, None, None, None
return
try:
timings = {}
# 1. Encode Text
yield None, None, "Encoding text...", None, None, None, None
t0 = time.time()
text_features = model.encode_text(query)
timings['Encoding'] = time.time() - t0
if text_features is None:
yield None, None, "Model does not support text encoding or is not initialized.", None, None, None, None
return
# 2. Search
yield None, None, "Encoding text... ✓\nRetrieving similar images...", None, None, None, None
t0 = time.time()
probs, filtered_indices, top_indices = model.search(text_features, top_percent=threshold/1000.0)
timings['Retrieval'] = time.time() - t0
if probs is None:
yield None, None, "Search failed (embeddings missing?).", None, None, None, None
return
# Show geographic distribution (not timed)
df_embed = model.df_embed
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to "{query}" ({model_name})')
# 3. Download Images
yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
top_indices = top_indices[:10]
results = fetch_top_k_images(top_indices, probs, df_embed, query_text=query)
timings['Download'] = time.time() - t0
# 4. Visualize - keep geo_dist_map visible
yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
fig_results = plot_top5_overview(None, results, query_info=query)
gallery_items = format_results_for_gallery(results)
timings['Visualization'] = time.time() - t0
# 5. Generate Final Status
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
all_results = get_all_results_metadata(model, filtered_indices, probs)
results_txt = format_results_to_text(all_results)
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
except Exception as e:
import traceback
traceback.print_exc()
yield None, None, f"Error: {str(e)}", None, None, None, None
def search_image(image_input, threshold, model_name):
model, error = get_active_model(model_name)
if error:
yield None, None, error, None, None, None, None
return
if image_input is None:
yield None, None, "Please upload an image.", None, None, None, None
return
try:
timings = {}
# 1. Encode Image
yield None, None, "Encoding image...", None, None, None, None
t0 = time.time()
image_features = model.encode_image(image_input)
timings['Encoding'] = time.time() - t0
if image_features is None:
yield None, None, "Model does not support image encoding.", None, None, None, None
return
# 2. Search
yield None, None, "Encoding image... ✓\nRetrieving similar images...", None, None, None, None
t0 = time.time()
probs, filtered_indices, top_indices = model.search(image_features, top_percent=threshold/1000.0)
timings['Retrieval'] = time.time() - t0
# Show geographic distribution (not timed)
df_embed = model.df_embed
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Input Image ({model_name})')
# 3. Download Images
yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
top_indices = top_indices[:6]
results = fetch_top_k_images(top_indices, probs, df_embed, query_text="Image Query")
timings['Download'] = time.time() - t0
# 4. Visualize - keep geo_dist_map visible
yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
fig_results = plot_top5_overview(image_input, results, query_info="Image Query")
gallery_items = format_results_for_gallery(results)
timings['Visualization'] = time.time() - t0
# 5. Generate Final Status
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
all_results = get_all_results_metadata(model, filtered_indices, probs)
results_txt = format_results_to_text(all_results[:50])
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
except Exception as e:
import traceback
traceback.print_exc()
yield None, None, f"Error: {str(e)}", None, None, None, None
def search_location(lat, lon, threshold):
model_name = "SatCLIP"
model, error = get_active_model(model_name)
if error:
yield None, None, error, None, None, None, None
return
try:
timings = {}
# 1. Encode Location
yield None, None, "Encoding location...", None, None, None, None
t0 = time.time()
loc_features = model.encode_location(float(lat), float(lon))
timings['Encoding'] = time.time() - t0
if loc_features is None:
yield None, None, "Location encoding failed.", None, None, None, None
return
# 2. Search
yield None, None, "Encoding location... ✓\nRetrieving similar images...", None, None, None, None
t0 = time.time()
probs, filtered_indices, top_indices = model.search(loc_features, top_percent=threshold/100.0)
timings['Retrieval'] = time.time() - t0
# 3. Generate Distribution Map (not timed for location distribution)
yield None, None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map...", None, None, None, None
df_embed = model.df_embed
top_10_indices = top_indices[:10]
top_10_results = []
for idx in top_10_indices:
row = df_embed.iloc[idx]
top_10_results.append({'lat': row['centre_lat'], 'lon': row['centre_lon']})
# Show geographic distribution (not timed)
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Location ({lat}, {lon})')
# 4. Download Images
yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
top_6_indices = top_indices[:6]
results = fetch_top_k_images(top_6_indices, probs, df_embed, query_text=f"Loc: {lat},{lon}")
# Get query tile
query_tile = None
try:
lats = pd.to_numeric(df_embed['centre_lat'], errors='coerce')
lons = pd.to_numeric(df_embed['centre_lon'], errors='coerce')
dists = (lats - float(lat))**2 + (lons - float(lon))**2
nearest_idx = dists.idxmin()
pid = df_embed.loc[nearest_idx, 'product_id']
query_tile, _ = download_and_process_image(pid, df_source=df_embed, verbose=False)
except Exception as e:
print(f"Error fetching nearest MajorTOM image: {e}")
if query_tile is None:
query_tile = get_placeholder_image(f"Query Location\n({lat}, {lon})")
timings['Download'] = time.time() - t0
# 5. Visualize - keep geo_dist_map visible
yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
fig_results = plot_top5_overview(query_tile, results, query_info=f"Loc: {lat},{lon}")
gallery_items = format_results_for_gallery(results)
timings['Visualization'] = time.time() - t0
# 6. Generate Final Status
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
all_results = get_all_results_metadata(model, filtered_indices, probs)
results_txt = format_results_to_text(all_results)
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
except Exception as e:
import traceback
traceback.print_exc()
yield None, None, f"Error: {str(e)}", None, None, None, None
def generate_status_msg(count, threshold, results):
status_msg = f"Found {count} matches in top {threshold*100:.0f}‰.\n\nTop {len(results)} similar images:\n"
for i, res in enumerate(results[:5]):
status_msg += f"{i+1}. Product ID: {res['id']}, Location: ({res['lat']:.4f}, {res['lon']:.4f}), Score: {res['score']:.4f}\n"
return status_msg
def get_initial_plot():
# Use FarSLIP as default for initial plot, fallback to SigLIP
df_vis = None
img = None
if 'DINOv2' in models and models['DINOv2'].df_embed is not None:
img, df_vis = plot_global_map_static(models['DINOv2'].df_embed)
# fig = plot_global_map(models['FarSLIP'].df_embed)
else:
img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
return gr.update(value=img, visible=True), [img], df_vis, gr.update(visible=False)
def handle_map_click(evt: gr.SelectData, df_vis):
if evt is None:
return None, None, None, "No point selected."
try:
x, y = evt.index[0], evt.index[1]
# Image dimensions (New)
img_width = 3000
img_height = 1500
# Scaled Margins (Proportional to 4000x2000)
left_margin = 110 * 0.75
right_margin = 110 * 0.75
top_margin = 100 * 0.75
bottom_margin = 67 * 0.75
plot_width = img_width - left_margin - right_margin
plot_height = img_height - top_margin - bottom_margin
# Adjust for aspect ratio preservation
map_aspect = 360.0 / 180.0 # 2.0
plot_aspect = plot_width / plot_height
if plot_aspect > map_aspect:
actual_map_width = plot_height * map_aspect
actual_map_height = plot_height
h_offset = (plot_width - actual_map_width) / 2
v_offset = 0
else:
actual_map_width = plot_width
actual_map_height = plot_width / map_aspect
h_offset = 0
v_offset = (plot_height - actual_map_height) / 2
# Calculate relative position within the plot area
x_in_plot = x - left_margin
y_in_plot = y - top_margin
# Check if click is within the actual map bounds
if (x_in_plot < h_offset or x_in_plot > h_offset + actual_map_width or
y_in_plot < v_offset or y_in_plot > v_offset + actual_map_height):
return None, None, None, "Click outside map area. Please click on the map."
# Calculate relative position within the map (0 to 1)
x_rel = (x_in_plot - h_offset) / actual_map_width
y_rel = (y_in_plot - v_offset) / actual_map_height
# Clamp to [0, 1]
x_rel = max(0, min(1, x_rel))
y_rel = max(0, min(1, y_rel))
# Convert to geographic coordinates
lon = x_rel * 360 - 180
lat = 90 - y_rel * 180
# Find nearest point in df_vis if available
pid = ""
if df_vis is not None:
dists = (df_vis['centre_lat'] - lat)**2 + (df_vis['centre_lon'] - lon)**2
min_idx = dists.idxmin()
nearest_row = df_vis.loc[min_idx]
if dists[min_idx] < 25:
lat = nearest_row['centre_lat']
lon = nearest_row['centre_lon']
pid = nearest_row['product_id']
except Exception as e:
print(f"Error handling click: {e}")
import traceback
traceback.print_exc()
return None, None, None, f"Error: {e}"
return lat, lon, pid, f"Selected Point: ({lat:.4f}, {lon:.4f})"
def download_image_by_location(lat, lon, pid, model_name):
"""Download and return the image at the specified location"""
if lat is None or lon is None:
return None, "Please specify coordinates first."
model, error = get_active_model(model_name)
if error:
return None, error
try:
# Convert to float to ensure proper formatting
lat = float(lat)
lon = float(lon)
# Find Product ID if not provided
if not pid:
df = model.df_embed
lats = pd.to_numeric(df['centre_lat'], errors='coerce')
lons = pd.to_numeric(df['centre_lon'], errors='coerce')
dists = (lats - lat)**2 + (lons - lon)**2
nearest_idx = dists.idxmin()
pid = df.loc[nearest_idx, 'product_id']
# Download image
img_384, _ = download_and_process_image(pid, df_source=model.df_embed, verbose=True)
if img_384 is None:
return None, f"Failed to download image for location ({lat:.4f}, {lon:.4f})"
return img_384, f"Downloaded image at ({lat:.4f}, {lon:.4f})"
except Exception as e:
import traceback
traceback.print_exc()
return None, f"Error: {str(e)}"
def reset_to_global_map():
"""Reset the map to the initial global distribution view"""
img = None
df_vis = None
if 'DINOv2' in models and models['DINOv2'].df_embed is not None:
img, df_vis = plot_global_map_static(models['DINOv2'].df_embed)
else:
img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
return gr.update(value=img, visible=True), [img], df_vis
def format_results_to_text(results):
if not results:
return "No results found."
txt = f"Top {len(results)} Retrieval Results\n"
txt += "=" * 30 + "\n\n"
for i, res in enumerate(results):
txt += f"Rank: {i+1}\n"
txt += f"Product ID: {res['id']}\n"
txt += f"Location: Latitude {res['lat']:.6f}, Longitude {res['lon']:.6f}\n"
txt += f"Similarity Score: {res['score']:.6f}\n"
txt += "-" * 30 + "\n"
return txt
def save_plot(figs):
if figs is None:
return None
try:
# If it's a single image (initial state), save as png
if isinstance(figs, PILImage.Image):
fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
os.close(fd)
figs.save(path)
return path
# If it's a list/tuple of images [map_img, results_img]
if isinstance(figs, (list, tuple)):
# If only one image in list, save as PNG
if len(figs) == 1 and isinstance(figs[0], PILImage.Image):
fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
os.close(fd)
figs[0].save(path)
return path
fd, zip_path = tempfile.mkstemp(suffix='.zip', prefix='earth_explorer_results_')
os.close(fd)
with zipfile.ZipFile(zip_path, 'w') as zipf:
# Save Map
if figs[0] is not None:
map_path = os.path.join(tempfile.gettempdir(), 'map_distribution.png')
figs[0].save(map_path)
zipf.write(map_path, arcname='map_distribution.png')
# Save Results
if len(figs) > 1 and figs[1] is not None:
res_path = os.path.join(tempfile.gettempdir(), 'retrieval_results.png')
figs[1].save(res_path)
zipf.write(res_path, arcname='retrieval_results.png')
# Save Results Text
if len(figs) > 2 and figs[2] is not None:
txt_path = os.path.join(tempfile.gettempdir(), 'results.txt')
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(figs[2])
zipf.write(txt_path, arcname='results.txt')
return zip_path
# Fallback for Plotly figure (if any)
# Create a temporary file
fd, path = tempfile.mkstemp(suffix='.html', prefix='earth_explorer_plot_')
os.close(fd)
# Write to the temporary file
figs.write_html(path)
return path
except Exception as e:
print(f"Error saving: {e}")
return None
# Gradio Blocks Interface
with gr.Blocks(title="EarthEmbeddingExplorer") as demo:
gr.Markdown("# EarthEmbeddingExplorer")
gr.HTML("""
<div style="font-size: 1.2em;">
EarthEmbeddingExplorer is a tool that allows you to search for satellite images of the Earth using natural language descriptions, images, geolocations, or a simple a click on the map. For example, you can type "tropical rainforest" or "coastline with a city," and the system will find locations on Earth that match your description. It then visualizes these locations on a world map and displays the top matching images.
</div>
<div style="display: flex; gap: 0.2em; align-items: center; justify-content: center;">
<a href="https://www.modelscope.cn/studios/VoyagerX/EarthExplorer"><img src="https://img.shields.io/badge/Open in ModelScope.cn-xGPU-624aff"></a>
<a href="https://modelscope.cn/datasets/VoyagerX/EarthEmbeddings"><img src="https://img.shields.io/badge/👾 MS-Dataset-624aff"></a>
<a href="https://huggingface.co/spaces/VoyagerXHF/EarthEmbeddingExplorer/tree/main"><img src="https://img.shields.io/badge/Open in HF Space-CPU-FFD21E"></a>
<a href="https://huggingface.co/datasets/ML4Sustain/EarthEmbeddings"><img src="https://img.shields.io/badge/🤗 HF-Dataset-FFD21E"></a>
<a href="https://huggingface.co/spaces/ML4Sustain/EarthExplorer/blob/main/Tutorial.md"> <img src="https://img.shields.io/badge/Tutorial-📖-007bff"> </a>
<a href="https://modelscope.cn/studios/VoyagerX/EarthExplorer/file/view/master/Tutorial_zh.md?status=1"> <img src="https://img.shields.io/badge/中文教程-📖-007bff"> </a>
</div>
""")
with gr.Row():
with gr.Column(scale=4):
with gr.Tabs():
with gr.TabItem("Text Search") as tab_text:
model_selector_text = gr.Dropdown(choices=["SigLIP", "FarSLIP"], value="FarSLIP", label="Model")
query_input = gr.Textbox(label="Query", placeholder="e.g., rainforest, glacier")
gr.Examples(
examples=[
["a satellite image of a river around a city"],
["a satellite image of a rainforest"],
["a satellite image of a slum"],
["a satellite image of a glacier"],
["a satellite image of snow covered mountains"]
],
inputs=[query_input],
label="Text Examples"
)
search_btn = gr.Button("Search by Text", variant="primary")
with gr.TabItem("Image Search") as tab_image:
model_selector_img = gr.Dropdown(choices=["SigLIP", "FarSLIP", "SatCLIP", "DINOv2"], value="FarSLIP", label="Model")
gr.Markdown("### Option 1: Upload or Select Image")
image_input = gr.Image(type="pil", label="Upload Image")
gr.Examples(
examples=[
["./examples/example1.png"],
["./examples/example2.png"],
["./examples/example3.png"]
],
inputs=[image_input],
label="Image Examples"
)
gr.Markdown("### Option 2: Click Map or Enter Coordinates")
btn_reset_map_img = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm")
with gr.Row():
img_lat = gr.Number(label="Latitude", interactive=True)
img_lon = gr.Number(label="Longitude", interactive=True)
img_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
img_click_status = gr.Markdown("")
btn_download_img = gr.Button("Download Image by Geolocation", variant="secondary")
search_img_btn = gr.Button("Search by Image", variant="primary")
with gr.TabItem("Location Search") as tab_location:
gr.Markdown("Search using **SatCLIP** location encoder.")
gr.Markdown("### Click Map or Enter Coordinates")
btn_reset_map_loc = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm")
with gr.Row():
lat_input = gr.Number(label="Latitude", value=30.0, interactive=True)
lon_input = gr.Number(label="Longitude", value=120.0, interactive=True)
loc_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
loc_click_status = gr.Markdown("")
gr.Examples(
examples=[
[30.32, 120.15],
[40.7128, -74.0060],
[24.65, 46.71],
[-3.4653, -62.2159],
[64.4, 16.8]
],
inputs=[lat_input, lon_input],
label="Location Examples"
)
search_loc_btn = gr.Button("Search by Location", variant="primary")
threshold_slider = gr.Slider(minimum=1, maximum=30, value=7, step=1, label="Top Percentage (‰)")
status_output = gr.Textbox(label="Status", lines=10)
save_btn = gr.Button("Download Result")
download_file = gr.File(label="Zipped Results", height=40)
with gr.Column(scale=6):
plot_map = gr.Image(
label="Geographical Distribution",
type="pil",
interactive=False,
height=400,
width=800,
visible=True
)
plot_map_interactive = gr.Plot(
label="Geographical Distribution (Interactive)",
visible=False
)
results_plot = gr.Image(label="Top 5 Matched Images", type="pil")
gallery_images = gr.Gallery(label="Top Retrieved Images (Zoom)", columns=3, height="auto")
current_fig = gr.State()
map_data_state = gr.State()
# Initial Load
demo.load(fn=get_initial_plot, outputs=[plot_map, current_fig, map_data_state, plot_map_interactive])
# Reset Map Buttons
btn_reset_map_img.click(
fn=reset_to_global_map,
outputs=[plot_map, current_fig, map_data_state]
)
btn_reset_map_loc.click(
fn=reset_to_global_map,
outputs=[plot_map, current_fig, map_data_state]
)
# Map Click Event - updates Image Search coordinates
plot_map.select(
fn=handle_map_click,
inputs=[map_data_state],
outputs=[img_lat, img_lon, img_pid, img_click_status]
)
# Map Click Event - also updates Location Search coordinates
plot_map.select(
fn=handle_map_click,
inputs=[map_data_state],
outputs=[lat_input, lon_input, loc_pid, loc_click_status]
)
# Download Image by Geolocation
btn_download_img.click(
fn=download_image_by_location,
inputs=[img_lat, img_lon, img_pid, model_selector_img],
outputs=[image_input, img_click_status]
)
# Search Event (Text)
search_btn.click(
fn=search_text,
inputs=[query_input, threshold_slider, model_selector_text],
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
)
# Search Event (Image)
search_img_btn.click(
fn=search_image,
inputs=[image_input, threshold_slider, model_selector_img],
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
)
# Search Event (Location)
search_loc_btn.click(
fn=search_location,
inputs=[lat_input, lon_input, threshold_slider],
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
)
# Save Event
save_btn.click(
fn=save_plot,
inputs=[current_fig],
outputs=[download_file]
)
# Tab Selection Events
def show_static_map():
return gr.update(visible=True), gr.update(visible=False)
tab_text.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
tab_image.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
tab_location.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)