File size: 33,162 Bytes
eb1aec4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 |
import gradio as gr
import torch
import time
import os
import tempfile
import zipfile
import numpy as np
import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed
# Import custom modules
from models.siglip_model import SigLIPModel
from models.satclip_model import SatCLIPModel
from models.farslip_model import FarSLIPModel
from models.dinov2_model import DINOv2Model
from models.load_config import load_and_process_config
from visualize import format_results_for_gallery, plot_top5_overview, plot_location_distribution, plot_global_map_static, plot_geographic_distribution
from data_utils import download_and_process_image, get_esri_satellite_image, get_placeholder_image
from PIL import Image as PILImage
from PIL import ImageDraw, ImageFont
# Configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on device: {device}")
# Load and process configuration
config = load_and_process_config()
print(config)
# Initialize Models
print("Initializing models...")
models = {}
# DINOv2
try:
if config and 'dinov2' in config:
models['DINOv2'] = DINOv2Model(
ckpt_path=config['dinov2'].get('ckpt_path'),
embedding_path=config['dinov2'].get('embedding_path'),
device=device
)
else:
models['DINOv2'] = DINOv2Model(device=device)
except Exception as e:
print(f"Failed to load DINOv2: {e}")
# SigLIP
try:
if config and 'siglip' in config:
models['SigLIP'] = SigLIPModel(
ckpt_path=config['siglip'].get('ckpt_path'),
tokenizer_path=config['siglip'].get('tokenizer_path'),
embedding_path=config['siglip'].get('embedding_path'),
device=device
)
else:
models['SigLIP'] = SigLIPModel(device=device)
except Exception as e:
print(f"Failed to load SigLIP: {e}")
# SatCLIP
try:
if config and 'satclip' in config:
models['SatCLIP'] = SatCLIPModel(
ckpt_path=config['satclip'].get('ckpt_path'),
embedding_path=config['satclip'].get('embedding_path'),
device=device
)
else:
models['SatCLIP'] = SatCLIPModel(device=device)
except Exception as e:
print(f"Failed to load SatCLIP: {e}")
# FarSLIP
try:
if config and 'farslip' in config:
models['FarSLIP'] = FarSLIPModel(
ckpt_path=config['farslip'].get('ckpt_path'),
model_name=config['farslip'].get('model_name'),
embedding_path=config['farslip'].get('embedding_path'),
device=device
)
else:
models['FarSLIP'] = FarSLIPModel(device=device)
except Exception as e:
print(f"Failed to load FarSLIP: {e}")
def get_active_model(model_name):
if model_name not in models:
return None, f"Model {model_name} not loaded."
return models[model_name], None
def combine_images(img1, img2):
if img1 is None: return img2
if img2 is None: return img1
# Resize to match width
w1, h1 = img1.size
w2, h2 = img2.size
new_w = max(w1, w2)
new_h1 = int(h1 * new_w / w1)
new_h2 = int(h2 * new_w / w2)
img1 = img1.resize((new_w, new_h1))
img2 = img2.resize((new_w, new_h2))
dst = PILImage.new('RGB', (new_w, new_h1 + new_h2), (255, 255, 255))
dst.paste(img1, (0, 0))
dst.paste(img2, (0, new_h1))
return dst
def create_text_image(text, size=(384, 384)):
img = PILImage.new('RGB', size, color=(240, 240, 240))
d = ImageDraw.Draw(img)
# Try to load a font, fallback to default
try:
# Try to find a font that supports larger size
font = ImageFont.truetype("DejaVuSans.ttf", 40)
except:
font = ImageFont.load_default()
# Wrap text simply
margin = 20
offset = 100
for line in text.split(','):
d.text((margin, offset), line.strip(), font=font, fill=(0, 0, 0))
offset += 50
d.text((margin, offset + 50), "Text Query", font=font, fill=(0, 0, 255))
return img
def fetch_top_k_images(top_indices, probs, df_embed, query_text=None):
"""
Fetches top-k images using actual dataset download (ModelScope) via download_and_process_image.
"""
results = []
# We can run this in parallel
with ThreadPoolExecutor(max_workers=5) as executor:
future_to_idx = {}
for i, idx in enumerate(top_indices):
row = df_embed.iloc[idx]
pid = row['product_id']
# Use download_and_process_image to get real data
future = executor.submit(download_and_process_image, pid, df_source=df_embed, verbose=False)
future_to_idx[future] = idx
for future in as_completed(future_to_idx):
idx = future_to_idx[future]
try:
img_384, img_full = future.result()
if img_384 is None:
# Fallback to Esri if download fails
print(f"Download failed for idx {idx}, falling back to Esri...")
row = df_embed.iloc[idx]
img_384 = get_esri_satellite_image(row['centre_lat'], row['centre_lon'], score=probs[idx], rank=0, query=query_text)
img_full = img_384
row = df_embed.iloc[idx]
results.append({
'image_384': img_384,
'image_full': img_full,
'score': probs[idx],
'lat': row['centre_lat'],
'lon': row['centre_lon'],
'id': row['product_id']
})
except Exception as e:
print(f"Error fetching image for idx {idx}: {e}")
# Sort results by score descending (since futures complete in random order)
results.sort(key=lambda x: x['score'], reverse=True)
return results
def get_all_results_metadata(model, filtered_indices, probs):
if len(filtered_indices) == 0:
return []
# Sort by score descending
filtered_scores = probs[filtered_indices]
sorted_order = np.argsort(filtered_scores)[::-1]
sorted_indices = filtered_indices[sorted_order]
# Extract from DataFrame
df_results = model.df_embed.iloc[sorted_indices].copy()
df_results['score'] = probs[sorted_indices]
# Rename columns
df_results = df_results.rename(columns={'product_id': 'id', 'centre_lat': 'lat', 'centre_lon': 'lon'})
# Convert to list of dicts
return df_results[['id', 'lat', 'lon', 'score']].to_dict('records')
def search_text(query, threshold, model_name):
model, error = get_active_model(model_name)
if error:
yield None, None, error, None, None, None, None
return
if not query:
yield None, None, "Please enter a query.", None, None, None, None
return
try:
timings = {}
# 1. Encode Text
yield None, None, "Encoding text...", None, None, None, None
t0 = time.time()
text_features = model.encode_text(query)
timings['Encoding'] = time.time() - t0
if text_features is None:
yield None, None, "Model does not support text encoding or is not initialized.", None, None, None, None
return
# 2. Search
yield None, None, "Encoding text... β\nRetrieving similar images...", None, None, None, None
t0 = time.time()
probs, filtered_indices, top_indices = model.search(text_features, top_percent=threshold/1000.0)
timings['Retrieval'] = time.time() - t0
if probs is None:
yield None, None, "Search failed (embeddings missing?).", None, None, None, None
return
# Show geographic distribution (not timed)
df_embed = model.df_embed
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to "{query}" ({model_name})')
# 3. Download Images
yield gr.update(visible=False), None, "Encoding text... β\nRetrieving similar images... β\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
top_indices = top_indices[:10]
results = fetch_top_k_images(top_indices, probs, df_embed, query_text=query)
timings['Download'] = time.time() - t0
# 4. Visualize - keep geo_dist_map visible
yield gr.update(visible=False), None, "Encoding text... β\nRetrieving similar images... β\nDownloading images... β\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
fig_results = plot_top5_overview(None, results, query_info=query)
gallery_items = format_results_for_gallery(results)
timings['Visualization'] = time.time() - t0
# 5. Generate Final Status
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
all_results = get_all_results_metadata(model, filtered_indices, probs)
results_txt = format_results_to_text(all_results)
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
except Exception as e:
import traceback
traceback.print_exc()
yield None, None, f"Error: {str(e)}", None, None, None, None
def search_image(image_input, threshold, model_name):
model, error = get_active_model(model_name)
if error:
yield None, None, error, None, None, None, None
return
if image_input is None:
yield None, None, "Please upload an image.", None, None, None, None
return
try:
timings = {}
# 1. Encode Image
yield None, None, "Encoding image...", None, None, None, None
t0 = time.time()
image_features = model.encode_image(image_input)
timings['Encoding'] = time.time() - t0
if image_features is None:
yield None, None, "Model does not support image encoding.", None, None, None, None
return
# 2. Search
yield None, None, "Encoding image... β\nRetrieving similar images...", None, None, None, None
t0 = time.time()
probs, filtered_indices, top_indices = model.search(image_features, top_percent=threshold/1000.0)
timings['Retrieval'] = time.time() - t0
# Show geographic distribution (not timed)
df_embed = model.df_embed
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Input Image ({model_name})')
# 3. Download Images
yield gr.update(visible=False), None, "Encoding image... β\nRetrieving similar images... β\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
top_indices = top_indices[:6]
results = fetch_top_k_images(top_indices, probs, df_embed, query_text="Image Query")
timings['Download'] = time.time() - t0
# 4. Visualize - keep geo_dist_map visible
yield gr.update(visible=False), None, "Encoding image... β\nRetrieving similar images... β\nDownloading images... β\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
fig_results = plot_top5_overview(image_input, results, query_info="Image Query")
gallery_items = format_results_for_gallery(results)
timings['Visualization'] = time.time() - t0
# 5. Generate Final Status
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
all_results = get_all_results_metadata(model, filtered_indices, probs)
results_txt = format_results_to_text(all_results[:50])
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
except Exception as e:
import traceback
traceback.print_exc()
yield None, None, f"Error: {str(e)}", None, None, None, None
def search_location(lat, lon, threshold):
model_name = "SatCLIP"
model, error = get_active_model(model_name)
if error:
yield None, None, error, None, None, None, None
return
try:
timings = {}
# 1. Encode Location
yield None, None, "Encoding location...", None, None, None, None
t0 = time.time()
loc_features = model.encode_location(float(lat), float(lon))
timings['Encoding'] = time.time() - t0
if loc_features is None:
yield None, None, "Location encoding failed.", None, None, None, None
return
# 2. Search
yield None, None, "Encoding location... β\nRetrieving similar images...", None, None, None, None
t0 = time.time()
probs, filtered_indices, top_indices = model.search(loc_features, top_percent=threshold/100.0)
timings['Retrieval'] = time.time() - t0
# 3. Generate Distribution Map (not timed for location distribution)
yield None, None, "Encoding location... β\nRetrieving similar images... β\nGenerating distribution map...", None, None, None, None
df_embed = model.df_embed
top_10_indices = top_indices[:10]
top_10_results = []
for idx in top_10_indices:
row = df_embed.iloc[idx]
top_10_results.append({'lat': row['centre_lat'], 'lon': row['centre_lon']})
# Show geographic distribution (not timed)
geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Location ({lat}, {lon})')
# 4. Download Images
yield gr.update(visible=False), None, "Encoding location... β\nRetrieving similar images... β\nGenerating distribution map... β\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
top_6_indices = top_indices[:6]
results = fetch_top_k_images(top_6_indices, probs, df_embed, query_text=f"Loc: {lat},{lon}")
# Get query tile
query_tile = None
try:
lats = pd.to_numeric(df_embed['centre_lat'], errors='coerce')
lons = pd.to_numeric(df_embed['centre_lon'], errors='coerce')
dists = (lats - float(lat))**2 + (lons - float(lon))**2
nearest_idx = dists.idxmin()
pid = df_embed.loc[nearest_idx, 'product_id']
query_tile, _ = download_and_process_image(pid, df_source=df_embed, verbose=False)
except Exception as e:
print(f"Error fetching nearest MajorTOM image: {e}")
if query_tile is None:
query_tile = get_placeholder_image(f"Query Location\n({lat}, {lon})")
timings['Download'] = time.time() - t0
# 5. Visualize - keep geo_dist_map visible
yield gr.update(visible=False), None, "Encoding location... β\nRetrieving similar images... β\nGenerating distribution map... β\nDownloading images... β\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True)
t0 = time.time()
fig_results = plot_top5_overview(query_tile, results, query_info=f"Loc: {lat},{lon}")
gallery_items = format_results_for_gallery(results)
timings['Visualization'] = time.time() - t0
# 6. Generate Final Status
timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n"
status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results)
all_results = get_all_results_metadata(model, filtered_indices, probs)
results_txt = format_results_to_text(all_results)
yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True)
except Exception as e:
import traceback
traceback.print_exc()
yield None, None, f"Error: {str(e)}", None, None, None, None
def generate_status_msg(count, threshold, results):
status_msg = f"Found {count} matches in top {threshold*100:.0f}β°.\n\nTop {len(results)} similar images:\n"
for i, res in enumerate(results[:5]):
status_msg += f"{i+1}. Product ID: {res['id']}, Location: ({res['lat']:.4f}, {res['lon']:.4f}), Score: {res['score']:.4f}\n"
return status_msg
def get_initial_plot():
# Use FarSLIP as default for initial plot, fallback to SigLIP
df_vis = None
img = None
if 'DINOv2' in models and models['DINOv2'].df_embed is not None:
img, df_vis = plot_global_map_static(models['DINOv2'].df_embed)
# fig = plot_global_map(models['FarSLIP'].df_embed)
else:
img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
return gr.update(value=img, visible=True), [img], df_vis, gr.update(visible=False)
def handle_map_click(evt: gr.SelectData, df_vis):
if evt is None:
return None, None, None, "No point selected."
try:
x, y = evt.index[0], evt.index[1]
# Image dimensions (New)
img_width = 3000
img_height = 1500
# Scaled Margins (Proportional to 4000x2000)
left_margin = 110 * 0.75
right_margin = 110 * 0.75
top_margin = 100 * 0.75
bottom_margin = 67 * 0.75
plot_width = img_width - left_margin - right_margin
plot_height = img_height - top_margin - bottom_margin
# Adjust for aspect ratio preservation
map_aspect = 360.0 / 180.0 # 2.0
plot_aspect = plot_width / plot_height
if plot_aspect > map_aspect:
actual_map_width = plot_height * map_aspect
actual_map_height = plot_height
h_offset = (plot_width - actual_map_width) / 2
v_offset = 0
else:
actual_map_width = plot_width
actual_map_height = plot_width / map_aspect
h_offset = 0
v_offset = (plot_height - actual_map_height) / 2
# Calculate relative position within the plot area
x_in_plot = x - left_margin
y_in_plot = y - top_margin
# Check if click is within the actual map bounds
if (x_in_plot < h_offset or x_in_plot > h_offset + actual_map_width or
y_in_plot < v_offset or y_in_plot > v_offset + actual_map_height):
return None, None, None, "Click outside map area. Please click on the map."
# Calculate relative position within the map (0 to 1)
x_rel = (x_in_plot - h_offset) / actual_map_width
y_rel = (y_in_plot - v_offset) / actual_map_height
# Clamp to [0, 1]
x_rel = max(0, min(1, x_rel))
y_rel = max(0, min(1, y_rel))
# Convert to geographic coordinates
lon = x_rel * 360 - 180
lat = 90 - y_rel * 180
# Find nearest point in df_vis if available
pid = ""
if df_vis is not None:
dists = (df_vis['centre_lat'] - lat)**2 + (df_vis['centre_lon'] - lon)**2
min_idx = dists.idxmin()
nearest_row = df_vis.loc[min_idx]
if dists[min_idx] < 25:
lat = nearest_row['centre_lat']
lon = nearest_row['centre_lon']
pid = nearest_row['product_id']
except Exception as e:
print(f"Error handling click: {e}")
import traceback
traceback.print_exc()
return None, None, None, f"Error: {e}"
return lat, lon, pid, f"Selected Point: ({lat:.4f}, {lon:.4f})"
def download_image_by_location(lat, lon, pid, model_name):
"""Download and return the image at the specified location"""
if lat is None or lon is None:
return None, "Please specify coordinates first."
model, error = get_active_model(model_name)
if error:
return None, error
try:
# Convert to float to ensure proper formatting
lat = float(lat)
lon = float(lon)
# Find Product ID if not provided
if not pid:
df = model.df_embed
lats = pd.to_numeric(df['centre_lat'], errors='coerce')
lons = pd.to_numeric(df['centre_lon'], errors='coerce')
dists = (lats - lat)**2 + (lons - lon)**2
nearest_idx = dists.idxmin()
pid = df.loc[nearest_idx, 'product_id']
# Download image
img_384, _ = download_and_process_image(pid, df_source=model.df_embed, verbose=True)
if img_384 is None:
return None, f"Failed to download image for location ({lat:.4f}, {lon:.4f})"
return img_384, f"Downloaded image at ({lat:.4f}, {lon:.4f})"
except Exception as e:
import traceback
traceback.print_exc()
return None, f"Error: {str(e)}"
def reset_to_global_map():
"""Reset the map to the initial global distribution view"""
img = None
df_vis = None
if 'DINOv2' in models and models['DINOv2'].df_embed is not None:
img, df_vis = plot_global_map_static(models['DINOv2'].df_embed)
else:
img, df_vis = plot_global_map_static(models['SigLIP'].df_embed)
return gr.update(value=img, visible=True), [img], df_vis
def format_results_to_text(results):
if not results:
return "No results found."
txt = f"Top {len(results)} Retrieval Results\n"
txt += "=" * 30 + "\n\n"
for i, res in enumerate(results):
txt += f"Rank: {i+1}\n"
txt += f"Product ID: {res['id']}\n"
txt += f"Location: Latitude {res['lat']:.6f}, Longitude {res['lon']:.6f}\n"
txt += f"Similarity Score: {res['score']:.6f}\n"
txt += "-" * 30 + "\n"
return txt
def save_plot(figs):
if figs is None:
return None
try:
# If it's a single image (initial state), save as png
if isinstance(figs, PILImage.Image):
fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
os.close(fd)
figs.save(path)
return path
# If it's a list/tuple of images [map_img, results_img]
if isinstance(figs, (list, tuple)):
# If only one image in list, save as PNG
if len(figs) == 1 and isinstance(figs[0], PILImage.Image):
fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_')
os.close(fd)
figs[0].save(path)
return path
fd, zip_path = tempfile.mkstemp(suffix='.zip', prefix='earth_explorer_results_')
os.close(fd)
with zipfile.ZipFile(zip_path, 'w') as zipf:
# Save Map
if figs[0] is not None:
map_path = os.path.join(tempfile.gettempdir(), 'map_distribution.png')
figs[0].save(map_path)
zipf.write(map_path, arcname='map_distribution.png')
# Save Results
if len(figs) > 1 and figs[1] is not None:
res_path = os.path.join(tempfile.gettempdir(), 'retrieval_results.png')
figs[1].save(res_path)
zipf.write(res_path, arcname='retrieval_results.png')
# Save Results Text
if len(figs) > 2 and figs[2] is not None:
txt_path = os.path.join(tempfile.gettempdir(), 'results.txt')
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(figs[2])
zipf.write(txt_path, arcname='results.txt')
return zip_path
# Fallback for Plotly figure (if any)
# Create a temporary file
fd, path = tempfile.mkstemp(suffix='.html', prefix='earth_explorer_plot_')
os.close(fd)
# Write to the temporary file
figs.write_html(path)
return path
except Exception as e:
print(f"Error saving: {e}")
return None
# Gradio Blocks Interface
with gr.Blocks(title="EarthEmbeddingExplorer") as demo:
gr.Markdown("# EarthEmbeddingExplorer")
gr.HTML("""
<div style="font-size: 1.2em;">
EarthEmbeddingExplorer is a tool that allows you to search for satellite images of the Earth using natural language descriptions, images, geolocations, or a simple a click on the map. For example, you can type "tropical rainforest" or "coastline with a city," and the system will find locations on Earth that match your description. It then visualizes these locations on a world map and displays the top matching images.
</div>
""")
with gr.Row():
with gr.Column(scale=4):
with gr.Tabs():
with gr.TabItem("Text Search") as tab_text:
model_selector_text = gr.Dropdown(choices=["SigLIP", "FarSLIP"], value="FarSLIP", label="Model")
query_input = gr.Textbox(label="Query", placeholder="e.g., rainforest, glacier")
gr.Examples(
examples=[
["a satellite image of a river around a city"],
["a satellite image of a rainforest"],
["a satellite image of a slum"],
["a satellite image of a glacier"],
["a satellite image of snow covered mountains"]
],
inputs=[query_input],
label="Text Examples"
)
search_btn = gr.Button("Search by Text", variant="primary")
with gr.TabItem("Image Search") as tab_image:
model_selector_img = gr.Dropdown(choices=["SigLIP", "FarSLIP", "SatCLIP", "DINOv2"], value="FarSLIP", label="Model")
gr.Markdown("### Option 1: Upload or Select Image")
image_input = gr.Image(type="pil", label="Upload Image")
gr.Examples(
examples=[
["./examples/example1.png"],
["./examples/example2.png"],
["./examples/example3.png"]
],
inputs=[image_input],
label="Image Examples"
)
gr.Markdown("### Option 2: Click Map or Enter Coordinates")
btn_reset_map_img = gr.Button("π Reset Map to Global View", variant="secondary", size="sm")
with gr.Row():
img_lat = gr.Number(label="Latitude", interactive=True)
img_lon = gr.Number(label="Longitude", interactive=True)
img_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
img_click_status = gr.Markdown("")
btn_download_img = gr.Button("Download Image by Geolocation", variant="secondary")
search_img_btn = gr.Button("Search by Image", variant="primary")
with gr.TabItem("Location Search") as tab_location:
gr.Markdown("Search using **SatCLIP** location encoder.")
gr.Markdown("### Click Map or Enter Coordinates")
btn_reset_map_loc = gr.Button("π Reset Map to Global View", variant="secondary", size="sm")
with gr.Row():
lat_input = gr.Number(label="Latitude", value=30.0, interactive=True)
lon_input = gr.Number(label="Longitude", value=120.0, interactive=True)
loc_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False)
loc_click_status = gr.Markdown("")
gr.Examples(
examples=[
[30.32, 120.15],
[40.7128, -74.0060],
[24.65, 46.71],
[-3.4653, -62.2159],
[64.4, 16.8]
],
inputs=[lat_input, lon_input],
label="Location Examples"
)
search_loc_btn = gr.Button("Search by Location", variant="primary")
threshold_slider = gr.Slider(minimum=1, maximum=30, value=7, step=1, label="Top Percentage (β°)")
status_output = gr.Textbox(label="Status", lines=10)
save_btn = gr.Button("Download Result")
download_file = gr.File(label="Zipped Results", height=40)
with gr.Column(scale=6):
plot_map = gr.Image(
label="Geographical Distribution",
type="pil",
interactive=False,
height=400,
width=800,
visible=True
)
plot_map_interactive = gr.Plot(
label="Geographical Distribution (Interactive)",
visible=False
)
results_plot = gr.Image(label="Top 5 Matched Images", type="pil")
gallery_images = gr.Gallery(label="Top Retrieved Images (Zoom)", columns=3, height="auto")
current_fig = gr.State()
map_data_state = gr.State()
# Initial Load
demo.load(fn=get_initial_plot, outputs=[plot_map, current_fig, map_data_state, plot_map_interactive])
# Reset Map Buttons
btn_reset_map_img.click(
fn=reset_to_global_map,
outputs=[plot_map, current_fig, map_data_state]
)
btn_reset_map_loc.click(
fn=reset_to_global_map,
outputs=[plot_map, current_fig, map_data_state]
)
# Map Click Event - updates Image Search coordinates
plot_map.select(
fn=handle_map_click,
inputs=[map_data_state],
outputs=[img_lat, img_lon, img_pid, img_click_status]
)
# Map Click Event - also updates Location Search coordinates
plot_map.select(
fn=handle_map_click,
inputs=[map_data_state],
outputs=[lat_input, lon_input, loc_pid, loc_click_status]
)
# Download Image by Geolocation
btn_download_img.click(
fn=download_image_by_location,
inputs=[img_lat, img_lon, img_pid, model_selector_img],
outputs=[image_input, img_click_status]
)
# Search Event (Text)
search_btn.click(
fn=search_text,
inputs=[query_input, threshold_slider, model_selector_text],
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
)
# Search Event (Image)
search_img_btn.click(
fn=search_image,
inputs=[image_input, threshold_slider, model_selector_img],
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
)
# Search Event (Location)
search_loc_btn.click(
fn=search_location,
inputs=[lat_input, lon_input, threshold_slider],
outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map]
)
# Save Event
save_btn.click(
fn=save_plot,
inputs=[current_fig],
outputs=[download_file]
)
# Tab Selection Events
def show_static_map():
return gr.update(visible=True), gr.update(visible=False)
tab_text.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
tab_image.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
tab_location.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|