import os import json import subprocess import pandas as pd import numpy as np import geopandas as gpd from fastapi import FastAPI, Query, HTTPException from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware import uvicorn app = FastAPI(title="iNaturalist & AVONET Explorer API") EXPLORER_MODE = os.environ.get("EXPLORER_MODE", "birds").lower() # Startup extraction check for images.tar IMAGE_DIR = "data/images" BIRDS_CHECK = os.path.join(IMAGE_DIR, "train") PLANTS_CHECK_DIR = IMAGE_DIR # Check if birds images exist (look for Aves directories) def has_bird_images(): train_dir = os.path.join(IMAGE_DIR, "train") if not os.path.exists(train_dir): return False entries = os.listdir(train_dir) return any("Aves" in e for e in entries[:20]) def has_plant_images(): train_dir = os.path.join(IMAGE_DIR, "train") if not os.path.exists(train_dir): return False entries = os.listdir(train_dir) return any("Plantae" in e for e in entries[:20]) if not has_bird_images() or not has_plant_images(): print("Some thumbnails are missing. Fetching from Hugging Face Dataset...") from huggingface_hub import hf_hub_download token = os.environ.get("HF_TOKEN") # 1. Download Birds Thumbnails (if missing) if not has_bird_images(): try: print("Downloading birds images.tar from Hugging Face Dataset...") downloaded_path = hf_hub_download( repo_id="mayesh/inat-thumbnails", filename="images.tar", repo_type="dataset", token=token, local_dir="." ) print(f"Downloaded images.tar to {downloaded_path}. Extracting...") os.makedirs(IMAGE_DIR, exist_ok=True) subprocess.run(["tar", "-xf", downloaded_path], check=True) print("Birds thumbnail extraction complete!") if os.path.exists(downloaded_path): os.remove(downloaded_path) except Exception as e: print(f"Error fetching/extracting birds images.tar: {e}") # 2. Download Plants Thumbnails (if missing) if not has_plant_images(): try: print("Downloading plants_images.tar from Hugging Face Dataset...") downloaded_path_plants = hf_hub_download( repo_id="mayesh/Side-Info_Generation", filename="plants_images.tar", repo_type="dataset", token=token, local_dir="." ) print(f"Downloaded plants_images.tar to {downloaded_path_plants}. Extracting...") os.makedirs(IMAGE_DIR, exist_ok=True) subprocess.run(["tar", "-xf", downloaded_path_plants], check=True) print("Plants thumbnail extraction complete!") if os.path.exists(downloaded_path_plants): os.remove(downloaded_path_plants) except Exception as e: print(f"Error fetching/extracting plants_images.tar: {e}") # Load world boundaries on startup world_geojson_data = None geojson_path = 'data/countries.geojson' if not os.path.exists(geojson_path): geojson_path = '/scratch/project/prj-02-visual-ai/mayesh/side-info_generation/inat_probing/data/countries.geojson' try: with open(geojson_path, 'r') as f: world_geojson_data = json.load(f) print(f"Loaded world geojson boundaries from {geojson_path}") except Exception as e: print(f"Error loading countries.geojson: {e}") # Enable CORS for development app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global data holders data_by_mode = {} df_val_aligned_by_mode = {} precomputed_similarity_by_mode = {} def get_data(mode: str = "birds"): mode = str(mode).lower() if mode not in data_by_mode: return None return data_by_mode[mode] def get_val_aligned(mode: str = "birds"): mode = str(mode).lower() if mode not in df_val_aligned_by_mode: return None return df_val_aligned_by_mode[mode] def get_similarity(mode: str = "birds"): mode = str(mode).lower() if mode not in precomputed_similarity_by_mode: return {} return precomputed_similarity_by_mode[mode] def get_trait_maps(mode: str = "birds"): df = get_data(mode) if df is None or df.empty: return {}, {} lifestyles = sorted(df['Primary.Lifestyle'].dropna().unique().tolist()) trophics = sorted(df['Trophic.Level'].dropna().unique().tolist()) lifestyle_map = {val: idx + 1 for idx, val in enumerate(lifestyles)} trophic_map = {val: idx + 1 for idx, val in enumerate(trophics)} return lifestyle_map, trophic_map def sanitize_str(val): if pd.isna(val): return None return str(val) def sanitize_float(val): if pd.isna(val) or np.isinf(val): return None return float(val) def load_and_preprocess_dataset(df_path): if not os.path.exists(df_path): print(f"Warning: master dataset parquet not found at {df_path}") return None print(f"Loading iNaturalist master dataset from {df_path}...") df_loaded = pd.read_parquet(df_path) print(f"Loaded all {len(df_loaded)} observations from {df_path}.") # Subsample to 250k observations for performance (hover responsiveness, memory) # Stratified by species to preserve taxonomic diversity MAX_OBSERVATIONS = 250_000 if len(df_loaded) > MAX_OBSERVATIONS: try: df_loaded = ( df_loaded.groupby('name', group_keys=False) .apply(lambda g: g.sample(frac=MAX_OBSERVATIONS / len(df_loaded), random_state=42)) ) if len(df_loaded) > MAX_OBSERVATIONS: df_loaded = df_loaded.sample(n=MAX_OBSERVATIONS, random_state=42) except Exception as e: print(f"Stratified sampling failed ({e}), falling back to random sample.") df_loaded = df_loaded.sample(n=MAX_OBSERVATIONS, random_state=42) df_loaded = df_loaded.reset_index(drop=True) print(f"Subsampled to {len(df_loaded)} observations for performance.") df_loaded['date_dt'] = pd.to_datetime(df_loaded['date'].astype(str)) df_loaded['year'] = df_loaded['date_dt'].dt.year.fillna(2019).astype(int) df_loaded['month'] = df_loaded['date_dt'].dt.month.fillna(6).astype(int) # Standardize column types df_loaded['latitude'] = df_loaded['latitude'].fillna(0.0).astype(float) df_loaded['longitude'] = df_loaded['longitude'].fillna(0.0).astype(float) df_loaded['elevation'] = df_loaded['elevation'].fillna(0.0).astype(float) df_loaded['temperature_2m'] = df_loaded['temperature_2m'].astype(float) df_loaded['ndvi'] = df_loaded['ndvi'].astype(float) df_loaded['Mass'] = df_loaded['Mass'].astype(float) df_loaded['Hand-Wing.Index'] = df_loaded['Hand-Wing.Index'].astype(float) # Perform spatial join for country and continent print("Mapping points geographically (Spatial Join)...") try: countries_gdf = gpd.read_file(geojson_path)[['name', 'continent', 'geometry']] countries_gdf = countries_gdf.rename(columns={'name': 'country_name'}) points_gdf = gpd.GeoDataFrame( df_loaded[['longitude', 'latitude']], geometry=gpd.points_from_xy(df_loaded['longitude'], df_loaded['latitude']), crs='EPSG:4326' ) joined = gpd.sjoin(points_gdf, countries_gdf, how='left', predicate='intersects') df_loaded['country'] = joined['country_name'].fillna('Ocean/Unknown').astype(str) df_loaded['continent'] = joined['continent'].fillna('Ocean/Unknown').astype(str) print("Geographic mapping complete!") except Exception as e: print(f"Warning: Geographic mapping failed: {e}") df_loaded['country'] = 'Unknown' df_loaded['continent'] = 'Unknown' return df_loaded # Load datasets data_by_mode['birds'] = load_and_preprocess_dataset('metadata/inat_world_model_master.parquet') data_by_mode['plants'] = load_and_preprocess_dataset('metadata/plants_metadata_with_traits.parquet') # Load aligned validation files for mode in ['birds', 'plants']: aligned_filename = 'df_val_aligned_plants.json' if mode == 'plants' else 'df_val_aligned.json' aligned_path = os.path.join('metadata', aligned_filename) try: df_val_aligned_by_mode[mode] = pd.read_json(aligned_path, orient='records') print(f"Loaded {mode} validation alignment cache. Shape: {df_val_aligned_by_mode[mode].shape}") except Exception as e: print(f"Error loading {mode} validation alignment cache ({aligned_path}): {e}") df_val_aligned_by_mode[mode] = pd.DataFrame() # Load similarity caches for mode in ['birds', 'plants']: sim_filename = 'similarity_matches_plants.json' if mode == 'plants' else 'similarity_matches.json' sim_path = os.path.join('metadata', sim_filename) try: with open(sim_path, 'r') as f: precomputed_similarity_by_mode[mode] = json.load(f) print(f"Loaded {mode} similarity matches cache successfully.") except Exception as e: print(f"Error loading {mode} similarity matches cache ({sim_path}): {e}") precomputed_similarity_by_mode[mode] = {} # Expose filter categories @app.get("/api/filters") def get_filters(mode: str = "birds"): df = get_data(mode) if df is None or df.empty: raise HTTPException(status_code=404, detail=f"Dataset for mode '{mode}' not loaded") filters = { "mode": mode, "orders": sorted(df['order'].dropna().unique().tolist()), "families": sorted(df['family'].dropna().unique().tolist()), "species": sorted(df['name'].dropna().unique().tolist()), "common_names": sorted(df['common_name'].dropna().unique().tolist()), "continents": sorted(df['continent'].dropna().unique().tolist()), "countries": sorted(df['country'].dropna().unique().tolist()), "lifestyles": sorted(df['Primary.Lifestyle'].dropna().unique().tolist()), "trophic_levels": sorted(df['Trophic.Level'].dropna().unique().tolist()) } return filters # Expose dynamic query endpoint @app.post("/api/query") def query_data(params: dict): explorer_mode = params.get('explorerMode', 'birds') active_df = get_data(explorer_mode) active_val_aligned = get_val_aligned(explorer_mode) mode = params.get('mode', 'spatial') # spatial or pca model_name = params.get('model', 'DINOv3').lower() # Select working dataframe working_df = active_val_aligned if mode == 'pca' else active_df if working_df is None or working_df.empty: return {"points": [], "grid_images": []} filtered = working_df.copy() # Apply Taxonomy filters if params.get('order'): filtered = filtered[filtered['order'] == params['order']] if params.get('family'): filtered = filtered[filtered['family'] == params['family']] if params.get('species'): sp = params['species'] if isinstance(sp, list) and sp: filtered = filtered[filtered['name'].isin(sp)] elif isinstance(sp, str) and sp: filtered = filtered[filtered['name'] == sp] # Apply global text search query if params.get('q'): q = str(params['q']).strip().lower() if q: search_cols = ['name', 'common_name', 'genus', 'family', 'order', 'country', 'continent'] masks = [] for col in search_cols: if col in filtered.columns: masks.append(filtered[col].str.lower().str.contains(q, na=False)) if masks: final_mask = masks[0] for m in masks[1:]: final_mask = final_mask | m filtered = filtered[final_mask] # Apply Categorical filters if params.get('lifestyle'): filtered = filtered[filtered['Primary.Lifestyle'] == params['lifestyle']] if params.get('trophic'): filtered = filtered[filtered['Trophic.Level'] == params['trophic']] # Apply Year filter (Cumulative vs Single Year) if params.get('year'): try: year_val = int(params['year']) year_mode = params.get('year_mode', 'cumulative') if year_mode == 'cumulative': filtered = filtered[filtered['year'] <= year_val] else: filtered = filtered[filtered['year'] == year_val] except Exception as e: print(f"Error filtering by year: {e}") # Apply Ranges if params.get('temp'): filtered = filtered[(filtered['temperature_2m'] >= params['temp'][0]) & (filtered['temperature_2m'] <= params['temp'][1])] if params.get('elev'): filtered = filtered[(filtered['elevation'] >= params['elev'][0]) & (filtered['elevation'] <= params['elev'][1])] if params.get('ndvi'): filtered = filtered[(filtered['ndvi'] >= params['ndvi'][0]) & (filtered['ndvi'] <= params['ndvi'][1])] if params.get('mass'): filtered = filtered[(filtered['Mass'] >= params['mass'][0]) & (filtered['Mass'] <= params['mass'][1])] if params.get('hwi'): filtered = filtered[(filtered['Hand-Wing.Index'] >= params['hwi'][0]) & (filtered['Hand-Wing.Index'] <= params['hwi'][1])] # Apply Spatial bounds if spatial mode and bounds are active if mode == 'spatial' and params.get('bounds'): b = params['bounds'] min_lat, max_lat = b['south'], b['north'] min_lon, max_lon = b['west'], b['east'] # Handle coordinate normalization filtered = filtered[(filtered['latitude'] >= min_lat) & (filtered['latitude'] <= max_lat)] # Handle crossing antimeridian if min_lon > max_lon: filtered = filtered[(filtered['longitude'] >= min_lon) | (filtered['longitude'] <= max_lon)] else: filtered = filtered[(filtered['longitude'] >= min_lon) & (filtered['longitude'] <= max_lon)] total_matches = len(filtered) # Generate random sample of 24 images for the right grid panel grid_sample = filtered.dropna(subset=['file_name']) if not grid_sample.empty: grid_sample = grid_sample.sample(n=min(len(grid_sample), 24), random_state=42) grid_images = [] for _, row in grid_sample.iterrows(): grid_images.append({ "image_id": str(row.get('image_id', '')) if not pd.isna(row.get('image_id')) else '', "file_name": sanitize_str(row['file_name']), "name": sanitize_str(row['name']), "common_name": sanitize_str(row.get('common_name')), "lifestyle": sanitize_str(row.get('Primary.Lifestyle')), "trophic": sanitize_str(row.get('Trophic.Level')), "niche": sanitize_str(row.get('Trophic.Niche')), "mass": sanitize_float(row.get('Mass')), "hwi": sanitize_float(row.get('Hand-Wing.Index')), "temp": sanitize_float(row.get('temperature_2m')), "elev": sanitize_float(row.get('elevation')), "ndvi": sanitize_float(row.get('ndvi')), "lat": sanitize_float(row['latitude']), "lon": sanitize_float(row['longitude']) }) else: grid_images = [] points = [] # Process Map/PCA point display outputs if mode == 'spatial': # Apply Zoom-dependent Bounding-Box smart sampling to ensure smooth Leaflet Canvas drawing zoom = params.get('zoom', 2) sample_limit = 5000 if zoom < 5 else 15000 if len(filtered) > sample_limit: plot_sample = filtered.sample(n=sample_limit, random_state=42) else: plot_sample = filtered for _, row in plot_sample.iterrows(): points.append({ "lat": sanitize_float(row['latitude']), "lon": sanitize_float(row['longitude']), "name": sanitize_str(row['name']), "common_name": sanitize_str(row.get('common_name')), "order": sanitize_str(row.get('order')), "lifestyle": sanitize_str(row.get('Primary.Lifestyle')), "trophic": sanitize_str(row.get('Trophic.Level')), "niche": sanitize_str(row.get('Trophic.Niche')), "mass": sanitize_float(row.get('Mass')), "hwi": sanitize_float(row.get('Hand-Wing.Index')), "temp": sanitize_float(row.get('temperature_2m')), "elev": sanitize_float(row.get('elevation')), "ndvi": sanitize_float(row.get('ndvi')), "file_name": sanitize_str(row['file_name']) }) else: # PCA Mode: Return PCA coordinates for chosen model pca1_col = f'{model_name}_pca1' pca2_col = f'{model_name}_pca2' # Limit to 3000 points in PCA space to prevent canvas overlap/lag if len(filtered) > 3000: plot_sample = filtered.sample(n=3000, random_state=42) else: plot_sample = filtered for _, row in plot_sample.iterrows(): points.append({ "pca1": sanitize_float(row[pca1_col]), "pca2": sanitize_float(row[pca2_col]), "name": sanitize_str(row['name']), "common_name": sanitize_str(row.get('common_name')), "lifestyle": sanitize_str(row.get('Primary.Lifestyle')), "trophic": sanitize_str(row.get('Trophic.Level')), "niche": sanitize_str(row.get('Trophic.Niche')), "mass": sanitize_float(row.get('Mass')), "hwi": sanitize_float(row.get('Hand-Wing.Index')), "temp": sanitize_float(row.get('temperature_2m')), "elev": sanitize_float(row.get('elevation')), "ndvi": sanitize_float(row.get('ndvi')), "file_name": sanitize_str(row['file_name_key']) }) return { "points": points, "grid_images": grid_images, "total_matches": total_matches } # Mount static files and images STATIC_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'static') os.makedirs(STATIC_DIR, exist_ok=True) app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") IMAGE_DIR = "data/images" if not os.path.exists(IMAGE_DIR): IMAGE_DIR = "/scratch/project/prj-02-visual-ai/mayesh/side-info_generation/inat_probing/data/images" if os.path.exists(IMAGE_DIR): app.mount("/images", StaticFiles(directory=IMAGE_DIR), name="images") @app.get("/api/world_geojson") def get_world_geojson(): global world_geojson_data if world_geojson_data is None: raise HTTPException(status_code=404, detail="World GeoJSON map boundaries not loaded on server startup") return world_geojson_data @app.get("/api/all_coordinates") def get_all_coordinates(mode: str = "birds"): df = get_data(mode) if df is None or df.empty: raise HTTPException(status_code=500, detail=f"Dataset for mode '{mode}' not loaded") lifestyle_map, trophic_map = get_trait_maps(mode) lat = df['latitude'].fillna(0.0).values lon = df['longitude'].fillna(0.0).values elev = df['elevation'].fillna(0.0).values year = df['year'].fillna(2019).values.astype(int) idx = df.index.values.astype(int) temp = df['temperature_2m'].fillna(286.2).values temp = np.where(temp < -1000, 286.2, temp) ndvi = df['ndvi'].fillna(0.0).values ndvi = np.where(ndvi < -1000, 0.0, ndvi) lifestyle = df['Primary.Lifestyle'].astype(object).map(lifestyle_map).fillna(0).values.astype(int) trophic = df['Trophic.Level'].astype(object).map(trophic_map).fillna(0).values.astype(int) flat = np.column_stack((lat, lon, elev, idx, year, temp, ndvi, lifestyle, trophic)).flatten() return {"points": flat.tolist()} @app.post("/api/query_coordinates") def query_coordinates(params: dict): explorer_mode = params.get('explorerMode', 'birds') df = get_data(explorer_mode) if df is None or df.empty: raise HTTPException(status_code=500, detail=f"Dataset for mode '{explorer_mode}' not loaded") filtered = df.copy() # 1. Taxonomy filters if params.get('order'): filtered = filtered[filtered['order'] == params['order']] if params.get('family'): filtered = filtered[filtered['family'] == params['family']] if params.get('species'): sp = params['species'] if isinstance(sp, list) and sp: filtered = filtered[filtered['name'].isin(sp)] elif isinstance(sp, str) and sp: filtered = filtered[filtered['name'] == sp] # Apply global text search query if params.get('q'): q = str(params['q']).strip().lower() if q: search_cols = ['name', 'common_name', 'genus', 'family', 'order', 'country', 'continent'] masks = [] for col in search_cols: if col in filtered.columns: masks.append(filtered[col].str.lower().str.contains(q, na=False)) if masks: final_mask = masks[0] for m in masks[1:]: final_mask = final_mask | m filtered = filtered[final_mask] # 2. Location filters if params.get('continent'): filtered = filtered[filtered['continent'] == params['continent']] if params.get('country'): filtered = filtered[filtered['country'] == params['country']] if params.get('hemisphere'): if params['hemisphere'] == 'Northern': filtered = filtered[filtered['latitude'] >= 0] elif params['hemisphere'] == 'Southern': filtered = filtered[filtered['latitude'] < 0] # 3. Categorical trait filters if params.get('lifestyle'): filtered = filtered[filtered['Primary.Lifestyle'] == params['lifestyle']] if params.get('trophic'): filtered = filtered[filtered['Trophic.Level'] == params['trophic']] lifestyle_map, trophic_map = get_trait_maps(explorer_mode) lat = filtered['latitude'].fillna(0.0).values lon = filtered['longitude'].fillna(0.0).values elev = filtered['elevation'].fillna(0.0).values year = filtered['year'].fillna(2019).values.astype(int) idx = filtered.index.values.astype(int) temp = filtered['temperature_2m'].fillna(286.2).values temp = np.where(temp < -1000, 286.2, temp) ndvi = filtered['ndvi'].fillna(0.0).values ndvi = np.where(ndvi < -1000, 0.0, ndvi) lifestyle = filtered['Primary.Lifestyle'].astype(object).map(lifestyle_map).fillna(0).values.astype(int) trophic = filtered['Trophic.Level'].astype(object).map(trophic_map).fillna(0).values.astype(int) flat = np.column_stack((lat, lon, elev, idx, year, temp, ndvi, lifestyle, trophic)).flatten() return {"points": flat.tolist()} landcover_map = { 0: "Unknown", 20: "Shrubs", 30: "Herbaceous vegetation", 40: "Cultivated / Agriculture", 50: "Urban / Built-up", 60: "Bare / Sparse vegetation", 70: "Snow / Ice", 80: "Permanent water", 90: "Herbaceous wetland", 100: "Moss / Lichen", 111: "Closed Forest (Needleleaf)", 112: "Closed Forest (Broadleaf)", 113: "Closed Forest (Deciduous Needleleaf)", 114: "Closed Forest (Deciduous Broadleaf)", 115: "Closed Forest (Mixed)", 116: "Closed Forest (Unknown)", 121: "Open Forest (Needleleaf)", 122: "Open Forest (Broadleaf)", 123: "Open Forest (Deciduous Needleleaf)", 124: "Open Forest (Deciduous Broadleaf)", 125: "Open Forest (Mixed)", 126: "Open Forest (Unknown)", 200: "Ocean / Open Water" } def map_landcover(val): if pd.isna(val): return "Unknown" try: code = int(float(val)) return landcover_map.get(code, "Unknown") except (ValueError, TypeError): return "Unknown" @app.get("/api/details/{idx}") def get_details(idx: int, mode: str = "birds"): df = get_data(mode) if df is None or df.empty: raise HTTPException(status_code=500, detail=f"Dataset for mode '{mode}' not loaded") if idx not in df.index: raise HTTPException(status_code=404, detail=f"Specimen details not found for index {idx} in {mode}") row = df.loc[idx] if isinstance(row, pd.DataFrame): row = row.iloc[0] return { "name": sanitize_str(row.get('name')), "common_name": sanitize_str(row.get('common_name')), "lifestyle": sanitize_str(row.get('Primary.Lifestyle')), "trophic": sanitize_str(row.get('Trophic.Level')), "niche": sanitize_str(row.get('Trophic.Niche')), "mass": sanitize_float(row.get('Mass')), "hwi": sanitize_float(row.get('Hand-Wing.Index')), "temp": sanitize_float(row.get('temperature_2m')), "elev": sanitize_float(row.get('elevation')), "ndvi": sanitize_float(row.get('ndvi')), "file_name": sanitize_str(row.get('file_name')), # Additional metadata "order": sanitize_str(row.get('order')), "family": sanitize_str(row.get('family')), "genus": sanitize_str(row.get('genus')), "beak_length": sanitize_float(row.get('Beak.Length_Culmen')), "tarsus_length": sanitize_float(row.get('Tarsus.Length')), "landcover": map_landcover(row.get('landcover_class')), "precipitation": sanitize_float(row.get('total_precipitation')), "date": sanitize_str(row.get('date')), "split": sanitize_str(row.get('split')) } @app.get("/api/similarity/{idx}") def get_similarity_endpoint(idx: int, model: str = Query('DINOv3'), mode: str = Query('birds')): df = get_data(mode) df_val_aligned = get_val_aligned(mode) precomputed_similarity = get_similarity(mode) if df is None or df.empty or df_val_aligned is None or df_val_aligned.empty: raise HTTPException(status_code=500, detail="Data alignment cache not loaded") if idx not in df.index: raise HTTPException(status_code=404, detail="Anchor specimen index not found") # Get anchor specimen filename key anchor_row = df.loc[idx] if isinstance(anchor_row, pd.DataFrame): anchor_row = anchor_row.iloc[0] file_name = anchor_row.get('file_name') # 1. Match anchor in validation set filenames matched_rows = df_val_aligned[df_val_aligned['file_name_key'] == file_name] if matched_rows.empty: # Fallback: Find first validation specimen of the same species species_name = anchor_row.get('name') matched_rows = df_val_aligned[df_val_aligned['name'] == species_name] if matched_rows.empty: raise HTTPException( status_code=404, detail="This specimen is from the train split, and no representative validation specimens exist for its species." ) val_row = matched_rows.iloc[0] file_name_key = val_row['file_name_key'] # 2. Get precomputed similarity scores for this model if model not in precomputed_similarity: raise HTTPException(status_code=400, detail=f"Similarity cache not found for model: {model}") model_matches = precomputed_similarity[model].get(file_name_key, []) # Take top 5 matching specimens matches = [] for match_aligned_idx, score in model_matches: if match_aligned_idx not in df_val_aligned.index: continue row = df_val_aligned.loc[match_aligned_idx] matches.append({ "original_index": int(row.get('original_index', -1)) if not pd.isna(row.get('original_index')) else -1, "name": sanitize_str(row['name']), "common_name": sanitize_str(row.get('common_name')), "lifestyle": sanitize_str(row.get('Primary.Lifestyle')), "trophic": sanitize_str(row.get('Trophic.Level')), "niche": sanitize_str(row.get('Trophic.Niche')), "mass": sanitize_float(row.get('Mass')), "hwi": sanitize_float(row.get('Hand-Wing.Index')), "temp": sanitize_float(row.get('temperature_2m')), "elev": sanitize_float(row.get('elevation')), "ndvi": sanitize_float(row.get('ndvi')), "file_name": sanitize_str(row['file_name_key']), "similarity": float(score), # Additional metadata "order": sanitize_str(row.get('order')), "family": sanitize_str(row.get('family')), "genus": sanitize_str(row.get('genus')), "beak_length": sanitize_float(row.get('Beak.Length_Culmen')), "tarsus_length": sanitize_float(row.get('Tarsus.Length')), "landcover": map_landcover(row.get('landcover_class')), "precipitation": sanitize_float(row.get('total_precipitation')), "date": sanitize_str(row.get('date')), "split": sanitize_str(row.get('split')) }) if len(matches) >= 5: break return { "anchor": { "name": sanitize_str(anchor_row.get('name')), "common_name": sanitize_str(anchor_row.get('common_name')), "file_name": sanitize_str(file_name), "lifestyle": sanitize_str(anchor_row.get('Primary.Lifestyle')), "trophic": sanitize_str(anchor_row.get('Trophic.Level')), "niche": sanitize_str(anchor_row.get('Trophic.Niche')), "mass": sanitize_float(anchor_row.get('Mass')), "hwi": sanitize_float(anchor_row.get('Hand-Wing.Index')), "temp": sanitize_float(anchor_row.get('temperature_2m')), "elev": sanitize_float(anchor_row.get('elevation')), "ndvi": sanitize_float(anchor_row.get('ndvi')), "order": sanitize_str(anchor_row.get('order')), "family": sanitize_str(anchor_row.get('family')), "genus": sanitize_str(anchor_row.get('genus')), "beak_length": sanitize_float(anchor_row.get('Beak.Length_Culmen')), "tarsus_length": sanitize_float(anchor_row.get('Tarsus.Length')), "landcover": map_landcover(anchor_row.get('landcover_class')), "precipitation": sanitize_float(anchor_row.get('total_precipitation')), "date": sanitize_str(anchor_row.get('date')), "split": sanitize_str(anchor_row.get('split')) }, "matches": matches } # Serve main explorer dashboard @app.get("/") def read_root(): from fastapi.responses import FileResponse return FileResponse(os.path.join(STATIC_DIR, "index.html")) if __name__ == "__main__": port = int(os.environ.get("PORT", 8000)) uvicorn.run("server:app", host="0.0.0.0", port=port, reload=False)