import gradio as gr import torch import time import os import tempfile import zipfile import numpy as np import pandas as pd from concurrent.futures import ThreadPoolExecutor, as_completed # Import custom modules from models.siglip_model import SigLIPModel from models.satclip_model import SatCLIPModel from models.farslip_model import FarSLIPModel from models.dinov2_model import DINOv2Model from models.load_config import load_and_process_config from visualize import format_results_for_gallery, plot_top5_overview, plot_location_distribution, plot_global_map_static, plot_geographic_distribution from data_utils import download_and_process_image, get_esri_satellite_image, get_placeholder_image from PIL import Image as PILImage from PIL import ImageDraw, ImageFont # Configuration device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Running on device: {device}") # Load and process configuration config = load_and_process_config() print(config) # Initialize Models print("Initializing models...") models = {} # DINOv2 try: if config and 'dinov2' in config: models['DINOv2'] = DINOv2Model( ckpt_path=config['dinov2'].get('ckpt_path'), embedding_path=config['dinov2'].get('embedding_path'), device=device ) else: models['DINOv2'] = DINOv2Model(device=device) except Exception as e: print(f"Failed to load DINOv2: {e}") # SigLIP try: if config and 'siglip' in config: models['SigLIP'] = SigLIPModel( ckpt_path=config['siglip'].get('ckpt_path'), tokenizer_path=config['siglip'].get('tokenizer_path'), embedding_path=config['siglip'].get('embedding_path'), device=device ) else: models['SigLIP'] = SigLIPModel(device=device) except Exception as e: print(f"Failed to load SigLIP: {e}") # SatCLIP try: if config and 'satclip' in config: models['SatCLIP'] = SatCLIPModel( ckpt_path=config['satclip'].get('ckpt_path'), embedding_path=config['satclip'].get('embedding_path'), device=device ) else: models['SatCLIP'] = SatCLIPModel(device=device) except Exception as e: print(f"Failed to load SatCLIP: {e}") # FarSLIP try: if config and 'farslip' in config: models['FarSLIP'] = FarSLIPModel( ckpt_path=config['farslip'].get('ckpt_path'), model_name=config['farslip'].get('model_name'), embedding_path=config['farslip'].get('embedding_path'), device=device ) else: models['FarSLIP'] = FarSLIPModel(device=device) except Exception as e: print(f"Failed to load FarSLIP: {e}") def get_active_model(model_name): if model_name not in models: return None, f"Model {model_name} not loaded." return models[model_name], None def combine_images(img1, img2): if img1 is None: return img2 if img2 is None: return img1 # Resize to match width w1, h1 = img1.size w2, h2 = img2.size new_w = max(w1, w2) new_h1 = int(h1 * new_w / w1) new_h2 = int(h2 * new_w / w2) img1 = img1.resize((new_w, new_h1)) img2 = img2.resize((new_w, new_h2)) dst = PILImage.new('RGB', (new_w, new_h1 + new_h2), (255, 255, 255)) dst.paste(img1, (0, 0)) dst.paste(img2, (0, new_h1)) return dst def create_text_image(text, size=(384, 384)): img = PILImage.new('RGB', size, color=(240, 240, 240)) d = ImageDraw.Draw(img) # Try to load a font, fallback to default try: # Try to find a font that supports larger size font = ImageFont.truetype("DejaVuSans.ttf", 40) except: font = ImageFont.load_default() # Wrap text simply margin = 20 offset = 100 for line in text.split(','): d.text((margin, offset), line.strip(), font=font, fill=(0, 0, 0)) offset += 50 d.text((margin, offset + 50), "Text Query", font=font, fill=(0, 0, 255)) return img def fetch_top_k_images(top_indices, probs, df_embed, query_text=None): """ Fetches top-k images using actual dataset download (ModelScope) via download_and_process_image. """ results = [] # We can run this in parallel with ThreadPoolExecutor(max_workers=5) as executor: future_to_idx = {} for i, idx in enumerate(top_indices): row = df_embed.iloc[idx] pid = row['product_id'] # Use download_and_process_image to get real data future = executor.submit(download_and_process_image, pid, df_source=df_embed, verbose=False) future_to_idx[future] = idx for future in as_completed(future_to_idx): idx = future_to_idx[future] try: img_384, img_full = future.result() if img_384 is None: # Fallback to Esri if download fails print(f"Download failed for idx {idx}, falling back to Esri...") row = df_embed.iloc[idx] img_384 = get_esri_satellite_image(row['centre_lat'], row['centre_lon'], score=probs[idx], rank=0, query=query_text) img_full = img_384 row = df_embed.iloc[idx] results.append({ 'image_384': img_384, 'image_full': img_full, 'score': probs[idx], 'lat': row['centre_lat'], 'lon': row['centre_lon'], 'id': row['product_id'] }) except Exception as e: print(f"Error fetching image for idx {idx}: {e}") # Sort results by score descending (since futures complete in random order) results.sort(key=lambda x: x['score'], reverse=True) return results def get_all_results_metadata(model, filtered_indices, probs): if len(filtered_indices) == 0: return [] # Sort by score descending filtered_scores = probs[filtered_indices] sorted_order = np.argsort(filtered_scores)[::-1] sorted_indices = filtered_indices[sorted_order] # Extract from DataFrame df_results = model.df_embed.iloc[sorted_indices].copy() df_results['score'] = probs[sorted_indices] # Rename columns df_results = df_results.rename(columns={'product_id': 'id', 'centre_lat': 'lat', 'centre_lon': 'lon'}) # Convert to list of dicts return df_results[['id', 'lat', 'lon', 'score']].to_dict('records') def search_text(query, threshold, model_name): model, error = get_active_model(model_name) if error: yield None, None, error, None, None, None, None return if not query: yield None, None, "Please enter a query.", None, None, None, None return try: timings = {} # 1. Encode Text yield None, None, "Encoding text...", None, None, None, None t0 = time.time() text_features = model.encode_text(query) timings['Encoding'] = time.time() - t0 if text_features is None: yield None, None, "Model does not support text encoding or is not initialized.", None, None, None, None return # 2. Search yield None, None, "Encoding text... ✓\nRetrieving similar images...", None, None, None, None t0 = time.time() probs, filtered_indices, top_indices = model.search(text_features, top_percent=threshold/1000.0) timings['Retrieval'] = time.time() - t0 if probs is None: yield None, None, "Search failed (embeddings missing?).", None, None, None, None return # Show geographic distribution (not timed) df_embed = model.df_embed geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to "{query}" ({model_name})') # 3. Download Images yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) t0 = time.time() top_indices = top_indices[:10] results = fetch_top_k_images(top_indices, probs, df_embed, query_text=query) timings['Download'] = time.time() - t0 # 4. Visualize - keep geo_dist_map visible yield gr.update(visible=False), None, "Encoding text... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) t0 = time.time() fig_results = plot_top5_overview(None, results, query_info=query) gallery_items = format_results_for_gallery(results) timings['Visualization'] = time.time() - t0 # 5. Generate Final Status timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n" status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results) all_results = get_all_results_metadata(model, filtered_indices, probs) results_txt = format_results_to_text(all_results) yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True) except Exception as e: import traceback traceback.print_exc() yield None, None, f"Error: {str(e)}", None, None, None, None def search_image(image_input, threshold, model_name): model, error = get_active_model(model_name) if error: yield None, None, error, None, None, None, None return if image_input is None: yield None, None, "Please upload an image.", None, None, None, None return try: timings = {} # 1. Encode Image yield None, None, "Encoding image...", None, None, None, None t0 = time.time() image_features = model.encode_image(image_input) timings['Encoding'] = time.time() - t0 if image_features is None: yield None, None, "Model does not support image encoding.", None, None, None, None return # 2. Search yield None, None, "Encoding image... ✓\nRetrieving similar images...", None, None, None, None t0 = time.time() probs, filtered_indices, top_indices = model.search(image_features, top_percent=threshold/1000.0) timings['Retrieval'] = time.time() - t0 # Show geographic distribution (not timed) df_embed = model.df_embed geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Input Image ({model_name})') # 3. Download Images yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) t0 = time.time() top_indices = top_indices[:6] results = fetch_top_k_images(top_indices, probs, df_embed, query_text="Image Query") timings['Download'] = time.time() - t0 # 4. Visualize - keep geo_dist_map visible yield gr.update(visible=False), None, "Encoding image... ✓\nRetrieving similar images... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) t0 = time.time() fig_results = plot_top5_overview(image_input, results, query_info="Image Query") gallery_items = format_results_for_gallery(results) timings['Visualization'] = time.time() - t0 # 5. Generate Final Status timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n" status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results) all_results = get_all_results_metadata(model, filtered_indices, probs) results_txt = format_results_to_text(all_results[:50]) yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True) except Exception as e: import traceback traceback.print_exc() yield None, None, f"Error: {str(e)}", None, None, None, None def search_location(lat, lon, threshold): model_name = "SatCLIP" model, error = get_active_model(model_name) if error: yield None, None, error, None, None, None, None return try: timings = {} # 1. Encode Location yield None, None, "Encoding location...", None, None, None, None t0 = time.time() loc_features = model.encode_location(float(lat), float(lon)) timings['Encoding'] = time.time() - t0 if loc_features is None: yield None, None, "Location encoding failed.", None, None, None, None return # 2. Search yield None, None, "Encoding location... ✓\nRetrieving similar images...", None, None, None, None t0 = time.time() probs, filtered_indices, top_indices = model.search(loc_features, top_percent=threshold/100.0) timings['Retrieval'] = time.time() - t0 # 3. Generate Distribution Map (not timed for location distribution) yield None, None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map...", None, None, None, None df_embed = model.df_embed top_10_indices = top_indices[:10] top_10_results = [] for idx in top_10_indices: row = df_embed.iloc[idx] top_10_results.append({'lat': row['centre_lat'], 'lon': row['centre_lon']}) # Show geographic distribution (not timed) geo_dist_map, df_filtered = plot_geographic_distribution(df_embed, probs, threshold/1000.0, title=f'Similarity to Location ({lat}, {lon})') # 4. Download Images yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) t0 = time.time() top_6_indices = top_indices[:6] results = fetch_top_k_images(top_6_indices, probs, df_embed, query_text=f"Loc: {lat},{lon}") # Get query tile query_tile = None try: lats = pd.to_numeric(df_embed['centre_lat'], errors='coerce') lons = pd.to_numeric(df_embed['centre_lon'], errors='coerce') dists = (lats - float(lat))**2 + (lons - float(lon))**2 nearest_idx = dists.idxmin() pid = df_embed.loc[nearest_idx, 'product_id'] query_tile, _ = download_and_process_image(pid, df_source=df_embed, verbose=False) except Exception as e: print(f"Error fetching nearest MajorTOM image: {e}") if query_tile is None: query_tile = get_placeholder_image(f"Query Location\n({lat}, {lon})") timings['Download'] = time.time() - t0 # 5. Visualize - keep geo_dist_map visible yield gr.update(visible=False), None, "Encoding location... ✓\nRetrieving similar images... ✓\nGenerating distribution map... ✓\nDownloading images... ✓\nGenerating visualizations...", None, None, df_filtered, gr.update(value=geo_dist_map, visible=True) t0 = time.time() fig_results = plot_top5_overview(query_tile, results, query_info=f"Loc: {lat},{lon}") gallery_items = format_results_for_gallery(results) timings['Visualization'] = time.time() - t0 # 6. Generate Final Status timing_str = f"Encoding {timings['Encoding']:.1f}s, Retrieval {timings['Retrieval']:.1f}s, Download {timings['Download']:.1f}s, Visualization {timings['Visualization']:.1f}s\n\n" status_msg = timing_str + generate_status_msg(len(filtered_indices), threshold/100.0, results) all_results = get_all_results_metadata(model, filtered_indices, probs) results_txt = format_results_to_text(all_results) yield gr.update(visible=False), gallery_items, status_msg, fig_results, [geo_dist_map, fig_results, results_txt], df_filtered, gr.update(value=geo_dist_map, visible=True) except Exception as e: import traceback traceback.print_exc() yield None, None, f"Error: {str(e)}", None, None, None, None def generate_status_msg(count, threshold, results): status_msg = f"Found {count} matches in top {threshold*100:.0f}‰.\n\nTop {len(results)} similar images:\n" for i, res in enumerate(results[:3]): status_msg += f"{i+1}. Product ID: {res['id']}, Location: ({res['lat']:.4f}, {res['lon']:.4f}), Score: {res['score']:.4f}\n" return status_msg def get_initial_plot(): # Use FarSLIP as default for initial plot, fallback to SigLIP df_vis = None img = None if 'DINOv2' in models and models['DINOv2'].df_embed is not None: img, df_vis = plot_global_map_static(models['DINOv2'].df_embed) # fig = plot_global_map(models['FarSLIP'].df_embed) else: img, df_vis = plot_global_map_static(models['SigLIP'].df_embed) return gr.update(value=img, visible=True), [img], df_vis, gr.update(visible=False) def handle_map_click(evt: gr.SelectData, df_vis): if evt is None: return None, None, None, "No point selected." try: x, y = evt.index[0], evt.index[1] # Image dimensions (New) img_width = 3000 img_height = 1500 # Scaled Margins (Proportional to 4000x2000) left_margin = 110 * 0.75 right_margin = 110 * 0.75 top_margin = 100 * 0.75 bottom_margin = 67 * 0.75 plot_width = img_width - left_margin - right_margin plot_height = img_height - top_margin - bottom_margin # Adjust for aspect ratio preservation map_aspect = 360.0 / 180.0 # 2.0 plot_aspect = plot_width / plot_height if plot_aspect > map_aspect: actual_map_width = plot_height * map_aspect actual_map_height = plot_height h_offset = (plot_width - actual_map_width) / 2 v_offset = 0 else: actual_map_width = plot_width actual_map_height = plot_width / map_aspect h_offset = 0 v_offset = (plot_height - actual_map_height) / 2 # Calculate relative position within the plot area x_in_plot = x - left_margin y_in_plot = y - top_margin # Check if click is within the actual map bounds if (x_in_plot < h_offset or x_in_plot > h_offset + actual_map_width or y_in_plot < v_offset or y_in_plot > v_offset + actual_map_height): return None, None, None, "Click outside map area. Please click on the map." # Calculate relative position within the map (0 to 1) x_rel = (x_in_plot - h_offset) / actual_map_width y_rel = (y_in_plot - v_offset) / actual_map_height # Clamp to [0, 1] x_rel = max(0, min(1, x_rel)) y_rel = max(0, min(1, y_rel)) # Convert to geographic coordinates lon = x_rel * 360 - 180 lat = 90 - y_rel * 180 # Find nearest point in df_vis if available pid = "" if df_vis is not None: dists = (df_vis['centre_lat'] - lat)**2 + (df_vis['centre_lon'] - lon)**2 min_idx = dists.idxmin() nearest_row = df_vis.loc[min_idx] if dists[min_idx] < 25: lat = nearest_row['centre_lat'] lon = nearest_row['centre_lon'] pid = nearest_row['product_id'] except Exception as e: print(f"Error handling click: {e}") import traceback traceback.print_exc() return None, None, None, f"Error: {e}" return lat, lon, pid, f"Selected Point: ({lat:.4f}, {lon:.4f})" def download_image_by_location(lat, lon, pid, model_name): """Download and return the image at the specified location""" if lat is None or lon is None: return None, "Please specify coordinates first." model, error = get_active_model(model_name) if error: return None, error try: # Convert to float to ensure proper formatting lat = float(lat) lon = float(lon) # Find Product ID if not provided if not pid: df = model.df_embed lats = pd.to_numeric(df['centre_lat'], errors='coerce') lons = pd.to_numeric(df['centre_lon'], errors='coerce') dists = (lats - lat)**2 + (lons - lon)**2 nearest_idx = dists.idxmin() pid = df.loc[nearest_idx, 'product_id'] # Download image img_384, _ = download_and_process_image(pid, df_source=model.df_embed, verbose=True) if img_384 is None: return None, f"Failed to download image for location ({lat:.4f}, {lon:.4f})" return img_384, f"Downloaded image at ({lat:.4f}, {lon:.4f})" except Exception as e: import traceback traceback.print_exc() return None, f"Error: {str(e)}" def reset_to_global_map(): """Reset the map to the initial global distribution view""" img = None df_vis = None if 'DINOv2' in models and models['DINOv2'].df_embed is not None: img, df_vis = plot_global_map_static(models['DINOv2'].df_embed) else: img, df_vis = plot_global_map_static(models['SigLIP'].df_embed) return gr.update(value=img, visible=True), [img], df_vis def format_results_to_text(results): if not results: return "No results found." txt = f"Top {len(results)} Retrieval Results\n" txt += "=" * 30 + "\n\n" for i, res in enumerate(results): txt += f"Rank: {i+1}\n" txt += f"Product ID: {res['id']}\n" txt += f"Location: Latitude {res['lat']:.6f}, Longitude {res['lon']:.6f}\n" txt += f"Similarity Score: {res['score']:.6f}\n" txt += "-" * 30 + "\n" return txt def save_plot(figs): if figs is None: return None try: # If it's a single image (initial state), save as png if isinstance(figs, PILImage.Image): fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_') os.close(fd) figs.save(path) return path # If it's a list/tuple of images [map_img, results_img] if isinstance(figs, (list, tuple)): # If only one image in list, save as PNG if len(figs) == 1 and isinstance(figs[0], PILImage.Image): fd, path = tempfile.mkstemp(suffix='.png', prefix='earth_explorer_map_') os.close(fd) figs[0].save(path) return path fd, zip_path = tempfile.mkstemp(suffix='.zip', prefix='earth_explorer_results_') os.close(fd) with zipfile.ZipFile(zip_path, 'w') as zipf: # Save Map if figs[0] is not None: map_path = os.path.join(tempfile.gettempdir(), 'map_distribution.png') figs[0].save(map_path) zipf.write(map_path, arcname='map_distribution.png') # Save Results if len(figs) > 1 and figs[1] is not None: res_path = os.path.join(tempfile.gettempdir(), 'retrieval_results.png') figs[1].save(res_path) zipf.write(res_path, arcname='retrieval_results.png') # Save Results Text if len(figs) > 2 and figs[2] is not None: txt_path = os.path.join(tempfile.gettempdir(), 'results.txt') with open(txt_path, 'w', encoding='utf-8') as f: f.write(figs[2]) zipf.write(txt_path, arcname='results.txt') return zip_path # Fallback for Plotly figure (if any) # Create a temporary file fd, path = tempfile.mkstemp(suffix='.html', prefix='earth_explorer_plot_') os.close(fd) # Write to the temporary file figs.write_html(path) return path except Exception as e: print(f"Error saving: {e}") return None # Gradio Blocks Interface with gr.Blocks(title="EarthEmbeddingExplorer") as demo: gr.Markdown("# EarthEmbeddingExplorer") gr.HTML("""
EarthEmbeddingExplorer is a tool that allows you to search for satellite images of the Earth using natural language descriptions, images, geolocations, or a simple a click on the map. For example, you can type "tropical rainforest" or "coastline with a city," and the system will find locations on Earth that match your description. It then visualizes these locations on a world map and displays the top matching images.
""") with gr.Row(): with gr.Column(scale=4): with gr.Tabs(): with gr.TabItem("Text Search") as tab_text: model_selector_text = gr.Dropdown(choices=["SigLIP", "FarSLIP"], value="FarSLIP", label="Model") query_input = gr.Textbox(label="Query", placeholder="e.g., rainforest, glacier") gr.Examples( examples=[ ["a satellite image of a river around a city"], ["a satellite image of a rainforest"], ["a satellite image of a slum"], ["a satellite image of a glacier"], ["a satellite image of snow covered mountains"] ], inputs=[query_input], label="Text Examples" ) search_btn = gr.Button("Search by Text", variant="primary") with gr.TabItem("Image Search") as tab_image: model_selector_img = gr.Dropdown(choices=["SigLIP", "FarSLIP", "SatCLIP", "DINOv2"], value="FarSLIP", label="Model") gr.Markdown("### Option 1: Upload or Select Image") image_input = gr.Image(type="pil", label="Upload Image") gr.Examples( examples=[ ["./examples/example1.png"], ["./examples/example2.png"], ["./examples/example3.png"] ], inputs=[image_input], label="Image Examples" ) gr.Markdown("### Option 2: Click Map or Enter Coordinates") btn_reset_map_img = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm") with gr.Row(): img_lat = gr.Number(label="Latitude", interactive=True) img_lon = gr.Number(label="Longitude", interactive=True) img_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False) img_click_status = gr.Markdown("") btn_download_img = gr.Button("Download Image by Geolocation", variant="secondary") search_img_btn = gr.Button("Search by Image", variant="primary") with gr.TabItem("Location Search") as tab_location: gr.Markdown("Search using **SatCLIP** location encoder.") gr.Markdown("### Click Map or Enter Coordinates") btn_reset_map_loc = gr.Button("🔄 Reset Map to Global View", variant="secondary", size="sm") with gr.Row(): lat_input = gr.Number(label="Latitude", value=30.0, interactive=True) lon_input = gr.Number(label="Longitude", value=120.0, interactive=True) loc_pid = gr.Textbox(label="Product ID (auto-filled)", visible=False) loc_click_status = gr.Markdown("") gr.Examples( examples=[ [30.32, 120.15], [40.7128, -74.0060], [24.65, 46.71], [-3.4653, -62.2159], [64.4, 16.8] ], inputs=[lat_input, lon_input], label="Location Examples" ) search_loc_btn = gr.Button("Search by Location", variant="primary") threshold_slider = gr.Slider(minimum=1, maximum=30, value=7, step=1, label="Top Percentage (‰)") status_output = gr.Textbox(label="Status", lines=10) save_btn = gr.Button("Download Result") download_file = gr.File(label="Zipped Results", height=40) with gr.Column(scale=6): plot_map = gr.Image( label="Geographical Distribution", type="pil", interactive=False, height=400, width=800, visible=True ) plot_map_interactive = gr.Plot( label="Geographical Distribution (Interactive)", visible=False ) results_plot = gr.Image(label="Top 5 Matched Images", type="pil") gallery_images = gr.Gallery(label="Top Retrieved Images (Zoom)", columns=3, height="auto") current_fig = gr.State() map_data_state = gr.State() # Initial Load demo.load(fn=get_initial_plot, outputs=[plot_map, current_fig, map_data_state, plot_map_interactive]) # Reset Map Buttons btn_reset_map_img.click( fn=reset_to_global_map, outputs=[plot_map, current_fig, map_data_state] ) btn_reset_map_loc.click( fn=reset_to_global_map, outputs=[plot_map, current_fig, map_data_state] ) # Map Click Event - updates Image Search coordinates plot_map.select( fn=handle_map_click, inputs=[map_data_state], outputs=[img_lat, img_lon, img_pid, img_click_status] ) # Map Click Event - also updates Location Search coordinates plot_map.select( fn=handle_map_click, inputs=[map_data_state], outputs=[lat_input, lon_input, loc_pid, loc_click_status] ) # Download Image by Geolocation btn_download_img.click( fn=download_image_by_location, inputs=[img_lat, img_lon, img_pid, model_selector_img], outputs=[image_input, img_click_status] ) # Search Event (Text) search_btn.click( fn=search_text, inputs=[query_input, threshold_slider, model_selector_text], outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map] ) # Search Event (Image) search_img_btn.click( fn=search_image, inputs=[image_input, threshold_slider, model_selector_img], outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map] ) # Search Event (Location) search_loc_btn.click( fn=search_location, inputs=[lat_input, lon_input, threshold_slider], outputs=[plot_map_interactive, gallery_images, status_output, results_plot, current_fig, map_data_state, plot_map] ) # Save Event save_btn.click( fn=save_plot, inputs=[current_fig], outputs=[download_file] ) # Tab Selection Events def show_static_map(): return gr.update(visible=True), gr.update(visible=False) tab_text.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive]) tab_image.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive]) tab_location.select(fn=show_static_map, outputs=[plot_map, plot_map_interactive]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False)