|
|
import gradio as gr |
|
|
import torch |
|
|
import time |
|
|
import os |
|
|
import tempfile |
|
|
import zipfile |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
|
|
|
from models.siglip_model import SigLIPModel |
|
|
from models.satclip_model import SatCLIPModel |
|
|
from models.farslip_model import FarSLIPModel |
|
|
from models.dinov2_model import DINOv2Model |
|
|
from models.load_config import load_and_process_config |
|
|
from visualize import format_results_for_gallery, plot_top5_overview, plot_location_distribution, plot_global_map_static, plot_geographic_distribution |
|
|
from data_utils import download_and_process_image, get_esri_satellite_image, get_placeholder_image |
|
|
from PIL import Image as PILImage |
|
|
from PIL import ImageDraw, ImageFont |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Running on device: {device}") |
|
|
|
|
|
|
|
|
config = load_and_process_config() |
|
|
print(config) |
|
|
|
|
|
|
|
|
print("Initializing models...") |
|
|
models = {} |
|
|
|
|
|
|
|
|
try: |
|
|
if config and 'dinov2' in config: |
|
|
models['DINOv2'] = DINOv2Model( |
|
|
ckpt_path=config['dinov2'].get('ckpt_path'), |
|
|
embedding_path=config['dinov2'].get('embedding_path'), |
|
|
device=device |
|
|
) |
|
|
else: |
|
|
models['DINOv2'] = DINOv2Model(device=device) |
|
|
except Exception as e: |
|
|
print(f"Failed to load DINOv2: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
if config and 'siglip' in config: |
|
|
models['SigLIP'] = SigLIPModel( |
|
|
ckpt_path=config['siglip'].get('ckpt_path'), |
|
|
tokenizer_path=config['siglip'].get('tokenizer_path'), |
|
|
embedding_path=config['siglip'].get('embedding_path'), |
|
|
device=device |
|
|
) |
|
|
else: |
|
|
models['SigLIP'] = SigLIPModel(device=device) |
|
|
except Exception as e: |
|
|
print(f"Failed to load SigLIP: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
if config and 'satclip' in config: |
|
|
models['SatCLIP'] = SatCLIPModel( |
|
|
ckpt_path=config['satclip'].get('ckpt_path'), |
|
|
embedding_path=config['satclip'].get('embedding_path'), |
|
|
device=device |
|
|
) |
|
|
else: |
|
|
models['SatCLIP'] = SatCLIPModel(device=device) |
|
|
except Exception as e: |
|
|
print(f"Failed to load SatCLIP: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
if config and 'farslip' in config: |
|
|
models['FarSLIP'] = FarSLIPModel( |
|
|
ckpt_path=config['farslip'].get('ckpt_path'), |
|
|
model_name=config['farslip'].get('model_name'), |
|
|
embedding_path=config['farslip'].get('embedding_path'), |
|
|
device=device |
|
|
) |
|
|
else: |
|
|
models['FarSLIP'] = FarSLIPModel(device=device) |
|
|
except Exception as e: |
|
|
print(f"Failed to load FarSLIP: {e}") |
|
|
|
|
|
def get_active_model(model_name): |
|
|
if model_name not in models: |
|
|
return None, f"Model {model_name} not loaded." |
|
|
return models[model_name], None |
|
|
|
|
|
def combine_images(img1, img2): |
|
|
if img1 is None: return img2 |
|
|
if img2 is None: return img1 |
|
|
|
|
|
|
|
|
w1, h1 = img1.size |
|
|
w2, h2 = img2.size |
|
|
|
|
|
new_w = max(w1, w2) |
|
|
new_h1 = int(h1 * new_w / w1) |
|
|
new_h2 = int(h2 * new_w / w2) |
|
|
|
|
|
img1 = img1.resize((new_w, new_h1)) |
|
|
img2 = img2.resize((new_w, new_h2)) |
|
|
|
|
|
dst = PILImage.new('RGB', (new_w, new_h1 + new_h2), (255, 255, 255)) |
|
|
dst.paste(img1, (0, 0)) |
|
|
dst.paste(img2, (0, new_h1)) |
|
|
return dst |
|
|
|
|
|
def create_text_image(text, size=(384, 384)): |
|
|
img = PILImage.new('RGB', size, color=(240, 240, 240)) |
|
|
d = ImageDraw.Draw(img) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
font = ImageFont.truetype("DejaVuSans.ttf", 40) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
|
|
|
margin = 20 |
|
|
offset = 100 |
|
|
for line in text.split(','): |
|
|
d.text((margin, offset), line.strip(), font=font, fill=(0, 0, 0)) |
|
|
offset += 50 |
|
|
|
|
|
d.text((margin, offset + 50), "Text Query", font=font, fill=(0, 0, 255)) |
|
|
return img |
|
|
|
|
|
def fetch_top_k_images(top_indices, probs, df_embed, query_text=None): |
|
|
""" |
|
|
Fetches top-k images using actual dataset download (ModelScope) via download_and_process_image. |
|
|
""" |
|
|
results = [] |
|
|
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=5) as executor: |
|
|
future_to_idx = {} |
|
|
for i, idx in enumerate(top_indices): |
|
|
row = df_embed.iloc[idx] |
|
|
pid = row['product_id'] |
|
|
|
|
|
|
|
|
future = executor.submit(download_and_process_image, pid, df_source=df_embed, verbose=False) |
|
|
future_to_idx[future] = idx |
|
|
|
|
|
for future in as_completed(future_to_idx): |
|
|
idx = future_to_idx[future] |
|
|
try: |
|
|
img_384, img_full = future.result() |
|
|
|
|
|
if img_384 is None: |
|
|
|
|
|
print(f"Download failed for idx {idx}, falling back to Esri...") |
|
|
row = df_embed.iloc[idx] |
|
|
img_384 = get_esri_satellite_image(row['centre_lat'], row['centre_lon'], score=probs[idx], rank=0, query=query_text) |
|
|
img_full = img_384 |
|
|
|
|
|
row = df_embed.iloc[idx] |
|
|
results.append({ |
|
|
'image_384': img_384, |
|
|
'image_full': img_full, |
|
|
'score': probs[idx], |
|
|
'lat': row['centre_lat'], |
|
|
'lon': row['centre_lon'], |
|
|
'id': row['product_id'] |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"Error fetching image for idx {idx}: {e}") |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x['score'], reverse=True) |
|
|
return results |
|
|
|
|
|
def get_all_results_metadata(model, filtered_indices, probs): |
|
|
if len(filtered_indices) == 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
filtered_scores = probs[filtered_indices] |
|
|
sorted_order = np.argsort(filtered_scores)[::-1] |
|
|
sorted_indices = filtered_indices[sorted_order] |
|
|
|
|
|
|
|
|
df_results = model.df_embed.iloc[sorted_indices].copy() |
|
|
df_results['score'] = probs[sorted_indices] |
|
|
|
|
|
|
|
|
df_results = df_results.rename(columns={'product_id': 'id', 'centre_lat': 'lat', 'centre_lon': 'lon'}) |
|
|
|
|
|
|
|
|
return df_results[['id', 'lat', 'lon', 'score']].to_dict('records') |
|
|
|
|
|
def search_text(query, threshold, model_name): |
|
|
model, error = get_active_model(model_name) |
|
|
if error: |
|
|
yield None, None, error, None, None, None, None |
|
|
return |
|
|
|
|
|
if not query: |
|
|
yield None, None, "Please enter a query.", None, None, None, None |
|
|
return |
|
|
|
|
|
try: |
|
|
timings = {} |
|
|
|
|
|
|
|
|
yield None, None, "Encoding text...", None, None, None, None |
|
|
t0 = time.time() |
|
|
text_features = model.encode_text(query) |
|
|
timings['Encoding'] = time.time() - t0 |
|
|
|
|
|
if text_features is None: |
|
|
yield None, None, "Model does not support text encoding or is not initialized.", None, None, None, None |
|
|
return |
|
|
|
|
|
|
|
|
yield None, None, "Encoding text... ✓\nRetrieving similar images...", None, None, None, None |
|
|
t0 = time.time() |
|
|
probs, filtered_indices, top_indices = model.search(text_features, top_percent=threshold/1000.0) |
|
|
timings['Retrieval'] = time.time() - t0 |
|
|
|
|
|
if probs is None: |
|
|
yield None, None, "Search failed (embeddings missing?).", None, None, None, None |
|
|
return |
|
|
|
|
|
|
|
|
df_embed = model.df_embed |
|
|
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to "{query}" ({model_name})') |
|
|
|
|
|
|
|
|
yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) |
|
|
t0 = time.time() |
|
|
top_indices = top_indices[:10] |
|
|
results = fetch_top_k_images(top_indices, probs, df_embed, query_text=query) |
|
|
timings['Download'] = time.time() - t0 |
|
|
|
|
|
|
|
|
yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) |
|
|
t0 = time.time() |
|
|
fig_results = plot_top5_overview(None, results, query_info=query) |
|
|
gallery_items = format_results_for_gallery(results) |
|
|
timings['Visualization'] = time.time() - t0 |
|
|
|
|
|
|
|
|
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n" |
|
|
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results) |
|
|
|
|
|
all_results = get_all_results_metadata(model, filtered_indices, probs) |
|
|
results_txt = format_results_to_text(all_results) |
|
|
|
|
|
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
yield None, None, f"Error: {str(e)}", None, None, None, None |
|
|
|
|
|
def search_image(image_input, threshold, model_name): |
|
|
model, error = get_active_model(model_name) |
|
|
if error: |
|
|
yield None, None, error, None, None, None, None |
|
|
return |
|
|
|
|
|
if image_input is None: |
|
|
yield None, None, "Please upload an image.", None, None, None, None |
|
|
return |
|
|
|
|
|
try: |
|
|
timings = {} |
|
|
|
|
|
|
|
|
yield None, None, "Encoding image...", None, None, None, None |
|
|
t0 = time.time() |
|
|
image_features = model.encode_image(image_input) |
|
|
timings['Encoding'] = time.time() - t0 |
|
|
|
|
|
if image_features is None: |
|
|
yield None, None, "Model does not support image encoding.", None, None, None, None |
|
|
return |
|
|
|
|
|
|
|
|
yield None, None, "Encoding image... ✓\nRetrieving similar images...", None, None, None, None |
|
|
t0 = time.time() |
|
|
probs, filtered_indices, top_indices = model.search(image_features, top_percent=threshold/1000.0) |
|
|
timings['Retrieval'] = time.time() - t0 |
|
|
|
|
|
|
|
|
df_embed = model.df_embed |
|
|
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Input Image ({model_name})') |
|
|
|
|
|
|
|
|
yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) |
|
|
t0 = time.time() |
|
|
top_indices = top_indices[:6] |
|
|
results = fetch_top_k_images(top_indices, probs, df_embed, query_text="Image Query") |
|
|
timings['Download'] = time.time() - t0 |
|
|
|
|
|
|
|
|
yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) |
|
|
t0 = time.time() |
|
|
fig_results = plot_top5_overview(image_input, results, query_info="Image Query") |
|
|
gallery_items = format_results_for_gallery(results) |
|
|
timings['Visualization'] = time.time() - t0 |
|
|
|
|
|
|
|
|
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n" |
|
|
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results) |
|
|
|
|
|
all_results = get_all_results_metadata(model, filtered_indices, probs) |
|
|
results_txt = format_results_to_text(all_results[:50]) |
|
|
|
|
|
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
yield None, None, f"Error: {str(e)}", None, None, None, None |
|
|
|
|
|
def search_location(lat, lon, threshold): |
|
|
model_name = "SatCLIP" |
|
|
model, error = get_active_model(model_name) |
|
|
if error: |
|
|
yield None, None, error, None, None, None, None |
|
|
return |
|
|
|
|
|
try: |
|
|
timings = {} |
|
|
|
|
|
|
|
|
yield None, None, "Encoding location...", None, None, None, None |
|
|
t0 = time.time() |
|
|
loc_features = model.encode_location(float(lat), float(lon)) |
|
|
timings['Encoding'] = time.time() - t0 |
|
|
|
|
|
if loc_features is None: |
|
|
yield None, None, "Location encoding failed.", None, None, None, None |
|
|
return |
|
|
|
|
|
|
|
|
yield None, None, "Encoding location... ✓\nRetrieving similar images...", None, None, None, None |
|
|
t0 = time.time() |
|
|
probs, filtered_indices, top_indices = model.search(loc_features, top_percent=threshold/100.0) |
|
|
timings['Retrieval'] = time.time() - t0 |
|
|
|
|
|
|
|
|
yield None, None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map...", None, None, None, None |
|
|
df_embed = model.df_embed |
|
|
top_10_indices = top_indices[:10] |
|
|
top_10_results = [] |
|
|
for idx in top_10_indices: |
|
|
row = df_embed.iloc[idx] |
|
|
top_10_results.append({'lat': row['centre_lat'], 'lon': row['centre_lon']}) |
|
|
|
|
|
|
|
|
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Location ({lat}, {lon})') |
|
|
|
|
|
|
|
|
yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) |
|
|
t0 = time.time() |
|
|
top_6_indices = top_indices[:6] |
|
|
results = fetch_top_k_images(top_6_indices, probs, df_embed, query_text=f"Loc: {lat},{lon}") |
|
|
|
|
|
|
|
|
query_tile = None |
|
|
try: |
|
|
lats = pd.to_numeric(df_embed['centre_lat'], errors='coerce') |
|
|
lons = pd.to_numeric(df_embed['centre_lon'], errors='coerce') |
|
|
dists = (lats - float(lat))**2 + (lons - float(lon))**2 |
|
|
nearest_idx = dists.idxmin() |
|
|
pid = df_embed.loc[nearest_idx, 'product_id'] |
|
|
query_tile, _ = download_and_process_image(pid, df_source=df_embed, verbose=False) |
|
|
except Exception as e: |
|
|
print(f"Error fetching nearest MajorTOM image: {e}") |
|
|
if query_tile is None: |
|
|
query_tile = get_placeholder_image(f"Query Location\n({lat}, {lon})") |
|
|
timings['Download'] = time.time() - t0 |
|
|
|
|
|
|
|
|
yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) |
|
|
t0 = time.time() |
|
|
fig_results = plot_top5_overview(query_tile, results, query_info=f"Loc: {lat},{lon}") |
|
|
gallery_items = format_results_for_gallery(results) |
|
|
timings['Visualization'] = time.time() - t0 |
|
|
|
|
|
|
|
|
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n" |
|
|
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results) |
|
|
|
|
|
all_results = get_all_results_metadata(model, filtered_indices, probs) |
|
|
results_txt = format_results_to_text(all_results) |
|
|
|
|
|
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
yield None, None, f"Error: {str(e)}", None, None, None, None |
|
|
|
|
|
def generate_status_msg(count, threshold, results): |
|
|
status_msg = f"Found {count} matches in top {threshold*100:.0f}‰.\n\nTop {len(results)} similar images:\n" |
|
|
for i, res in enumerate(results[:5]): |
|
|
status_msg += f"{i+1}. Product ID: {res['id']}, Location: ({res['lat']:.4f}, {res['lon']:.4f}), Score: {res['score']:.4f}\n" |
|
|
return status_msg |
|
|
|
|
|
def get_initial_plot(): |
|
|
|
|
|
df_vis = None |
|
|
img = None |
|
|
if 'DINOv2' in models and models['DINOv2'].df_embed is not None: |
|
|
img, df_vis = plot_global_map_static(models['DINOv2'].df_embed) |
|
|
|
|
|
else: |
|
|
img, df_vis = plot_global_map_static(models['SigLIP'].df_embed) |
|
|
return gr.update(value=img, visible=True), [img], df_vis, gr.update(visible=False) |
|
|
|
|
|
def handle_map_click(evt: gr.SelectData, df_vis): |
|
|
if evt is None: |
|
|
return None, None, None, "No point selected." |
|
|
|
|
|
try: |
|
|
x, y = evt.index[0], evt.index[1] |
|
|
|
|
|
|
|
|
img_width = 3000 |
|
|
img_height = 1500 |
|
|
|
|
|
|
|
|
left_margin = 110 * 0.75 |
|
|
right_margin = 110 * 0.75 |
|
|
top_margin = 100 * 0.75 |
|
|
bottom_margin = 67 * 0.75 |
|
|
|
|
|
plot_width = img_width - left_margin - right_margin |
|
|
plot_height = img_height - top_margin - bottom_margin |
|
|
|
|
|
|
|
|
map_aspect = 360.0 / 180.0 |
|
|
plot_aspect = plot_width / plot_height |
|
|
|
|
|
if plot_aspect > map_aspect: |
|
|
actual_map_width = plot_height * map_aspect |
|
|
actual_map_height = plot_height |
|
|
h_offset = (plot_width - actual_map_width) / 2 |
|
|
v_offset = 0 |
|
|
else: |
|
|
actual_map_width = plot_width |
|
|
actual_map_height = plot_width / map_aspect |
|
|
h_offset = 0 |
|
|
v_offset = (plot_height - actual_map_height) / 2 |
|
|
|
|
|
|
|
|
x_in_plot = x - left_margin |
|
|
y_in_plot = y - top_margin |
|
|
|
|
|
|
|
|
if (x_in_plot < h_offset or x_in_plot > h_offset + actual_map_width or |
|
|
y_in_plot < v_offset or y_in_plot > v_offset + actual_map_height): |
|
|
return None, None, None, "Click outside map area. Please click on the map." |
|
|
|
|
|
|
|
|
x_rel = (x_in_plot - h_offset) / actual_map_width |
|
|
y_rel = (y_in_plot - v_offset) / actual_map_height |
|
|
|
|
|
|
|
|
x_rel = max(0, min(1, x_rel)) |
|
|
y_rel = max(0, min(1, y_rel)) |
|
|
|
|
|
|
|
|
lon = x_rel * 360 - 180 |
|
|
lat = 90 - y_rel * 180 |
|
|
|
|
|
|
|
|
pid = "" |
|
|
if df_vis is not None: |
|
|
dists = (df_vis['centre_lat'] - lat)**2 + (df_vis['centre_lon'] - lon)**2 |
|
|
min_idx = dists.idxmin() |
|
|
nearest_row = df_vis.loc[min_idx] |
|
|
|
|
|
if dists[min_idx] < 25: |
|
|
lat = nearest_row['centre_lat'] |
|
|
lon = nearest_row['centre_lon'] |
|
|
pid = nearest_row['product_id'] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error handling click: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, None, None, f"Error: {e}" |
|
|
|
|
|
return lat, lon, pid, f"Selected Point: ({lat:.4f}, {lon:.4f})" |
|
|
|
|
|
def download_image_by_location(lat, lon, pid, model_name): |
|
|
"""Download and return the image at the specified location""" |
|
|
if lat is None or lon is None: |
|
|
return None, "Please specify coordinates first." |
|
|
|
|
|
model, error = get_active_model(model_name) |
|
|
if error: |
|
|
return None, error |
|
|
|
|
|
try: |
|
|
|
|
|
lat = float(lat) |
|
|
lon = float(lon) |
|
|
|
|
|
|
|
|
if not pid: |
|
|
df = model.df_embed |
|
|
lats = pd.to_numeric(df['centre_lat'], errors='coerce') |
|
|
lons = pd.to_numeric(df['centre_lon'], errors='coerce') |
|
|
dists = (lats - lat)**2 + (lons - lon)**2 |
|
|
nearest_idx = dists.idxmin() |
|
|
pid = df.loc[nearest_idx, 'product_id'] |
|
|
|
|
|
|
|
|
img_384, _ = download_and_process_image(pid, df_source=model.df_embed, verbose=True) |
|
|
|
|
|
if img_384 is None: |
|
|
return None, f"Failed to download image for location ({lat:.4f}, {lon:.4f})" |
|
|
|
|
|
return img_384, f"Downloaded image at ({lat:.4f}, {lon:.4f})" |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, f"Error: {str(e)}" |
|
|
|
|
|
def reset_to_global_map(): |
|
|
"""Reset the map to the initial global distribution view""" |
|
|
img = None |
|
|
df_vis = None |
|
|
if 'DINOv2' in models and models['DINOv2'].df_embed is not None: |
|
|
img, df_vis = plot_global_map_static(models['DINOv2'].df_embed) |
|
|
else: |
|
|
img, df_vis = plot_global_map_static(models['SigLIP'].df_embed) |
|
|
|
|
|
return gr.update(value=img, visible=True), [img], df_vis |
|
|
|
|
|
def format_results_to_text(results): |
|
|
if not results: |
|
|
return "No results found." |
|
|
|
|
|
txt = f"Top {len(results)} Retrieval Results\n" |
|
|
txt += "=" * 30 + "\n\n" |
|
|
for i, res in enumerate(results): |
|
|
txt += f"Rank: {i+1}\n" |
|
|
txt += f"Product ID: {res['id']}\n" |
|
|
txt += f"Location: Latitude {res['lat']:.6f}, Longitude {res['lon']:.6f}\n" |
|
|
txt += f"Similarity Score: {res['score']:.6f}\n" |
|
|
txt += "-" * 30 + "\n" |
|
|
return txt |
|
|
|
|
|
def save_plot(figs): |
|
|
if figs is None: |
|
|
return None |
|
|
try: |
|
|
|
|
|
if isinstance(figs, PILImage.Image): |
|
|
fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_') |
|
|
os.close(fd) |
|
|
figs.save(path) |
|
|
return path |
|
|
|
|
|
|
|
|
if isinstance(figs, (list, tuple)): |
|
|
|
|
|
if len(figs) == 1 and isinstance(figs[0], PILImage.Image): |
|
|
fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_') |
|
|
os.close(fd) |
|
|
figs[0].save(path) |
|
|
return path |
|
|
|
|
|
fd, zip_path = tempfile.mkstemp(suffix='.zip', prefix='earth_explorer_results_') |
|
|
os.close(fd) |
|
|
|
|
|
with zipfile.ZipFile(zip_path, 'w') as zipf: |
|
|
|
|
|
if figs[0] is not None: |
|
|
map_path = os.path.join(tempfile.gettempdir(), 'map_distribution.png') |
|
|
figs[0].save(map_path) |
|
|
zipf.write(map_path, arcname='map_distribution.png') |
|
|
|
|
|
|
|
|
if len(figs) > 1 and figs[1] is not None: |
|
|
res_path = os.path.join(tempfile.gettempdir(), 'retrieval_results.png') |
|
|
figs[1].save(res_path) |
|
|
zipf.write(res_path, arcname='retrieval_results.png') |
|
|
|
|
|
|
|
|
if len(figs) > 2 and figs[2] is not None: |
|
|
txt_path = os.path.join(tempfile.gettempdir(), 'results.txt') |
|
|
with open(txt_path, 'w', encoding='utf-8') as f: |
|
|
f.write(figs[2]) |
|
|
zipf.write(txt_path, arcname='results.txt') |
|
|
|
|
|
return zip_path |
|
|
|
|
|
|
|
|
|
|
|
fd, path = tempfile.mkstemp(suffix='.html', prefix='earth_explorer_plot_') |
|
|
os.close(fd) |
|
|
|
|
|
|
|
|
figs.write_html(path) |
|
|
return path |
|
|
except Exception as e: |
|
|
print(f"Error saving: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
with gr.Blocks(title="EarthEmbeddingExplorer") as demo: |
|
|
gr.Markdown("# EarthEmbeddingExplorer") |
|
|
gr.HTML(""" |
|
|
<div style="font-size: 1.2em;"> |
|
|
EarthEmbeddingExplorer is a tool that allows you to search for satellite images of the Earth using natural language descriptions, images, geolocations, or a simple a click on the map. For example, you can type "tropical rainforest" or "coastline with a city," and the system will find locations on Earth that match your description. It then visualizes these locations on a world map and displays the top matching images. |
|
|
</div> |
|
|
|
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=4): |
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Text Search") as tab_text: |
|
|
model_selector_text = gr.Dropdown(choices=["SigLIP", "FarSLIP"], value="FarSLIP", label="Model") |
|
|
query_input = gr.Textbox(label="Query", placeholder="e.g., rainforest, glacier") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["a satellite image of a river around a city"], |
|
|
["a satellite image of a rainforest"], |
|
|
["a satellite image of a slum"], |
|
|
["a satellite image of a glacier"], |
|
|
["a satellite image of snow covered mountains"] |
|
|
], |
|
|
inputs=[query_input], |
|
|
label="Text Examples" |
|
|
) |
|
|
|
|
|
search_btn = gr.Button("Search by Text", variant="primary") |
|
|
|
|
|
with gr.TabItem("Image Search") as tab_image: |
|
|
model_selector_img = gr.Dropdown(choices=["SigLIP", "FarSLIP", "SatCLIP", "DINOv2"], value="FarSLIP", label="Model") |
|
|
|
|
|
gr.Markdown("### Option 1: Upload or Select Image") |
|
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["./examples/example1.png"], |
|
|
["./examples/example2.png"], |
|
|
["./examples/example3.png"] |
|
|
], |
|
|
inputs=[image_input], |
|
|
label="Image Examples" |
|
|
) |
|
|
|
|
|
gr.Markdown("### Option 2: Click Map or Enter Coordinates") |
|
|
btn_reset_map_img = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm") |
|
|
|
|
|
with gr.Row(): |
|
|
img_lat = gr.Number(label="Latitude", interactive=True) |
|
|
img_lon = gr.Number(label="Longitude", interactive=True) |
|
|
|
|
|
img_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False) |
|
|
img_click_status = gr.Markdown("") |
|
|
|
|
|
btn_download_img = gr.Button("Download Image by Geolocation", variant="secondary") |
|
|
|
|
|
search_img_btn = gr.Button("Search by Image", variant="primary") |
|
|
|
|
|
with gr.TabItem("Location Search") as tab_location: |
|
|
gr.Markdown("Search using **SatCLIP** location encoder.") |
|
|
|
|
|
gr.Markdown("### Click Map or Enter Coordinates") |
|
|
btn_reset_map_loc = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm") |
|
|
|
|
|
with gr.Row(): |
|
|
lat_input = gr.Number(label="Latitude", value=30.0, interactive=True) |
|
|
lon_input = gr.Number(label="Longitude", value=120.0, interactive=True) |
|
|
|
|
|
loc_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False) |
|
|
loc_click_status = gr.Markdown("") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
[30.32, 120.15], |
|
|
[40.7128, -74.0060], |
|
|
[24.65, 46.71], |
|
|
[-3.4653, -62.2159], |
|
|
[64.4, 16.8] |
|
|
], |
|
|
inputs=[lat_input, lon_input], |
|
|
label="Location Examples" |
|
|
) |
|
|
|
|
|
search_loc_btn = gr.Button("Search by Location", variant="primary") |
|
|
|
|
|
threshold_slider = gr.Slider(minimum=1, maximum=30, value=7, step=1, label="Top Percentage (‰)") |
|
|
status_output = gr.Textbox(label="Status", lines=10) |
|
|
save_btn = gr.Button("Download Result") |
|
|
download_file = gr.File(label="Zipped Results", height=40) |
|
|
|
|
|
with gr.Column(scale=6): |
|
|
plot_map = gr.Image( |
|
|
label="Geographical Distribution", |
|
|
type="pil", |
|
|
interactive=False, |
|
|
height=400, |
|
|
width=800, |
|
|
visible=True |
|
|
) |
|
|
plot_map_interactive = gr.Plot( |
|
|
label="Geographical Distribution (Interactive)", |
|
|
visible=False |
|
|
) |
|
|
results_plot = gr.Image(label="Top 5 Matched Images", type="pil") |
|
|
gallery_images = gr.Gallery(label="Top Retrieved Images (Zoom)", columns=3, height="auto") |
|
|
|
|
|
current_fig = gr.State() |
|
|
map_data_state = gr.State() |
|
|
|
|
|
|
|
|
demo.load(fn=get_initial_plot, outputs=[plot_map, current_fig, map_data_state, plot_map_interactive]) |
|
|
|
|
|
|
|
|
btn_reset_map_img.click( |
|
|
fn=reset_to_global_map, |
|
|
outputs=[plot_map, current_fig, map_data_state] |
|
|
) |
|
|
|
|
|
btn_reset_map_loc.click( |
|
|
fn=reset_to_global_map, |
|
|
outputs=[plot_map, current_fig, map_data_state] |
|
|
) |
|
|
|
|
|
|
|
|
plot_map.select( |
|
|
fn=handle_map_click, |
|
|
inputs=[map_data_state], |
|
|
outputs=[img_lat, img_lon, img_pid, img_click_status] |
|
|
) |
|
|
|
|
|
|
|
|
plot_map.select( |
|
|
fn=handle_map_click, |
|
|
inputs=[map_data_state], |
|
|
outputs=[lat_input, lon_input, loc_pid, loc_click_status] |
|
|
) |
|
|
|
|
|
|
|
|
btn_download_img.click( |
|
|
fn=download_image_by_location, |
|
|
inputs=[img_lat, img_lon, img_pid, model_selector_img], |
|
|
outputs=[image_input, img_click_status] |
|
|
) |
|
|
|
|
|
|
|
|
search_btn.click( |
|
|
fn=search_text, |
|
|
inputs=[query_input, threshold_slider, model_selector_text], |
|
|
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map] |
|
|
) |
|
|
|
|
|
|
|
|
search_img_btn.click( |
|
|
fn=search_image, |
|
|
inputs=[image_input, threshold_slider, model_selector_img], |
|
|
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map] |
|
|
) |
|
|
|
|
|
|
|
|
search_loc_btn.click( |
|
|
fn=search_location, |
|
|
inputs=[lat_input, lon_input, threshold_slider], |
|
|
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map] |
|
|
) |
|
|
|
|
|
|
|
|
save_btn.click( |
|
|
fn=save_plot, |
|
|
inputs=[current_fig], |
|
|
outputs=[download_file] |
|
|
) |
|
|
|
|
|
|
|
|
def show_static_map(): |
|
|
return gr.update(visible=True), gr.update(visible=False) |
|
|
|
|
|
tab_text.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive]) |
|
|
tab_image.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive]) |
|
|
tab_location.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |
|
|
|