VoyagerXvoyagerx's picture
test: print config before loading models
8962d56
import gradio as gr
import torch
import time
import os
import tempfile
import zipfile
import numpy as np
import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed
# Import custom modules
from models.siglip_model import SigLIPModel
from models.satclip_model import SatCLIPModel
from models.farslip_model import FarSLIPModel
from models.dinov2_model import DINOv2Model
from models.load_config import load_and_process_config
from visualize import format_results_for_gallery, plot_top5_overview, plot_location_distribution, plot_global_map_static, plot_geographic_distribution
from data_utils import download_and_process_image, get_esri_satellite_image, get_placeholder_image
from PIL import Image as PILImage
from PIL import ImageDraw, ImageFont
# Configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on device: {device}")
# Load and process configuration
config = load_and_process_config()
print(config)
# Initialize Models
print("Initializing models...")
models = {}
# DINOv2
try:
if config and 'dinov2' in config:
models['DINOv2'] = DINOv2Model(
ckpt_path=config['dinov2'].get('ckpt_path'),
embedding_path=config['dinov2'].get('embedding_path'),
device=device
)
else:
models['DINOv2'] = DINOv2Model(device=device)
except Exception as e:
print(f"Failed to load DINOv2: {e}")
# SigLIP
try:
if config and 'siglip' in config:
models['SigLIP'] = SigLIPModel(
ckpt_path=config['siglip'].get('ckpt_path'),
tokenizer_path=config['siglip'].get('tokenizer_path'),
embedding_path=config['siglip'].get('embedding_path'),
device=device
)
else:
models['SigLIP'] = SigLIPModel(device=device)
except Exception as e:
print(f"Failed to load SigLIP: {e}")
# SatCLIP
try:
if config and 'satclip' in config:
models['SatCLIP'] = SatCLIPModel(
ckpt_path=config['satclip'].get('ckpt_path'),
embedding_path=config['satclip'].get('embedding_path'),
device=device
)
else:
models['SatCLIP'] = SatCLIPModel(device=device)
except Exception as e:
print(f"Failed to load SatCLIP: {e}")
# FarSLIP
try:
if config and 'farslip' in config:
models['FarSLIP'] = FarSLIPModel(
ckpt_path=config['farslip'].get('ckpt_path'),
model_name=config['farslip'].get('model_name'),
embedding_path=config['farslip'].get('embedding_path'),
device=device
)
else:
models['FarSLIP'] = FarSLIPModel(device=device)
except Exception as e:
print(f"Failed to load FarSLIP: {e}")
def get_active_model(model_name):
if model_name not in models:
return None, f"Model {model_name} not loaded."
return models[model_name], None
def combine_images(img1, img2):
if img1 is None: return img2
if img2 is None: return img1
# Resize to match width
w1, h1 = img1.size
w2, h2 = img2.size
new_w = max(w1, w2)
new_h1 = int(h1 * new_w / w1)
new_h2 = int(h2 * new_w / w2)
img1 = img1.resize((new_w, new_h1))
img2 = img2.resize((new_w, new_h2))
dst = PILImage.new('RGB', (new_w, new_h1 + new_h2), (255, 255, 255))
dst.paste(img1, (0, 0))
dst.paste(img2, (0, new_h1))
return dst
def create_text_image(text, size=(384, 384)):
img = PILImage.new('RGB', size, color=(240, 240, 240))
d = ImageDraw.Draw(img)
# Try to load a font, fallback to default
try:
# Try to find a font that supports larger size
font = ImageFont.truetype("DejaVuSans.ttf", 40)
except:
font = ImageFont.load_default()
# Wrap text simply
margin = 20
offset = 100
for line in text.split(','):
d.text((margin, offset), line.strip(), font=font, fill=(0, 0, 0))
offset += 50
d.text((margin, offset + 50), "Text Query", font=font, fill=(0, 0, 255))
return img
def fetch_top_k_images(top_indices, probs, df_embed, query_text=None):
"""
Fetches top-k images using actual dataset download (ModelScope) via download_and_process_image.
"""
results = []
# We can run this in parallel
with ThreadPoolExecutor(max_workers=5) as executor:
future_to_idx = {}
for i, idx in enumerate(top_indices):
row = df_embed.iloc[idx]
pid = row['product_id']
# Use download_and_process_image to get real data
future = executor.submit(download_and_process_image, pid, df_source=df_embed, verbose=False)
future_to_idx[future] = idx
for future in as_completed(future_to_idx):
idx = future_to_idx[future]
try:
img_384, img_full = future.result()
if img_384 is None:
# Fallback to Esri if download fails
print(f"Download failed for idx {idx}, falling back to Esri...")
row = df_embed.iloc[idx]
img_384 = get_esri_satellite_image(row['centre_lat'], row['centre_lon'], score=probs[idx], rank=0, query=query_text)
img_full = img_384
row = df_embed.iloc[idx]
results.append({
'image_384': img_384,
'image_full': img_full,
'score': probs[idx],
'lat': row['centre_lat'],
'lon': row['centre_lon'],
'id': row['product_id']
})
except Exception as e:
print(f"Error fetching image for idx {idx}: {e}")
# Sort results by score descending (since futures complete in random order)
results.sort(key=lambda x: x['score'], reverse=True)
return results
def get_all_results_metadata(model, filtered_indices, probs):
if len(filtered_indices) == 0:
return []
# Sort by score descending
filtered_scores = probs[filtered_indices]
sorted_order = np.argsort(filtered_scores)[::-1]
sorted_indices = filtered_indices[sorted_order]
# Extract from DataFrame
df_results = model.df_embed.iloc[sorted_indices].copy()
df_results['score'] = probs[sorted_indices]
# Rename columns
df_results = df_results.rename(columns={'product_id': 'id', 'centre_lat': 'lat', 'centre_lon': 'lon'})
# Convert to list of dicts
return df_results[['id', 'lat', 'lon', 'score']].to_dict('records')
def search_text(query, threshold, model_name):
model, error = get_active_model(model_name)
if error:
yield None, None, error, None, None, None, None
return
if not query:
yield None, None, "Please enter a query.", None, None, None, None
return
try:
timings = {}
# 1. Encode Text
yield None, None, "Encoding text...", None, None, None, None
t0 = time.time()
text_features = model.encode_text(query)
timings['Encoding'] = time.time() - t0
if text_features is None:
yield None, None, "Model does not support text encoding or is not initialized.", None, None, None, None
return
# 2. Search
yield None, None, "Encoding text... ✓\nRetrieving similar images...", None, None, None, None
t0 = time.time()
probs, filtered_indices, top_indices = model.search(text_features, top_percent=threshold/1000.0)
timings['Retrieval'] = time.time() - t0
if probs is None:
yield None, None, "Search failed (embeddings missing?).", None, None, None, None
return
# Show geographic distribution (not timed)
df_embed = model.df_embed
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to "{query}" ({model_name})')
# 3. Download Images
yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
top_indices = top_indices[:10]
results = fetch_top_k_images(top_indices, probs, df_embed, query_text=query)
timings['Download'] = time.time() - t0
# 4. Visualize - keep geo_dist_map visible
yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
fig_results = plot_top5_overview(None, results, query_info=query)
gallery_items = format_results_for_gallery(results)
timings['Visualization'] = time.time() - t0
# 5. Generate Final Status
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
all_results = get_all_results_metadata(model, filtered_indices, probs)
results_txt = format_results_to_text(all_results)
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
except Exception as e:
import traceback
traceback.print_exc()
yield None, None, f"Error: {str(e)}", None, None, None, None
def search_image(image_input, threshold, model_name):
model, error = get_active_model(model_name)
if error:
yield None, None, error, None, None, None, None
return
if image_input is None:
yield None, None, "Please upload an image.", None, None, None, None
return
try:
timings = {}
# 1. Encode Image
yield None, None, "Encoding image...", None, None, None, None
t0 = time.time()
image_features = model.encode_image(image_input)
timings['Encoding'] = time.time() - t0
if image_features is None:
yield None, None, "Model does not support image encoding.", None, None, None, None
return
# 2. Search
yield None, None, "Encoding image... ✓\nRetrieving similar images...", None, None, None, None
t0 = time.time()
probs, filtered_indices, top_indices = model.search(image_features, top_percent=threshold/1000.0)
timings['Retrieval'] = time.time() - t0
# Show geographic distribution (not timed)
df_embed = model.df_embed
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Input Image ({model_name})')
# 3. Download Images
yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
top_indices = top_indices[:6]
results = fetch_top_k_images(top_indices, probs, df_embed, query_text="Image Query")
timings['Download'] = time.time() - t0
# 4. Visualize - keep geo_dist_map visible
yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
fig_results = plot_top5_overview(image_input, results, query_info="Image Query")
gallery_items = format_results_for_gallery(results)
timings['Visualization'] = time.time() - t0
# 5. Generate Final Status
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
all_results = get_all_results_metadata(model, filtered_indices, probs)
results_txt = format_results_to_text(all_results[:50])
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
except Exception as e:
import traceback
traceback.print_exc()
yield None, None, f"Error: {str(e)}", None, None, None, None
def search_location(lat, lon, threshold):
model_name = "SatCLIP"
model, error = get_active_model(model_name)
if error:
yield None, None, error, None, None, None, None
return
try:
timings = {}
# 1. Encode Location
yield None, None, "Encoding location...", None, None, None, None
t0 = time.time()
loc_features = model.encode_location(float(lat), float(lon))
timings['Encoding'] = time.time() - t0
if loc_features is None:
yield None, None, "Location encoding failed.", None, None, None, None
return
# 2. Search
yield None, None, "Encoding location... ✓\nRetrieving similar images...", None, None, None, None
t0 = time.time()
probs, filtered_indices, top_indices = model.search(loc_features, top_percent=threshold/100.0)
timings['Retrieval'] = time.time() - t0
# 3. Generate Distribution Map (not timed for location distribution)
yield None, None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map...", None, None, None, None
df_embed = model.df_embed
top_10_indices = top_indices[:10]
top_10_results = []
for idx in top_10_indices:
row = df_embed.iloc[idx]
top_10_results.append({'lat': row['centre_lat'], 'lon': row['centre_lon']})
# Show geographic distribution (not timed)
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Location ({lat}, {lon})')
# 4. Download Images
yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
top_6_indices = top_indices[:6]
results = fetch_top_k_images(top_6_indices, probs, df_embed, query_text=f"Loc: {lat},{lon}")
# Get query tile
query_tile = None
try:
lats = pd.to_numeric(df_embed['centre_lat'], errors='coerce')
lons = pd.to_numeric(df_embed['centre_lon'], errors='coerce')
dists = (lats - float(lat))**2 + (lons - float(lon))**2
nearest_idx = dists.idxmin()
pid = df_embed.loc[nearest_idx, 'product_id']
query_tile, _ = download_and_process_image(pid, df_source=df_embed, verbose=False)
except Exception as e:
print(f"Error fetching nearest MajorTOM image: {e}")
if query_tile is None:
query_tile = get_placeholder_image(f"Query Location\n({lat}, {lon})")
timings['Download'] = time.time() - t0
# 5. Visualize - keep geo_dist_map visible
yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
fig_results = plot_top5_overview(query_tile, results, query_info=f"Loc: {lat},{lon}")
gallery_items = format_results_for_gallery(results)
timings['Visualization'] = time.time() - t0
# 6. Generate Final Status
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
all_results = get_all_results_metadata(model, filtered_indices, probs)
results_txt = format_results_to_text(all_results)
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
except Exception as e:
import traceback
traceback.print_exc()
yield None, None, f"Error: {str(e)}", None, None, None, None
def generate_status_msg(count, threshold, results):
status_msg = f"Found {count} matches in top {threshold*100:.0f}‰.\n\nTop {len(results)} similar images:\n"
for i, res in enumerate(results[:3]):
status_msg += f"{i+1}. Product ID: {res['id']}, Location: ({res['lat']:.4f}, {res['lon']:.4f}), Score: {res['score']:.4f}\n"
return status_msg
def get_initial_plot():
# Use FarSLIP as default for initial plot, fallback to SigLIP
df_vis = None
img = None
if 'DINOv2' in models and models['DINOv2'].df_embed is not None:
img, df_vis = plot_global_map_static(models['DINOv2'].df_embed)
# fig = plot_global_map(models['FarSLIP'].df_embed)
else:
img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
return gr.update(value=img, visible=True), [img], df_vis, gr.update(visible=False)
def handle_map_click(evt: gr.SelectData, df_vis):
if evt is None:
return None, None, None, "No point selected."
try:
x, y = evt.index[0], evt.index[1]
# Image dimensions (New)
img_width = 3000
img_height = 1500
# Scaled Margins (Proportional to 4000x2000)
left_margin = 110 * 0.75
right_margin = 110 * 0.75
top_margin = 100 * 0.75
bottom_margin = 67 * 0.75
plot_width = img_width - left_margin - right_margin
plot_height = img_height - top_margin - bottom_margin
# Adjust for aspect ratio preservation
map_aspect = 360.0 / 180.0 # 2.0
plot_aspect = plot_width / plot_height
if plot_aspect > map_aspect:
actual_map_width = plot_height * map_aspect
actual_map_height = plot_height
h_offset = (plot_width - actual_map_width) / 2
v_offset = 0
else:
actual_map_width = plot_width
actual_map_height = plot_width / map_aspect
h_offset = 0
v_offset = (plot_height - actual_map_height) / 2
# Calculate relative position within the plot area
x_in_plot = x - left_margin
y_in_plot = y - top_margin
# Check if click is within the actual map bounds
if (x_in_plot < h_offset or x_in_plot > h_offset + actual_map_width or
y_in_plot < v_offset or y_in_plot > v_offset + actual_map_height):
return None, None, None, "Click outside map area. Please click on the map."
# Calculate relative position within the map (0 to 1)
x_rel = (x_in_plot - h_offset) / actual_map_width
y_rel = (y_in_plot - v_offset) / actual_map_height
# Clamp to [0, 1]
x_rel = max(0, min(1, x_rel))
y_rel = max(0, min(1, y_rel))
# Convert to geographic coordinates
lon = x_rel * 360 - 180
lat = 90 - y_rel * 180
# Find nearest point in df_vis if available
pid = ""
if df_vis is not None:
dists = (df_vis['centre_lat'] - lat)**2 + (df_vis['centre_lon'] - lon)**2
min_idx = dists.idxmin()
nearest_row = df_vis.loc[min_idx]
if dists[min_idx] < 25:
lat = nearest_row['centre_lat']
lon = nearest_row['centre_lon']
pid = nearest_row['product_id']
except Exception as e:
print(f"Error handling click: {e}")
import traceback
traceback.print_exc()
return None, None, None, f"Error: {e}"
return lat, lon, pid, f"Selected Point: ({lat:.4f}, {lon:.4f})"
def download_image_by_location(lat, lon, pid, model_name):
"""Download and return the image at the specified location"""
if lat is None or lon is None:
return None, "Please specify coordinates first."
model, error = get_active_model(model_name)
if error:
return None, error
try:
# Convert to float to ensure proper formatting
lat = float(lat)
lon = float(lon)
# Find Product ID if not provided
if not pid:
df = model.df_embed
lats = pd.to_numeric(df['centre_lat'], errors='coerce')
lons = pd.to_numeric(df['centre_lon'], errors='coerce')
dists = (lats - lat)**2 + (lons - lon)**2
nearest_idx = dists.idxmin()
pid = df.loc[nearest_idx, 'product_id']
# Download image
img_384, _ = download_and_process_image(pid, df_source=model.df_embed, verbose=True)
if img_384 is None:
return None, f"Failed to download image for location ({lat:.4f}, {lon:.4f})"
return img_384, f"Downloaded image at ({lat:.4f}, {lon:.4f})"
except Exception as e:
import traceback
traceback.print_exc()
return None, f"Error: {str(e)}"
def reset_to_global_map():
"""Reset the map to the initial global distribution view"""
img = None
df_vis = None
if 'DINOv2' in models and models['DINOv2'].df_embed is not None:
img, df_vis = plot_global_map_static(models['DINOv2'].df_embed)
else:
img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
return gr.update(value=img, visible=True), [img], df_vis
def format_results_to_text(results):
if not results:
return "No results found."
txt = f"Top {len(results)} Retrieval Results\n"
txt += "=" * 30 + "\n\n"
for i, res in enumerate(results):
txt += f"Rank: {i+1}\n"
txt += f"Product ID: {res['id']}\n"
txt += f"Location: Latitude {res['lat']:.6f}, Longitude {res['lon']:.6f}\n"
txt += f"Similarity Score: {res['score']:.6f}\n"
txt += "-" * 30 + "\n"
return txt
def save_plot(figs):
if figs is None:
return None
try:
# If it's a single image (initial state), save as png
if isinstance(figs, PILImage.Image):
fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
os.close(fd)
figs.save(path)
return path
# If it's a list/tuple of images [map_img, results_img]
if isinstance(figs, (list, tuple)):
# If only one image in list, save as PNG
if len(figs) == 1 and isinstance(figs[0], PILImage.Image):
fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
os.close(fd)
figs[0].save(path)
return path
fd, zip_path = tempfile.mkstemp(suffix='.zip', prefix='earth_explorer_results_')
os.close(fd)
with zipfile.ZipFile(zip_path, 'w') as zipf:
# Save Map
if figs[0] is not None:
map_path = os.path.join(tempfile.gettempdir(), 'map_distribution.png')
figs[0].save(map_path)
zipf.write(map_path, arcname='map_distribution.png')
# Save Results
if len(figs) > 1 and figs[1] is not None:
res_path = os.path.join(tempfile.gettempdir(), 'retrieval_results.png')
figs[1].save(res_path)
zipf.write(res_path, arcname='retrieval_results.png')
# Save Results Text
if len(figs) > 2 and figs[2] is not None:
txt_path = os.path.join(tempfile.gettempdir(), 'results.txt')
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(figs[2])
zipf.write(txt_path, arcname='results.txt')
return zip_path
# Fallback for Plotly figure (if any)
# Create a temporary file
fd, path = tempfile.mkstemp(suffix='.html', prefix='earth_explorer_plot_')
os.close(fd)
# Write to the temporary file
figs.write_html(path)
return path
except Exception as e:
print(f"Error saving: {e}")
return None
# Gradio Blocks Interface
with gr.Blocks(title="EarthEmbeddingExplorer") as demo:
gr.Markdown("# EarthEmbeddingExplorer")
gr.HTML("""
<div style="font-size: 1.2em;">
EarthEmbeddingExplorer is a tool that allows you to search for satellite images of the Earth using natural language descriptions, images, geolocations, or a simple a click on the map. For example, you can type "tropical rainforest" or "coastline with a city," and the system will find locations on Earth that match your description. It then visualizes these locations on a world map and displays the top matching images.
</div>
<div style="display: flex; gap: 0.2em; align-items: center; justify-content: center;">
<a href="https://www.modelscope.cn/studios/VoyagerX/EarthExplorer"><img src="https://img.shields.io/badge/Open in ModelScope.cn-xGPU-624aff"></a>
<a href="https://modelscope.cn/datasets/VoyagerX/EarthEmbeddings"><img src="https://img.shields.io/badge/👾 MS-Dataset-624aff"></a>
<a href="https://huggingface.co/spaces/ML4Sustain/EarthExplorer"><img src="https://img.shields.io/badge/Open in HF Space-CPU-FFD21E"></a>
<a href="https://huggingface.co/datasets/ML4Sustain/EarthEmbeddings"><img src="https://img.shields.io/badge/🤗 HF-Dataset-FFD21E"></a>
<a href="https://huggingface.co/spaces/ML4Sustain/EarthExplorer/blob/main/Tutorial.md"> <img src="https://img.shields.io/badge/Tutorial-📖-007bff"> </a>
<a href="https://modelscope.cn/studios/VoyagerX/EarthExplorer/file/view/master/Tutorial_zh.md?status=1"> <img src="https://img.shields.io/badge/中文教程-📖-007bff"> </a>
</div>
""")
with gr.Row():
with gr.Column(scale=4):
with gr.Tabs():
with gr.TabItem("Text Search") as tab_text:
model_selector_text = gr.Dropdown(choices=["SigLIP", "FarSLIP"], value="FarSLIP", label="Model")
query_input = gr.Textbox(label="Query", placeholder="e.g., rainforest, glacier")
gr.Examples(
examples=[
["a satellite image of a river around a city"],
["a satellite image of a rainforest"],
["a satellite image of a slum"],
["a satellite image of a glacier"],
["a satellite image of snow covered mountains"]
],
inputs=[query_input],
label="Text Examples"
)
search_btn = gr.Button("Search by Text", variant="primary")
with gr.TabItem("Image Search") as tab_image:
model_selector_img = gr.Dropdown(choices=["SigLIP", "FarSLIP", "SatCLIP", "DINOv2"], value="FarSLIP", label="Model")
gr.Markdown("### Option 1: Upload or Select Image")
image_input = gr.Image(type="pil", label="Upload Image")
gr.Examples(
examples=[
["./examples/example1.png"],
["./examples/example2.png"],
["./examples/example3.png"]
],
inputs=[image_input],
label="Image Examples"
)
gr.Markdown("### Option 2: Click Map or Enter Coordinates")
btn_reset_map_img = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm")
with gr.Row():
img_lat = gr.Number(label="Latitude", interactive=True)
img_lon = gr.Number(label="Longitude", interactive=True)
img_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
img_click_status = gr.Markdown("")
btn_download_img = gr.Button("Download Image by Geolocation", variant="secondary")
search_img_btn = gr.Button("Search by Image", variant="primary")
with gr.TabItem("Location Search") as tab_location:
gr.Markdown("Search using **SatCLIP** location encoder.")
gr.Markdown("### Click Map or Enter Coordinates")
btn_reset_map_loc = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm")
with gr.Row():
lat_input = gr.Number(label="Latitude", value=30.0, interactive=True)
lon_input = gr.Number(label="Longitude", value=120.0, interactive=True)
loc_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
loc_click_status = gr.Markdown("")
gr.Examples(
examples=[
[30.32, 120.15],
[40.7128, -74.0060],
[24.65, 46.71],
[-3.4653, -62.2159],
[64.4, 16.8]
],
inputs=[lat_input, lon_input],
label="Location Examples"
)
search_loc_btn = gr.Button("Search by Location", variant="primary")
threshold_slider = gr.Slider(minimum=1, maximum=30, value=7, step=1, label="Top Percentage (‰)")
status_output = gr.Textbox(label="Status", lines=10)
save_btn = gr.Button("Download Result")
download_file = gr.File(label="Zipped Results", height=40)
with gr.Column(scale=6):
plot_map = gr.Image(
label="Geographical Distribution",
type="pil",
interactive=False,
height=400,
width=800,
visible=True
)
plot_map_interactive = gr.Plot(
label="Geographical Distribution (Interactive)",
visible=False
)
results_plot = gr.Image(label="Top 5 Matched Images", type="pil")
gallery_images = gr.Gallery(label="Top Retrieved Images (Zoom)", columns=3, height="auto")
current_fig = gr.State()
map_data_state = gr.State()
# Initial Load
demo.load(fn=get_initial_plot, outputs=[plot_map, current_fig, map_data_state, plot_map_interactive])
# Reset Map Buttons
btn_reset_map_img.click(
fn=reset_to_global_map,
outputs=[plot_map, current_fig, map_data_state]
)
btn_reset_map_loc.click(
fn=reset_to_global_map,
outputs=[plot_map, current_fig, map_data_state]
)
# Map Click Event - updates Image Search coordinates
plot_map.select(
fn=handle_map_click,
inputs=[map_data_state],
outputs=[img_lat, img_lon, img_pid, img_click_status]
)
# Map Click Event - also updates Location Search coordinates
plot_map.select(
fn=handle_map_click,
inputs=[map_data_state],
outputs=[lat_input, lon_input, loc_pid, loc_click_status]
)
# Download Image by Geolocation
btn_download_img.click(
fn=download_image_by_location,
inputs=[img_lat, img_lon, img_pid, model_selector_img],
outputs=[image_input, img_click_status]
)
# Search Event (Text)
search_btn.click(
fn=search_text,
inputs=[query_input, threshold_slider, model_selector_text],
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
)
# Search Event (Image)
search_img_btn.click(
fn=search_image,
inputs=[image_input, threshold_slider, model_selector_img],
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
)
# Search Event (Location)
search_loc_btn.click(
fn=search_location,
inputs=[lat_input, lon_input, threshold_slider],
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
)
# Save Event
save_btn.click(
fn=save_plot,
inputs=[current_fig],
outputs=[download_file]
)
# Tab Selection Events
def show_static_map():
return gr.update(visible=True), gr.update(visible=False)
tab_text.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
tab_image.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
tab_location.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)