Spaces:
Runtime error
Runtime error
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from matplotlib.figure import Figure | |
| from matplotlib.backends.backend_agg import FigureCanvasAgg | |
| from io import BytesIO | |
| from PIL import Image | |
| import json | |
| import os | |
| import cartopy.crs as ccrs | |
| import cartopy.feature as cfeature | |
| from matplotlib.figure import Figure | |
| from matplotlib.backends.backend_agg import FigureCanvasAgg | |
| def get_background_map_trace(): | |
| # Use absolute path relative to the script location | |
| base_dir = os.path.dirname(os.path.abspath(__file__)) | |
| geojson_path = os.path.join(base_dir, 'countries.geo.json') | |
| if not os.path.exists(geojson_path): | |
| print(f"Warning: Local GeoJSON not found at {geojson_path}") | |
| # Debugging info for remote deployment | |
| print(f"Current Working Directory: {os.getcwd()}") | |
| try: | |
| print(f"Files in {base_dir}: {os.listdir(base_dir)}") | |
| except Exception as e: | |
| print(f"Error listing files: {e}") | |
| return None | |
| try: | |
| with open(geojson_path, 'r', encoding='utf-8') as f: | |
| world_geojson = json.load(f) | |
| ids = [f['id'] for f in world_geojson['features'] if 'id' in f] | |
| print(f"DEBUG: Loaded {len(ids)} countries from {geojson_path}") | |
| if not ids: | |
| print("DEBUG: No IDs found in GeoJSON features") | |
| return None | |
| # Create a background map using Choropleth | |
| # We use a constant value for z to make all countries the same color | |
| bg_trace = go.Choropleth( | |
| geojson=world_geojson, | |
| locations=ids, | |
| z=[1]*len(ids), # Dummy value | |
| colorscale=[[0, 'rgb(243, 243, 243)'], [1, 'rgb(243, 243, 243)']], # Land color | |
| showscale=False, | |
| marker_line_color='rgb(204, 204, 204)', # Coastline color | |
| marker_line_width=0.5, | |
| hoverinfo='skip', | |
| name='Background' | |
| ) | |
| return bg_trace | |
| except Exception as e: | |
| print(f"Error loading GeoJSON: {e}") | |
| return None | |
| def plot_global_map_static(df, lat_col='centre_lat', lon_col='centre_lon'): | |
| if df is None: | |
| return None, None | |
| # Ensure coordinates are numeric and drop NaNs | |
| df_clean = df.copy() | |
| df_clean[lat_col] = pd.to_numeric(df_clean[lat_col], errors='coerce') | |
| df_clean[lon_col] = pd.to_numeric(df_clean[lon_col], errors='coerce') | |
| df_clean = df_clean.dropna(subset=[lat_col, lon_col]) | |
| # Sample every 3rd item if too large | |
| if len(df_clean) > 250000: | |
| # Calculate step size to get approximately 50000 samples | |
| step = 2 | |
| # step = max(1, len(df_clean) // 50000) | |
| df_vis = df_clean.iloc[::step] # Take every 'step'-th row | |
| print(f"Sampled {len(df_vis)} points from {len(df_clean)} total points (step={step}) for visualization.") | |
| else: | |
| df_vis = df_clean | |
| # Create static map using Matplotlib | |
| # Use a fixed size and DPI to make coordinate mapping easier | |
| # Width=800px, Height=400px -> Aspect Ratio 2:1 (matches 360:180) | |
| # Increased DPI for better quality: 8x300 = 2400px width | |
| fig = Figure(figsize=(10, 5), dpi=300) | |
| ax = fig.add_subplot(111, projection=ccrs.PlateCarree()) | |
| # Add land + coastline (Cartopy) | |
| ax.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.2) | |
| ax.add_feature(cfeature.COASTLINE, linewidth=0.5, alpha=0.5) | |
| # Plot points - Use blue to match user request | |
| ax.scatter( | |
| df_vis[lon_col], | |
| df_vis[lat_col], | |
| s=0.2, | |
| c="blue", | |
| marker='o', | |
| edgecolors='none', | |
| # alpha=0.6, | |
| transform=ccrs.PlateCarree(), | |
| label='Samples', | |
| ) | |
| # Set limits to full world | |
| ax.set_extent([-180, 180, -90, 90], crs=ccrs.PlateCarree()) | |
| # Remove axes and margins | |
| ax.axis('off') | |
| # fig.subplots_adjust(left=0, right=1, bottom=0, top=1) | |
| # Add Legend | |
| ax.legend(loc='lower left', markerscale=5, frameon=True, facecolor='white', framealpha=0.9) | |
| fig.tight_layout() | |
| # Save to PIL | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png', facecolor='white') | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img, df_vis | |
| def plot_geographic_distribution(df, scores, threshold, lat_col='centre_lat', lon_col='centre_lon', title="Search Results"): | |
| if df is None or scores is None: | |
| return None, None | |
| df_vis = df.copy() | |
| df_vis['score'] = scores | |
| df_vis = df_vis.sort_values(by='score', ascending=False) | |
| # Top 1% | |
| top_n = int(len(df_vis) * threshold) | |
| if top_n < 1: top_n = 1 | |
| # if top_n > 5000: top_n = 5000 | |
| df_filtered = df_vis.head(top_n) | |
| fig = Figure(figsize=(10, 5), dpi=300) | |
| ax = fig.add_subplot(111, projection=ccrs.PlateCarree()) | |
| # Add land + coastline (Cartopy) | |
| ax.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.2) | |
| ax.add_feature(cfeature.COASTLINE, linewidth=0.5, alpha=0.5) | |
| # 2. Plot Search Results with color map | |
| label_text = f'Top {threshold * 1000:.0f}‰ Matches' | |
| sc = ax.scatter( | |
| df_filtered[lon_col], | |
| df_filtered[lat_col], | |
| c=df_filtered['score'], | |
| cmap='Reds', | |
| s=0.35, | |
| alpha=0.8, | |
| transform=ccrs.PlateCarree(), | |
| label=label_text, | |
| ) | |
| ax.set_extent([-180, 180, -90, 90], crs=ccrs.PlateCarree()) | |
| ax.axis('off') | |
| # fig.subplots_adjust(left=0, right=1, bottom=0, top=1) | |
| # Add Colorbar | |
| cbar = fig.colorbar(sc, ax=ax, fraction=0.025, pad=0.02) | |
| cbar.set_label('Similarity Score') | |
| # Add Legend | |
| ax.legend(loc='lower left', markerscale=3, frameon=True, facecolor='white', framealpha=0.9) | |
| fig.tight_layout() | |
| # Add title (optional, might overlap) | |
| # ax.set_title(title) | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png', facecolor='white') | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img, df_filtered | |
| def format_results_for_gallery(results): | |
| """ | |
| Format results for Gradio Gallery. | |
| results: list of dicts | |
| Returns: list of (image, caption) tuples | |
| """ | |
| gallery_items = [] | |
| for res in results: | |
| # Use 384x384 image for gallery thumbnail/preview | |
| img = res.get('image_384') | |
| if img is None: | |
| continue | |
| caption = f"Score: {res['score']:.4f}\nLat: {res['lat']:.2f}, Lon: {res['lon']:.2f}\nID: {res['id']}" | |
| gallery_items.append((img, caption)) | |
| return gallery_items | |
| def plot_top5_overview(query_image, results, query_info="Query"): | |
| """ | |
| Generates a matplotlib figure showing the query image and top retrieved images. | |
| Similar to the visualization in SigLIP_embdding.ipynb. | |
| Uses OO Matplotlib API for thread safety. | |
| """ | |
| top_k = len(results) | |
| if top_k == 0: | |
| return None | |
| # Special case for Text Search (query_image is None) with 10 results | |
| # User requested: "Middle box top and bottom each 5 photos" | |
| if query_image is None and top_k == 10: | |
| cols = 5 | |
| rows = 2 | |
| fig = Figure(figsize=(4 * cols, 4 * rows)) # Square-ish aspect ratio per image | |
| canvas = FigureCanvasAgg(fig) | |
| for i, res in enumerate(results): | |
| # Calculate row and col | |
| r = i // 5 | |
| c = i % 5 | |
| # Add subplot (1-based index) | |
| ax = fig.add_subplot(rows, cols, i + 1) | |
| img_384 = res.get('image_384') | |
| if img_384: | |
| ax.imshow(img_384) | |
| ax.set_title(f"Rank {i+1}\nScore: {res['score']:.4f}\n({res['lat']:.2f}, {res['lon']:.2f})", fontsize=9) | |
| else: | |
| ax.text(0.5, 0.5, "N/A", ha='center', va='center') | |
| ax.axis('off') | |
| fig.tight_layout() | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png', bbox_inches='tight') | |
| buf.seek(0) | |
| return Image.open(buf) | |
| # Default behavior (for Image Search or other counts) | |
| # Layout: | |
| # If query_image exists: | |
| # Row 1: Query Image (Left), Top-K 384x384 (Right) | |
| # Row 2: Empty (Left), Top-K Original (Right) | |
| cols = top_k + (1 if query_image else 0) | |
| rows = 2 | |
| fig = Figure(figsize=(4 * cols, 8)) | |
| canvas = FigureCanvasAgg(fig) | |
| # Plot Query Image | |
| if query_image: | |
| # Row 1, Col 1 | |
| ax = fig.add_subplot(rows, cols, 1) | |
| ax.imshow(query_image) | |
| ax.set_title(f"Query\n{query_info}", color='blue', fontweight='bold') | |
| ax.axis('off') | |
| # Row 2, Col 1 (Empty or repeat?) | |
| # Let's leave it empty or show text | |
| ax = fig.add_subplot(rows, cols, cols + 1) | |
| ax.axis('off') | |
| start_col = 2 | |
| else: | |
| start_col = 1 | |
| # Plot Results | |
| for i, res in enumerate(results): | |
| # Row 1: 384x384 | |
| ax1 = fig.add_subplot(rows, cols, start_col + i) | |
| img_384 = res.get('image_384') | |
| if img_384: | |
| ax1.imshow(img_384) | |
| ax1.set_title(f"Rank {i+1} (384)\nScore: {res['score']:.4f}\n({res['lat']:.2f}, {res['lon']:.2f})", fontsize=9) | |
| else: | |
| ax1.text(0.5, 0.5, "N/A", ha='center', va='center') | |
| ax1.axis('off') | |
| # Row 2: Full | |
| ax2 = fig.add_subplot(rows, cols, cols + start_col + i) | |
| img_full = res.get('image_full') | |
| if img_full: | |
| ax2.imshow(img_full) | |
| ax2.set_title("Original", fontsize=9) | |
| else: | |
| ax2.text(0.5, 0.5, "N/A", ha='center', va='center') | |
| ax2.axis('off') | |
| fig.tight_layout() | |
| # Save to buffer | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png', bbox_inches='tight') | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def plot_location_distribution(df_all, query_lat, query_lon, results, query_info="Query"): | |
| """ | |
| Generates a global distribution map for location search. | |
| Reference: improve2_satclip.ipynb | |
| """ | |
| if df_all is None: | |
| return None | |
| fig = Figure(figsize=(8, 4), dpi=300) | |
| canvas = FigureCanvasAgg(fig) | |
| ax = fig.add_subplot(111, projection=ccrs.PlateCarree()) | |
| # # 1. Background (All samples) - Sampled if too large | |
| # if len(df_all) > 300000: | |
| # df_bg = df_all.sample(300000) | |
| # else: | |
| # df_bg = df_all | |
| ax.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.2) | |
| ax.add_feature(cfeature.COASTLINE, linewidth=0.5, alpha=0.5) | |
| # ax.scatter(df_bg['centre_lon'], df_bg['centre_lat'], c='lightgray', s=1, alpha=0.3, label='All Samples') | |
| # 2. Query Location | |
| ax.scatter(query_lon, query_lat, c='red', s=150, marker='*', edgecolors='black', zorder=10, label='Input Coordinate') | |
| # 3. Retrieved Results | |
| res_lons = [r['lon'] for r in results] | |
| res_lats = [r['lat'] for r in results] | |
| ax.scatter(res_lons, res_lats, c='blue', s=50, marker='x', linewidths=2, label=f'Retrieved Top-{len(results)}') | |
| # 4. Connecting lines | |
| for r in results: | |
| ax.plot([query_lon, r['lon']], [query_lat, r['lat']], 'b--', alpha=0.2) | |
| ax.legend(loc='upper right') | |
| ax.set_title(f"Location of Top 5 Matched Images ({query_info})") | |
| ax.set_xlabel("Longitude") | |
| ax.set_ylabel("Latitude") | |
| ax.grid(True, alpha=0.2) | |
| # Save to buffer | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png', bbox_inches='tight') | |
| buf.seek(0) | |
| return Image.open(buf) | |