Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import time | |
| import os | |
| import tempfile | |
| import zipfile | |
| import numpy as np | |
| import pandas as pd | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| # Import custom modules | |
| from models.siglip_model import SigLIPModel | |
| from models.satclip_model import SatCLIPModel | |
| from models.farslip_model import FarSLIPModel | |
| from models.dinov2_model import DINOv2Model | |
| from models.load_config import load_and_process_config | |
| from visualize import format_results_for_gallery, plot_top5_overview, plot_location_distribution, plot_global_map_static, plot_geographic_distribution | |
| from data_utils import download_and_process_image, get_esri_satellite_image, get_placeholder_image | |
| from PIL import Image as PILImage | |
| from PIL import ImageDraw, ImageFont | |
| # Configuration | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Running on device: {device}") | |
| # Load and process configuration | |
| config = load_and_process_config() | |
| print(config) | |
| # Initialize Models | |
| print("Initializing models...") | |
| models = {} | |
| # DINOv2 | |
| try: | |
| if config and 'dinov2' in config: | |
| models['DINOv2'] = DINOv2Model( | |
| ckpt_path=config['dinov2'].get('ckpt_path'), | |
| embedding_path=config['dinov2'].get('embedding_path'), | |
| device=device | |
| ) | |
| else: | |
| models['DINOv2'] = DINOv2Model(device=device) | |
| except Exception as e: | |
| print(f"Failed to load DINOv2: {e}") | |
| # SigLIP | |
| try: | |
| if config and 'siglip' in config: | |
| models['SigLIP'] = SigLIPModel( | |
| ckpt_path=config['siglip'].get('ckpt_path'), | |
| tokenizer_path=config['siglip'].get('tokenizer_path'), | |
| embedding_path=config['siglip'].get('embedding_path'), | |
| device=device | |
| ) | |
| else: | |
| models['SigLIP'] = SigLIPModel(device=device) | |
| except Exception as e: | |
| print(f"Failed to load SigLIP: {e}") | |
| # SatCLIP | |
| try: | |
| if config and 'satclip' in config: | |
| models['SatCLIP'] = SatCLIPModel( | |
| ckpt_path=config['satclip'].get('ckpt_path'), | |
| embedding_path=config['satclip'].get('embedding_path'), | |
| device=device | |
| ) | |
| else: | |
| models['SatCLIP'] = SatCLIPModel(device=device) | |
| except Exception as e: | |
| print(f"Failed to load SatCLIP: {e}") | |
| # FarSLIP | |
| try: | |
| if config and 'farslip' in config: | |
| models['FarSLIP'] = FarSLIPModel( | |
| ckpt_path=config['farslip'].get('ckpt_path'), | |
| model_name=config['farslip'].get('model_name'), | |
| embedding_path=config['farslip'].get('embedding_path'), | |
| device=device | |
| ) | |
| else: | |
| models['FarSLIP'] = FarSLIPModel(device=device) | |
| except Exception as e: | |
| print(f"Failed to load FarSLIP: {e}") | |
| def get_active_model(model_name): | |
| if model_name not in models: | |
| return None, f"Model {model_name} not loaded." | |
| return models[model_name], None | |
| def combine_images(img1, img2): | |
| if img1 is None: return img2 | |
| if img2 is None: return img1 | |
| # Resize to match width | |
| w1, h1 = img1.size | |
| w2, h2 = img2.size | |
| new_w = max(w1, w2) | |
| new_h1 = int(h1 * new_w / w1) | |
| new_h2 = int(h2 * new_w / w2) | |
| img1 = img1.resize((new_w, new_h1)) | |
| img2 = img2.resize((new_w, new_h2)) | |
| dst = PILImage.new('RGB', (new_w, new_h1 + new_h2), (255, 255, 255)) | |
| dst.paste(img1, (0, 0)) | |
| dst.paste(img2, (0, new_h1)) | |
| return dst | |
| def create_text_image(text, size=(384, 384)): | |
| img = PILImage.new('RGB', size, color=(240, 240, 240)) | |
| d = ImageDraw.Draw(img) | |
| # Try to load a font, fallback to default | |
| try: | |
| # Try to find a font that supports larger size | |
| font = ImageFont.truetype("DejaVuSans.ttf", 40) | |
| except: | |
| font = ImageFont.load_default() | |
| # Wrap text simply | |
| margin = 20 | |
| offset = 100 | |
| for line in text.split(','): | |
| d.text((margin, offset), line.strip(), font=font, fill=(0, 0, 0)) | |
| offset += 50 | |
| d.text((margin, offset + 50), "Text Query", font=font, fill=(0, 0, 255)) | |
| return img | |
| def fetch_top_k_images(top_indices, probs, df_embed, query_text=None): | |
| """ | |
| Fetches top-k images using actual dataset download (ModelScope) via download_and_process_image. | |
| """ | |
| results = [] | |
| # We can run this in parallel | |
| with ThreadPoolExecutor(max_workers=5) as executor: | |
| future_to_idx = {} | |
| for i, idx in enumerate(top_indices): | |
| row = df_embed.iloc[idx] | |
| pid = row['product_id'] | |
| # Use download_and_process_image to get real data | |
| future = executor.submit(download_and_process_image, pid, df_source=df_embed, verbose=False) | |
| future_to_idx[future] = idx | |
| for future in as_completed(future_to_idx): | |
| idx = future_to_idx[future] | |
| try: | |
| img_384, img_full = future.result() | |
| if img_384 is None: | |
| # Fallback to Esri if download fails | |
| print(f"Download failed for idx {idx}, falling back to Esri...") | |
| row = df_embed.iloc[idx] | |
| img_384 = get_esri_satellite_image(row['centre_lat'], row['centre_lon'], score=probs[idx], rank=0, query=query_text) | |
| img_full = img_384 | |
| row = df_embed.iloc[idx] | |
| results.append({ | |
| 'image_384': img_384, | |
| 'image_full': img_full, | |
| 'score': probs[idx], | |
| 'lat': row['centre_lat'], | |
| 'lon': row['centre_lon'], | |
| 'id': row['product_id'] | |
| }) | |
| except Exception as e: | |
| print(f"Error fetching image for idx {idx}: {e}") | |
| # Sort results by score descending (since futures complete in random order) | |
| results.sort(key=lambda x: x['score'], reverse=True) | |
| return results | |
| def get_all_results_metadata(model, filtered_indices, probs): | |
| if len(filtered_indices) == 0: | |
| return [] | |
| # Sort by score descending | |
| filtered_scores = probs[filtered_indices] | |
| sorted_order = np.argsort(filtered_scores)[::-1] | |
| sorted_indices = filtered_indices[sorted_order] | |
| # Extract from DataFrame | |
| df_results = model.df_embed.iloc[sorted_indices].copy() | |
| df_results['score'] = probs[sorted_indices] | |
| # Rename columns | |
| df_results = df_results.rename(columns={'product_id': 'id', 'centre_lat': 'lat', 'centre_lon': 'lon'}) | |
| # Convert to list of dicts | |
| return df_results[['id', 'lat', 'lon', 'score']].to_dict('records') | |
| def search_text(query, threshold, model_name): | |
| model, error = get_active_model(model_name) | |
| if error: | |
| yield None, None, error, None, None, None, None | |
| return | |
| if not query: | |
| yield None, None, "Please enter a query.", None, None, None, None | |
| return | |
| try: | |
| timings = {} | |
| # 1. Encode Text | |
| yield None, None, "Encoding text...", None, None, None, None | |
| t0 = time.time() | |
| text_features = model.encode_text(query) | |
| timings['Encoding'] = time.time() - t0 | |
| if text_features is None: | |
| yield None, None, "Model does not support text encoding or is not initialized.", None, None, None, None | |
| return | |
| # 2. Search | |
| yield None, None, "Encoding text... ✓\nRetrieving similar images...", None, None, None, None | |
| t0 = time.time() | |
| probs, filtered_indices, top_indices = model.search(text_features, top_percent=threshold/1000.0) | |
| timings['Retrieval'] = time.time() - t0 | |
| if probs is None: | |
| yield None, None, "Search failed (embeddings missing?).", None, None, None, None | |
| return | |
| # Show geographic distribution (not timed) | |
| df_embed = model.df_embed | |
| geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to "{query}" ({model_name})') | |
| # 3. Download Images | |
| yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) | |
| t0 = time.time() | |
| top_indices = top_indices[:10] | |
| results = fetch_top_k_images(top_indices, probs, df_embed, query_text=query) | |
| timings['Download'] = time.time() - t0 | |
| # 4. Visualize - keep geo_dist_map visible | |
| yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) | |
| t0 = time.time() | |
| fig_results = plot_top5_overview(None, results, query_info=query) | |
| gallery_items = format_results_for_gallery(results) | |
| timings['Visualization'] = time.time() - t0 | |
| # 5. Generate Final Status | |
| timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n" | |
| status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results) | |
| all_results = get_all_results_metadata(model, filtered_indices, probs) | |
| results_txt = format_results_to_text(all_results) | |
| yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| yield None, None, f"Error: {str(e)}", None, None, None, None | |
| def search_image(image_input, threshold, model_name): | |
| model, error = get_active_model(model_name) | |
| if error: | |
| yield None, None, error, None, None, None, None | |
| return | |
| if image_input is None: | |
| yield None, None, "Please upload an image.", None, None, None, None | |
| return | |
| try: | |
| timings = {} | |
| # 1. Encode Image | |
| yield None, None, "Encoding image...", None, None, None, None | |
| t0 = time.time() | |
| image_features = model.encode_image(image_input) | |
| timings['Encoding'] = time.time() - t0 | |
| if image_features is None: | |
| yield None, None, "Model does not support image encoding.", None, None, None, None | |
| return | |
| # 2. Search | |
| yield None, None, "Encoding image... ✓\nRetrieving similar images...", None, None, None, None | |
| t0 = time.time() | |
| probs, filtered_indices, top_indices = model.search(image_features, top_percent=threshold/1000.0) | |
| timings['Retrieval'] = time.time() - t0 | |
| # Show geographic distribution (not timed) | |
| df_embed = model.df_embed | |
| geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Input Image ({model_name})') | |
| # 3. Download Images | |
| yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) | |
| t0 = time.time() | |
| top_indices = top_indices[:6] | |
| results = fetch_top_k_images(top_indices, probs, df_embed, query_text="Image Query") | |
| timings['Download'] = time.time() - t0 | |
| # 4. Visualize - keep geo_dist_map visible | |
| yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) | |
| t0 = time.time() | |
| fig_results = plot_top5_overview(image_input, results, query_info="Image Query") | |
| gallery_items = format_results_for_gallery(results) | |
| timings['Visualization'] = time.time() - t0 | |
| # 5. Generate Final Status | |
| timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n" | |
| status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results) | |
| all_results = get_all_results_metadata(model, filtered_indices, probs) | |
| results_txt = format_results_to_text(all_results[:50]) | |
| yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| yield None, None, f"Error: {str(e)}", None, None, None, None | |
| def search_location(lat, lon, threshold): | |
| model_name = "SatCLIP" | |
| model, error = get_active_model(model_name) | |
| if error: | |
| yield None, None, error, None, None, None, None | |
| return | |
| try: | |
| timings = {} | |
| # 1. Encode Location | |
| yield None, None, "Encoding location...", None, None, None, None | |
| t0 = time.time() | |
| loc_features = model.encode_location(float(lat), float(lon)) | |
| timings['Encoding'] = time.time() - t0 | |
| if loc_features is None: | |
| yield None, None, "Location encoding failed.", None, None, None, None | |
| return | |
| # 2. Search | |
| yield None, None, "Encoding location... ✓\nRetrieving similar images...", None, None, None, None | |
| t0 = time.time() | |
| probs, filtered_indices, top_indices = model.search(loc_features, top_percent=threshold/100.0) | |
| timings['Retrieval'] = time.time() - t0 | |
| # 3. Generate Distribution Map (not timed for location distribution) | |
| yield None, None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map...", None, None, None, None | |
| df_embed = model.df_embed | |
| top_10_indices = top_indices[:10] | |
| top_10_results = [] | |
| for idx in top_10_indices: | |
| row = df_embed.iloc[idx] | |
| top_10_results.append({'lat': row['centre_lat'], 'lon': row['centre_lon']}) | |
| # Show geographic distribution (not timed) | |
| geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Location ({lat}, {lon})') | |
| # 4. Download Images | |
| yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) | |
| t0 = time.time() | |
| top_6_indices = top_indices[:6] | |
| results = fetch_top_k_images(top_6_indices, probs, df_embed, query_text=f"Loc: {lat},{lon}") | |
| # Get query tile | |
| query_tile = None | |
| try: | |
| lats = pd.to_numeric(df_embed['centre_lat'], errors='coerce') | |
| lons = pd.to_numeric(df_embed['centre_lon'], errors='coerce') | |
| dists = (lats - float(lat))**2 + (lons - float(lon))**2 | |
| nearest_idx = dists.idxmin() | |
| pid = df_embed.loc[nearest_idx, 'product_id'] | |
| query_tile, _ = download_and_process_image(pid, df_source=df_embed, verbose=False) | |
| except Exception as e: | |
| print(f"Error fetching nearest MajorTOM image: {e}") | |
| if query_tile is None: | |
| query_tile = get_placeholder_image(f"Query Location\n({lat}, {lon})") | |
| timings['Download'] = time.time() - t0 | |
| # 5. Visualize - keep geo_dist_map visible | |
| yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) | |
| t0 = time.time() | |
| fig_results = plot_top5_overview(query_tile, results, query_info=f"Loc: {lat},{lon}") | |
| gallery_items = format_results_for_gallery(results) | |
| timings['Visualization'] = time.time() - t0 | |
| # 6. Generate Final Status | |
| timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n" | |
| status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results) | |
| all_results = get_all_results_metadata(model, filtered_indices, probs) | |
| results_txt = format_results_to_text(all_results) | |
| yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| yield None, None, f"Error: {str(e)}", None, None, None, None | |
| def generate_status_msg(count, threshold, results): | |
| status_msg = f"Found {count} matches in top {threshold*100:.0f}‰.\n\nTop {len(results)} similar images:\n" | |
| for i, res in enumerate(results[:3]): | |
| status_msg += f"{i+1}. Product ID: {res['id']}, Location: ({res['lat']:.4f}, {res['lon']:.4f}), Score: {res['score']:.4f}\n" | |
| return status_msg | |
| def get_initial_plot(): | |
| # Use FarSLIP as default for initial plot, fallback to SigLIP | |
| df_vis = None | |
| img = None | |
| if 'DINOv2' in models and models['DINOv2'].df_embed is not None: | |
| img, df_vis = plot_global_map_static(models['DINOv2'].df_embed) | |
| # fig = plot_global_map(models['FarSLIP'].df_embed) | |
| else: | |
| img, df_vis = plot_global_map_static(models['SigLIP'].df_embed) | |
| return gr.update(value=img, visible=True), [img], df_vis, gr.update(visible=False) | |
| def handle_map_click(evt: gr.SelectData, df_vis): | |
| if evt is None: | |
| return None, None, None, "No point selected." | |
| try: | |
| x, y = evt.index[0], evt.index[1] | |
| # Image dimensions (New) | |
| img_width = 3000 | |
| img_height = 1500 | |
| # Scaled Margins (Proportional to 4000x2000) | |
| left_margin = 110 * 0.75 | |
| right_margin = 110 * 0.75 | |
| top_margin = 100 * 0.75 | |
| bottom_margin = 67 * 0.75 | |
| plot_width = img_width - left_margin - right_margin | |
| plot_height = img_height - top_margin - bottom_margin | |
| # Adjust for aspect ratio preservation | |
| map_aspect = 360.0 / 180.0 # 2.0 | |
| plot_aspect = plot_width / plot_height | |
| if plot_aspect > map_aspect: | |
| actual_map_width = plot_height * map_aspect | |
| actual_map_height = plot_height | |
| h_offset = (plot_width - actual_map_width) / 2 | |
| v_offset = 0 | |
| else: | |
| actual_map_width = plot_width | |
| actual_map_height = plot_width / map_aspect | |
| h_offset = 0 | |
| v_offset = (plot_height - actual_map_height) / 2 | |
| # Calculate relative position within the plot area | |
| x_in_plot = x - left_margin | |
| y_in_plot = y - top_margin | |
| # Check if click is within the actual map bounds | |
| if (x_in_plot < h_offset or x_in_plot > h_offset + actual_map_width or | |
| y_in_plot < v_offset or y_in_plot > v_offset + actual_map_height): | |
| return None, None, None, "Click outside map area. Please click on the map." | |
| # Calculate relative position within the map (0 to 1) | |
| x_rel = (x_in_plot - h_offset) / actual_map_width | |
| y_rel = (y_in_plot - v_offset) / actual_map_height | |
| # Clamp to [0, 1] | |
| x_rel = max(0, min(1, x_rel)) | |
| y_rel = max(0, min(1, y_rel)) | |
| # Convert to geographic coordinates | |
| lon = x_rel * 360 - 180 | |
| lat = 90 - y_rel * 180 | |
| # Find nearest point in df_vis if available | |
| pid = "" | |
| if df_vis is not None: | |
| dists = (df_vis['centre_lat'] - lat)**2 + (df_vis['centre_lon'] - lon)**2 | |
| min_idx = dists.idxmin() | |
| nearest_row = df_vis.loc[min_idx] | |
| if dists[min_idx] < 25: | |
| lat = nearest_row['centre_lat'] | |
| lon = nearest_row['centre_lon'] | |
| pid = nearest_row['product_id'] | |
| except Exception as e: | |
| print(f"Error handling click: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, None, f"Error: {e}" | |
| return lat, lon, pid, f"Selected Point: ({lat:.4f}, {lon:.4f})" | |
| def download_image_by_location(lat, lon, pid, model_name): | |
| """Download and return the image at the specified location""" | |
| if lat is None or lon is None: | |
| return None, "Please specify coordinates first." | |
| model, error = get_active_model(model_name) | |
| if error: | |
| return None, error | |
| try: | |
| # Convert to float to ensure proper formatting | |
| lat = float(lat) | |
| lon = float(lon) | |
| # Find Product ID if not provided | |
| if not pid: | |
| df = model.df_embed | |
| lats = pd.to_numeric(df['centre_lat'], errors='coerce') | |
| lons = pd.to_numeric(df['centre_lon'], errors='coerce') | |
| dists = (lats - lat)**2 + (lons - lon)**2 | |
| nearest_idx = dists.idxmin() | |
| pid = df.loc[nearest_idx, 'product_id'] | |
| # Download image | |
| img_384, _ = download_and_process_image(pid, df_source=model.df_embed, verbose=True) | |
| if img_384 is None: | |
| return None, f"Failed to download image for location ({lat:.4f}, {lon:.4f})" | |
| return img_384, f"Downloaded image at ({lat:.4f}, {lon:.4f})" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"Error: {str(e)}" | |
| def reset_to_global_map(): | |
| """Reset the map to the initial global distribution view""" | |
| img = None | |
| df_vis = None | |
| if 'DINOv2' in models and models['DINOv2'].df_embed is not None: | |
| img, df_vis = plot_global_map_static(models['DINOv2'].df_embed) | |
| else: | |
| img, df_vis = plot_global_map_static(models['SigLIP'].df_embed) | |
| return gr.update(value=img, visible=True), [img], df_vis | |
| def format_results_to_text(results): | |
| if not results: | |
| return "No results found." | |
| txt = f"Top {len(results)} Retrieval Results\n" | |
| txt += "=" * 30 + "\n\n" | |
| for i, res in enumerate(results): | |
| txt += f"Rank: {i+1}\n" | |
| txt += f"Product ID: {res['id']}\n" | |
| txt += f"Location: Latitude {res['lat']:.6f}, Longitude {res['lon']:.6f}\n" | |
| txt += f"Similarity Score: {res['score']:.6f}\n" | |
| txt += "-" * 30 + "\n" | |
| return txt | |
| def save_plot(figs): | |
| if figs is None: | |
| return None | |
| try: | |
| # If it's a single image (initial state), save as png | |
| if isinstance(figs, PILImage.Image): | |
| fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_') | |
| os.close(fd) | |
| figs.save(path) | |
| return path | |
| # If it's a list/tuple of images [map_img, results_img] | |
| if isinstance(figs, (list, tuple)): | |
| # If only one image in list, save as PNG | |
| if len(figs) == 1 and isinstance(figs[0], PILImage.Image): | |
| fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_') | |
| os.close(fd) | |
| figs[0].save(path) | |
| return path | |
| fd, zip_path = tempfile.mkstemp(suffix='.zip', prefix='earth_explorer_results_') | |
| os.close(fd) | |
| with zipfile.ZipFile(zip_path, 'w') as zipf: | |
| # Save Map | |
| if figs[0] is not None: | |
| map_path = os.path.join(tempfile.gettempdir(), 'map_distribution.png') | |
| figs[0].save(map_path) | |
| zipf.write(map_path, arcname='map_distribution.png') | |
| # Save Results | |
| if len(figs) > 1 and figs[1] is not None: | |
| res_path = os.path.join(tempfile.gettempdir(), 'retrieval_results.png') | |
| figs[1].save(res_path) | |
| zipf.write(res_path, arcname='retrieval_results.png') | |
| # Save Results Text | |
| if len(figs) > 2 and figs[2] is not None: | |
| txt_path = os.path.join(tempfile.gettempdir(), 'results.txt') | |
| with open(txt_path, 'w', encoding='utf-8') as f: | |
| f.write(figs[2]) | |
| zipf.write(txt_path, arcname='results.txt') | |
| return zip_path | |
| # Fallback for Plotly figure (if any) | |
| # Create a temporary file | |
| fd, path = tempfile.mkstemp(suffix='.html', prefix='earth_explorer_plot_') | |
| os.close(fd) | |
| # Write to the temporary file | |
| figs.write_html(path) | |
| return path | |
| except Exception as e: | |
| print(f"Error saving: {e}") | |
| return None | |
| # Gradio Blocks Interface | |
| with gr.Blocks(title="EarthEmbeddingExplorer") as demo: | |
| gr.Markdown("# EarthEmbeddingExplorer") | |
| gr.HTML(""" | |
| <div style="font-size: 1.2em;"> | |
| EarthEmbeddingExplorer is a tool that allows you to search for satellite images of the Earth using natural language descriptions, images, geolocations, or a simple a click on the map. For example, you can type "tropical rainforest" or "coastline with a city," and the system will find locations on Earth that match your description. It then visualizes these locations on a world map and displays the top matching images. | |
| </div> | |
| <div style="display: flex; gap: 0.2em; align-items: center; justify-content: center;"> | |
| <a href="https://www.modelscope.cn/studios/VoyagerX/EarthExplorer"><img src="https://img.shields.io/badge/Open in ModelScope.cn-xGPU-624aff"></a> | |
| <a href="https://modelscope.cn/datasets/VoyagerX/EarthEmbeddings"><img src="https://img.shields.io/badge/👾 MS-Dataset-624aff"></a> | |
| <a href="https://huggingface.co/spaces/ML4Sustain/EarthExplorer"><img src="https://img.shields.io/badge/Open in HF Space-CPU-FFD21E"></a> | |
| <a href="https://huggingface.co/datasets/ML4Sustain/EarthEmbeddings"><img src="https://img.shields.io/badge/🤗 HF-Dataset-FFD21E"></a> | |
| <a href="https://huggingface.co/spaces/ML4Sustain/EarthExplorer/blob/main/Tutorial.md"> <img src="https://img.shields.io/badge/Tutorial-📖-007bff"> </a> | |
| <a href="https://modelscope.cn/studios/VoyagerX/EarthExplorer/file/view/master/Tutorial_zh.md?status=1"> <img src="https://img.shields.io/badge/中文教程-📖-007bff"> </a> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| with gr.Tabs(): | |
| with gr.TabItem("Text Search") as tab_text: | |
| model_selector_text = gr.Dropdown(choices=["SigLIP", "FarSLIP"], value="FarSLIP", label="Model") | |
| query_input = gr.Textbox(label="Query", placeholder="e.g., rainforest, glacier") | |
| gr.Examples( | |
| examples=[ | |
| ["a satellite image of a river around a city"], | |
| ["a satellite image of a rainforest"], | |
| ["a satellite image of a slum"], | |
| ["a satellite image of a glacier"], | |
| ["a satellite image of snow covered mountains"] | |
| ], | |
| inputs=[query_input], | |
| label="Text Examples" | |
| ) | |
| search_btn = gr.Button("Search by Text", variant="primary") | |
| with gr.TabItem("Image Search") as tab_image: | |
| model_selector_img = gr.Dropdown(choices=["SigLIP", "FarSLIP", "SatCLIP", "DINOv2"], value="FarSLIP", label="Model") | |
| gr.Markdown("### Option 1: Upload or Select Image") | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| gr.Examples( | |
| examples=[ | |
| ["./examples/example1.png"], | |
| ["./examples/example2.png"], | |
| ["./examples/example3.png"] | |
| ], | |
| inputs=[image_input], | |
| label="Image Examples" | |
| ) | |
| gr.Markdown("### Option 2: Click Map or Enter Coordinates") | |
| btn_reset_map_img = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm") | |
| with gr.Row(): | |
| img_lat = gr.Number(label="Latitude", interactive=True) | |
| img_lon = gr.Number(label="Longitude", interactive=True) | |
| img_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False) | |
| img_click_status = gr.Markdown("") | |
| btn_download_img = gr.Button("Download Image by Geolocation", variant="secondary") | |
| search_img_btn = gr.Button("Search by Image", variant="primary") | |
| with gr.TabItem("Location Search") as tab_location: | |
| gr.Markdown("Search using **SatCLIP** location encoder.") | |
| gr.Markdown("### Click Map or Enter Coordinates") | |
| btn_reset_map_loc = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm") | |
| with gr.Row(): | |
| lat_input = gr.Number(label="Latitude", value=30.0, interactive=True) | |
| lon_input = gr.Number(label="Longitude", value=120.0, interactive=True) | |
| loc_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False) | |
| loc_click_status = gr.Markdown("") | |
| gr.Examples( | |
| examples=[ | |
| [30.32, 120.15], | |
| [40.7128, -74.0060], | |
| [24.65, 46.71], | |
| [-3.4653, -62.2159], | |
| [64.4, 16.8] | |
| ], | |
| inputs=[lat_input, lon_input], | |
| label="Location Examples" | |
| ) | |
| search_loc_btn = gr.Button("Search by Location", variant="primary") | |
| threshold_slider = gr.Slider(minimum=1, maximum=30, value=7, step=1, label="Top Percentage (‰)") | |
| status_output = gr.Textbox(label="Status", lines=10) | |
| save_btn = gr.Button("Download Result") | |
| download_file = gr.File(label="Zipped Results", height=40) | |
| with gr.Column(scale=6): | |
| plot_map = gr.Image( | |
| label="Geographical Distribution", | |
| type="pil", | |
| interactive=False, | |
| height=400, | |
| width=800, | |
| visible=True | |
| ) | |
| plot_map_interactive = gr.Plot( | |
| label="Geographical Distribution (Interactive)", | |
| visible=False | |
| ) | |
| results_plot = gr.Image(label="Top 5 Matched Images", type="pil") | |
| gallery_images = gr.Gallery(label="Top Retrieved Images (Zoom)", columns=3, height="auto") | |
| current_fig = gr.State() | |
| map_data_state = gr.State() | |
| # Initial Load | |
| demo.load(fn=get_initial_plot, outputs=[plot_map, current_fig, map_data_state, plot_map_interactive]) | |
| # Reset Map Buttons | |
| btn_reset_map_img.click( | |
| fn=reset_to_global_map, | |
| outputs=[plot_map, current_fig, map_data_state] | |
| ) | |
| btn_reset_map_loc.click( | |
| fn=reset_to_global_map, | |
| outputs=[plot_map, current_fig, map_data_state] | |
| ) | |
| # Map Click Event - updates Image Search coordinates | |
| plot_map.select( | |
| fn=handle_map_click, | |
| inputs=[map_data_state], | |
| outputs=[img_lat, img_lon, img_pid, img_click_status] | |
| ) | |
| # Map Click Event - also updates Location Search coordinates | |
| plot_map.select( | |
| fn=handle_map_click, | |
| inputs=[map_data_state], | |
| outputs=[lat_input, lon_input, loc_pid, loc_click_status] | |
| ) | |
| # Download Image by Geolocation | |
| btn_download_img.click( | |
| fn=download_image_by_location, | |
| inputs=[img_lat, img_lon, img_pid, model_selector_img], | |
| outputs=[image_input, img_click_status] | |
| ) | |
| # Search Event (Text) | |
| search_btn.click( | |
| fn=search_text, | |
| inputs=[query_input, threshold_slider, model_selector_text], | |
| outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map] | |
| ) | |
| # Search Event (Image) | |
| search_img_btn.click( | |
| fn=search_image, | |
| inputs=[image_input, threshold_slider, model_selector_img], | |
| outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map] | |
| ) | |
| # Search Event (Location) | |
| search_loc_btn.click( | |
| fn=search_location, | |
| inputs=[lat_input, lon_input, threshold_slider], | |
| outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map] | |
| ) | |
| # Save Event | |
| save_btn.click( | |
| fn=save_plot, | |
| inputs=[current_fig], | |
| outputs=[download_file] | |
| ) | |
| # Tab Selection Events | |
| def show_static_map(): | |
| return gr.update(visible=True), gr.update(visible=False) | |
| tab_text.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive]) | |
| tab_image.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive]) | |
| tab_location.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |