Spaces:
Running
Running
| import os | |
| import json | |
| import subprocess | |
| import pandas as pd | |
| import numpy as np | |
| import geopandas as gpd | |
| from fastapi import FastAPI, Query, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| app = FastAPI(title="iNaturalist & AVONET Explorer API") | |
| EXPLORER_MODE = os.environ.get("EXPLORER_MODE", "birds").lower() | |
| # Startup extraction check for images.tar | |
| IMAGE_DIR = "data/images" | |
| BIRDS_CHECK = os.path.join(IMAGE_DIR, "train") | |
| PLANTS_CHECK_DIR = IMAGE_DIR | |
| # Check if birds images exist (look for Aves directories) | |
| def has_bird_images(): | |
| train_dir = os.path.join(IMAGE_DIR, "train") | |
| if not os.path.exists(train_dir): | |
| return False | |
| entries = os.listdir(train_dir) | |
| return any("Aves" in e for e in entries[:20]) | |
| def has_plant_images(): | |
| train_dir = os.path.join(IMAGE_DIR, "train") | |
| if not os.path.exists(train_dir): | |
| return False | |
| entries = os.listdir(train_dir) | |
| return any("Plantae" in e for e in entries[:20]) | |
| if not has_bird_images() or not has_plant_images(): | |
| print("Some thumbnails are missing. Fetching from Hugging Face Dataset...") | |
| from huggingface_hub import hf_hub_download | |
| token = os.environ.get("HF_TOKEN") | |
| # 1. Download Birds Thumbnails (if missing) | |
| if not has_bird_images(): | |
| try: | |
| print("Downloading birds images.tar from Hugging Face Dataset...") | |
| downloaded_path = hf_hub_download( | |
| repo_id="mayesh/inat-thumbnails", | |
| filename="images.tar", | |
| repo_type="dataset", | |
| token=token, | |
| local_dir="." | |
| ) | |
| print(f"Downloaded images.tar to {downloaded_path}. Extracting...") | |
| os.makedirs(IMAGE_DIR, exist_ok=True) | |
| subprocess.run(["tar", "-xf", downloaded_path], check=True) | |
| print("Birds thumbnail extraction complete!") | |
| if os.path.exists(downloaded_path): | |
| os.remove(downloaded_path) | |
| except Exception as e: | |
| print(f"Error fetching/extracting birds images.tar: {e}") | |
| # 2. Download Plants Thumbnails (if missing) | |
| if not has_plant_images(): | |
| try: | |
| print("Downloading plants_images.tar from Hugging Face Dataset...") | |
| downloaded_path_plants = hf_hub_download( | |
| repo_id="mayesh/Side-Info_Generation", | |
| filename="plants_images.tar", | |
| repo_type="dataset", | |
| token=token, | |
| local_dir="." | |
| ) | |
| print(f"Downloaded plants_images.tar to {downloaded_path_plants}. Extracting...") | |
| os.makedirs(IMAGE_DIR, exist_ok=True) | |
| subprocess.run(["tar", "-xf", downloaded_path_plants], check=True) | |
| print("Plants thumbnail extraction complete!") | |
| if os.path.exists(downloaded_path_plants): | |
| os.remove(downloaded_path_plants) | |
| except Exception as e: | |
| print(f"Error fetching/extracting plants_images.tar: {e}") | |
| # Load world boundaries on startup | |
| world_geojson_data = None | |
| geojson_path = 'data/countries.geojson' | |
| if not os.path.exists(geojson_path): | |
| geojson_path = '/scratch/project/prj-02-visual-ai/mayesh/side-info_generation/inat_probing/data/countries.geojson' | |
| try: | |
| with open(geojson_path, 'r') as f: | |
| world_geojson_data = json.load(f) | |
| print(f"Loaded world geojson boundaries from {geojson_path}") | |
| except Exception as e: | |
| print(f"Error loading countries.geojson: {e}") | |
| # Enable CORS for development | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global data holders | |
| data_by_mode = {} | |
| df_val_aligned_by_mode = {} | |
| precomputed_similarity_by_mode = {} | |
| def get_data(mode: str = "birds"): | |
| mode = str(mode).lower() | |
| if mode not in data_by_mode: | |
| return None | |
| return data_by_mode[mode] | |
| def get_val_aligned(mode: str = "birds"): | |
| mode = str(mode).lower() | |
| if mode not in df_val_aligned_by_mode: | |
| return None | |
| return df_val_aligned_by_mode[mode] | |
| def get_similarity(mode: str = "birds"): | |
| mode = str(mode).lower() | |
| if mode not in precomputed_similarity_by_mode: | |
| return {} | |
| return precomputed_similarity_by_mode[mode] | |
| def get_trait_maps(mode: str = "birds"): | |
| df = get_data(mode) | |
| if df is None or df.empty: | |
| return {}, {} | |
| lifestyles = sorted(df['Primary.Lifestyle'].dropna().unique().tolist()) | |
| trophics = sorted(df['Trophic.Level'].dropna().unique().tolist()) | |
| lifestyle_map = {val: idx + 1 for idx, val in enumerate(lifestyles)} | |
| trophic_map = {val: idx + 1 for idx, val in enumerate(trophics)} | |
| return lifestyle_map, trophic_map | |
| def sanitize_str(val): | |
| if pd.isna(val): | |
| return None | |
| return str(val) | |
| def sanitize_float(val): | |
| if pd.isna(val) or np.isinf(val): | |
| return None | |
| return float(val) | |
| def load_and_preprocess_dataset(df_path): | |
| if not os.path.exists(df_path): | |
| print(f"Warning: master dataset parquet not found at {df_path}") | |
| return None | |
| print(f"Loading iNaturalist master dataset from {df_path}...") | |
| df_loaded = pd.read_parquet(df_path) | |
| print(f"Loaded all {len(df_loaded)} observations from {df_path}.") | |
| # Subsample to 250k observations for performance (hover responsiveness, memory) | |
| # Stratified by species to preserve taxonomic diversity | |
| MAX_OBSERVATIONS = 250_000 | |
| if len(df_loaded) > MAX_OBSERVATIONS: | |
| try: | |
| df_loaded = ( | |
| df_loaded.groupby('name', group_keys=False) | |
| .apply(lambda g: g.sample(frac=MAX_OBSERVATIONS / len(df_loaded), random_state=42)) | |
| ) | |
| if len(df_loaded) > MAX_OBSERVATIONS: | |
| df_loaded = df_loaded.sample(n=MAX_OBSERVATIONS, random_state=42) | |
| except Exception as e: | |
| print(f"Stratified sampling failed ({e}), falling back to random sample.") | |
| df_loaded = df_loaded.sample(n=MAX_OBSERVATIONS, random_state=42) | |
| df_loaded = df_loaded.reset_index(drop=True) | |
| print(f"Subsampled to {len(df_loaded)} observations for performance.") | |
| df_loaded['date_dt'] = pd.to_datetime(df_loaded['date'].astype(str)) | |
| df_loaded['year'] = df_loaded['date_dt'].dt.year.fillna(2019).astype(int) | |
| df_loaded['month'] = df_loaded['date_dt'].dt.month.fillna(6).astype(int) | |
| # Standardize column types | |
| df_loaded['latitude'] = df_loaded['latitude'].fillna(0.0).astype(float) | |
| df_loaded['longitude'] = df_loaded['longitude'].fillna(0.0).astype(float) | |
| df_loaded['elevation'] = df_loaded['elevation'].fillna(0.0).astype(float) | |
| df_loaded['temperature_2m'] = df_loaded['temperature_2m'].astype(float) | |
| df_loaded['ndvi'] = df_loaded['ndvi'].astype(float) | |
| df_loaded['Mass'] = df_loaded['Mass'].astype(float) | |
| df_loaded['Hand-Wing.Index'] = df_loaded['Hand-Wing.Index'].astype(float) | |
| # Perform spatial join for country and continent | |
| print("Mapping points geographically (Spatial Join)...") | |
| try: | |
| countries_gdf = gpd.read_file(geojson_path)[['name', 'continent', 'geometry']] | |
| countries_gdf = countries_gdf.rename(columns={'name': 'country_name'}) | |
| points_gdf = gpd.GeoDataFrame( | |
| df_loaded[['longitude', 'latitude']], | |
| geometry=gpd.points_from_xy(df_loaded['longitude'], df_loaded['latitude']), | |
| crs='EPSG:4326' | |
| ) | |
| joined = gpd.sjoin(points_gdf, countries_gdf, how='left', predicate='intersects') | |
| df_loaded['country'] = joined['country_name'].fillna('Ocean/Unknown').astype(str) | |
| df_loaded['continent'] = joined['continent'].fillna('Ocean/Unknown').astype(str) | |
| print("Geographic mapping complete!") | |
| except Exception as e: | |
| print(f"Warning: Geographic mapping failed: {e}") | |
| df_loaded['country'] = 'Unknown' | |
| df_loaded['continent'] = 'Unknown' | |
| return df_loaded | |
| # Load datasets | |
| data_by_mode['birds'] = load_and_preprocess_dataset('metadata/inat_world_model_master.parquet') | |
| data_by_mode['plants'] = load_and_preprocess_dataset('metadata/plants_metadata_with_traits.parquet') | |
| # Load aligned validation files | |
| for mode in ['birds', 'plants']: | |
| aligned_filename = 'df_val_aligned_plants.json' if mode == 'plants' else 'df_val_aligned.json' | |
| aligned_path = os.path.join('metadata', aligned_filename) | |
| try: | |
| df_val_aligned_by_mode[mode] = pd.read_json(aligned_path, orient='records') | |
| print(f"Loaded {mode} validation alignment cache. Shape: {df_val_aligned_by_mode[mode].shape}") | |
| except Exception as e: | |
| print(f"Error loading {mode} validation alignment cache ({aligned_path}): {e}") | |
| df_val_aligned_by_mode[mode] = pd.DataFrame() | |
| # Load similarity caches | |
| for mode in ['birds', 'plants']: | |
| sim_filename = 'similarity_matches_plants.json' if mode == 'plants' else 'similarity_matches.json' | |
| sim_path = os.path.join('metadata', sim_filename) | |
| try: | |
| with open(sim_path, 'r') as f: | |
| precomputed_similarity_by_mode[mode] = json.load(f) | |
| print(f"Loaded {mode} similarity matches cache successfully.") | |
| except Exception as e: | |
| print(f"Error loading {mode} similarity matches cache ({sim_path}): {e}") | |
| precomputed_similarity_by_mode[mode] = {} | |
| # Expose filter categories | |
| def get_filters(mode: str = "birds"): | |
| df = get_data(mode) | |
| if df is None or df.empty: | |
| raise HTTPException(status_code=404, detail=f"Dataset for mode '{mode}' not loaded") | |
| filters = { | |
| "mode": mode, | |
| "orders": sorted(df['order'].dropna().unique().tolist()), | |
| "families": sorted(df['family'].dropna().unique().tolist()), | |
| "species": sorted(df['name'].dropna().unique().tolist()), | |
| "common_names": sorted(df['common_name'].dropna().unique().tolist()), | |
| "continents": sorted(df['continent'].dropna().unique().tolist()), | |
| "countries": sorted(df['country'].dropna().unique().tolist()), | |
| "lifestyles": sorted(df['Primary.Lifestyle'].dropna().unique().tolist()), | |
| "trophic_levels": sorted(df['Trophic.Level'].dropna().unique().tolist()) | |
| } | |
| return filters | |
| # Expose dynamic query endpoint | |
| def query_data(params: dict): | |
| explorer_mode = params.get('explorerMode', 'birds') | |
| active_df = get_data(explorer_mode) | |
| active_val_aligned = get_val_aligned(explorer_mode) | |
| mode = params.get('mode', 'spatial') # spatial or pca | |
| model_name = params.get('model', 'DINOv3').lower() | |
| # Select working dataframe | |
| working_df = active_val_aligned if mode == 'pca' else active_df | |
| if working_df is None or working_df.empty: | |
| return {"points": [], "grid_images": []} | |
| filtered = working_df.copy() | |
| # Apply Taxonomy filters | |
| if params.get('order'): | |
| filtered = filtered[filtered['order'] == params['order']] | |
| if params.get('family'): | |
| filtered = filtered[filtered['family'] == params['family']] | |
| if params.get('species'): | |
| sp = params['species'] | |
| if isinstance(sp, list) and sp: | |
| filtered = filtered[filtered['name'].isin(sp)] | |
| elif isinstance(sp, str) and sp: | |
| filtered = filtered[filtered['name'] == sp] | |
| # Apply global text search query | |
| if params.get('q'): | |
| q = str(params['q']).strip().lower() | |
| if q: | |
| search_cols = ['name', 'common_name', 'genus', 'family', 'order', 'country', 'continent'] | |
| masks = [] | |
| for col in search_cols: | |
| if col in filtered.columns: | |
| masks.append(filtered[col].str.lower().str.contains(q, na=False)) | |
| if masks: | |
| final_mask = masks[0] | |
| for m in masks[1:]: | |
| final_mask = final_mask | m | |
| filtered = filtered[final_mask] | |
| # Apply Categorical filters | |
| if params.get('lifestyle'): | |
| filtered = filtered[filtered['Primary.Lifestyle'] == params['lifestyle']] | |
| if params.get('trophic'): | |
| filtered = filtered[filtered['Trophic.Level'] == params['trophic']] | |
| # Apply Year filter (Cumulative vs Single Year) | |
| if params.get('year'): | |
| try: | |
| year_val = int(params['year']) | |
| year_mode = params.get('year_mode', 'cumulative') | |
| if year_mode == 'cumulative': | |
| filtered = filtered[filtered['year'] <= year_val] | |
| else: | |
| filtered = filtered[filtered['year'] == year_val] | |
| except Exception as e: | |
| print(f"Error filtering by year: {e}") | |
| # Apply Ranges | |
| if params.get('temp'): | |
| filtered = filtered[(filtered['temperature_2m'] >= params['temp'][0]) & (filtered['temperature_2m'] <= params['temp'][1])] | |
| if params.get('elev'): | |
| filtered = filtered[(filtered['elevation'] >= params['elev'][0]) & (filtered['elevation'] <= params['elev'][1])] | |
| if params.get('ndvi'): | |
| filtered = filtered[(filtered['ndvi'] >= params['ndvi'][0]) & (filtered['ndvi'] <= params['ndvi'][1])] | |
| if params.get('mass'): | |
| filtered = filtered[(filtered['Mass'] >= params['mass'][0]) & (filtered['Mass'] <= params['mass'][1])] | |
| if params.get('hwi'): | |
| filtered = filtered[(filtered['Hand-Wing.Index'] >= params['hwi'][0]) & (filtered['Hand-Wing.Index'] <= params['hwi'][1])] | |
| # Apply Spatial bounds if spatial mode and bounds are active | |
| if mode == 'spatial' and params.get('bounds'): | |
| b = params['bounds'] | |
| min_lat, max_lat = b['south'], b['north'] | |
| min_lon, max_lon = b['west'], b['east'] | |
| # Handle coordinate normalization | |
| filtered = filtered[(filtered['latitude'] >= min_lat) & (filtered['latitude'] <= max_lat)] | |
| # Handle crossing antimeridian | |
| if min_lon > max_lon: | |
| filtered = filtered[(filtered['longitude'] >= min_lon) | (filtered['longitude'] <= max_lon)] | |
| else: | |
| filtered = filtered[(filtered['longitude'] >= min_lon) & (filtered['longitude'] <= max_lon)] | |
| total_matches = len(filtered) | |
| # Generate random sample of 24 images for the right grid panel | |
| grid_sample = filtered.dropna(subset=['file_name']) | |
| if not grid_sample.empty: | |
| grid_sample = grid_sample.sample(n=min(len(grid_sample), 24), random_state=42) | |
| grid_images = [] | |
| for _, row in grid_sample.iterrows(): | |
| grid_images.append({ | |
| "image_id": str(row.get('image_id', '')) if not pd.isna(row.get('image_id')) else '', | |
| "file_name": sanitize_str(row['file_name']), | |
| "name": sanitize_str(row['name']), | |
| "common_name": sanitize_str(row.get('common_name')), | |
| "lifestyle": sanitize_str(row.get('Primary.Lifestyle')), | |
| "trophic": sanitize_str(row.get('Trophic.Level')), | |
| "niche": sanitize_str(row.get('Trophic.Niche')), | |
| "mass": sanitize_float(row.get('Mass')), | |
| "hwi": sanitize_float(row.get('Hand-Wing.Index')), | |
| "temp": sanitize_float(row.get('temperature_2m')), | |
| "elev": sanitize_float(row.get('elevation')), | |
| "ndvi": sanitize_float(row.get('ndvi')), | |
| "lat": sanitize_float(row['latitude']), | |
| "lon": sanitize_float(row['longitude']) | |
| }) | |
| else: | |
| grid_images = [] | |
| points = [] | |
| # Process Map/PCA point display outputs | |
| if mode == 'spatial': | |
| # Apply Zoom-dependent Bounding-Box smart sampling to ensure smooth Leaflet Canvas drawing | |
| zoom = params.get('zoom', 2) | |
| sample_limit = 5000 if zoom < 5 else 15000 | |
| if len(filtered) > sample_limit: | |
| plot_sample = filtered.sample(n=sample_limit, random_state=42) | |
| else: | |
| plot_sample = filtered | |
| for _, row in plot_sample.iterrows(): | |
| points.append({ | |
| "lat": sanitize_float(row['latitude']), | |
| "lon": sanitize_float(row['longitude']), | |
| "name": sanitize_str(row['name']), | |
| "common_name": sanitize_str(row.get('common_name')), | |
| "order": sanitize_str(row.get('order')), | |
| "lifestyle": sanitize_str(row.get('Primary.Lifestyle')), | |
| "trophic": sanitize_str(row.get('Trophic.Level')), | |
| "niche": sanitize_str(row.get('Trophic.Niche')), | |
| "mass": sanitize_float(row.get('Mass')), | |
| "hwi": sanitize_float(row.get('Hand-Wing.Index')), | |
| "temp": sanitize_float(row.get('temperature_2m')), | |
| "elev": sanitize_float(row.get('elevation')), | |
| "ndvi": sanitize_float(row.get('ndvi')), | |
| "file_name": sanitize_str(row['file_name']) | |
| }) | |
| else: | |
| # PCA Mode: Return PCA coordinates for chosen model | |
| pca1_col = f'{model_name}_pca1' | |
| pca2_col = f'{model_name}_pca2' | |
| # Limit to 3000 points in PCA space to prevent canvas overlap/lag | |
| if len(filtered) > 3000: | |
| plot_sample = filtered.sample(n=3000, random_state=42) | |
| else: | |
| plot_sample = filtered | |
| for _, row in plot_sample.iterrows(): | |
| points.append({ | |
| "pca1": sanitize_float(row[pca1_col]), | |
| "pca2": sanitize_float(row[pca2_col]), | |
| "name": sanitize_str(row['name']), | |
| "common_name": sanitize_str(row.get('common_name')), | |
| "lifestyle": sanitize_str(row.get('Primary.Lifestyle')), | |
| "trophic": sanitize_str(row.get('Trophic.Level')), | |
| "niche": sanitize_str(row.get('Trophic.Niche')), | |
| "mass": sanitize_float(row.get('Mass')), | |
| "hwi": sanitize_float(row.get('Hand-Wing.Index')), | |
| "temp": sanitize_float(row.get('temperature_2m')), | |
| "elev": sanitize_float(row.get('elevation')), | |
| "ndvi": sanitize_float(row.get('ndvi')), | |
| "file_name": sanitize_str(row['file_name_key']) | |
| }) | |
| return { | |
| "points": points, | |
| "grid_images": grid_images, | |
| "total_matches": total_matches | |
| } | |
| # Mount static files and images | |
| STATIC_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'static') | |
| os.makedirs(STATIC_DIR, exist_ok=True) | |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") | |
| IMAGE_DIR = "data/images" | |
| if not os.path.exists(IMAGE_DIR): | |
| IMAGE_DIR = "/scratch/project/prj-02-visual-ai/mayesh/side-info_generation/inat_probing/data/images" | |
| if os.path.exists(IMAGE_DIR): | |
| app.mount("/images", StaticFiles(directory=IMAGE_DIR), name="images") | |
| def get_world_geojson(): | |
| global world_geojson_data | |
| if world_geojson_data is None: | |
| raise HTTPException(status_code=404, detail="World GeoJSON map boundaries not loaded on server startup") | |
| return world_geojson_data | |
| def get_all_coordinates(mode: str = "birds"): | |
| df = get_data(mode) | |
| if df is None or df.empty: | |
| raise HTTPException(status_code=500, detail=f"Dataset for mode '{mode}' not loaded") | |
| lifestyle_map, trophic_map = get_trait_maps(mode) | |
| lat = df['latitude'].fillna(0.0).values | |
| lon = df['longitude'].fillna(0.0).values | |
| elev = df['elevation'].fillna(0.0).values | |
| year = df['year'].fillna(2019).values.astype(int) | |
| idx = df.index.values.astype(int) | |
| temp = df['temperature_2m'].fillna(286.2).values | |
| temp = np.where(temp < -1000, 286.2, temp) | |
| ndvi = df['ndvi'].fillna(0.0).values | |
| ndvi = np.where(ndvi < -1000, 0.0, ndvi) | |
| lifestyle = df['Primary.Lifestyle'].astype(object).map(lifestyle_map).fillna(0).values.astype(int) | |
| trophic = df['Trophic.Level'].astype(object).map(trophic_map).fillna(0).values.astype(int) | |
| flat = np.column_stack((lat, lon, elev, idx, year, temp, ndvi, lifestyle, trophic)).flatten() | |
| return {"points": flat.tolist()} | |
| def query_coordinates(params: dict): | |
| explorer_mode = params.get('explorerMode', 'birds') | |
| df = get_data(explorer_mode) | |
| if df is None or df.empty: | |
| raise HTTPException(status_code=500, detail=f"Dataset for mode '{explorer_mode}' not loaded") | |
| filtered = df.copy() | |
| # 1. Taxonomy filters | |
| if params.get('order'): | |
| filtered = filtered[filtered['order'] == params['order']] | |
| if params.get('family'): | |
| filtered = filtered[filtered['family'] == params['family']] | |
| if params.get('species'): | |
| sp = params['species'] | |
| if isinstance(sp, list) and sp: | |
| filtered = filtered[filtered['name'].isin(sp)] | |
| elif isinstance(sp, str) and sp: | |
| filtered = filtered[filtered['name'] == sp] | |
| # Apply global text search query | |
| if params.get('q'): | |
| q = str(params['q']).strip().lower() | |
| if q: | |
| search_cols = ['name', 'common_name', 'genus', 'family', 'order', 'country', 'continent'] | |
| masks = [] | |
| for col in search_cols: | |
| if col in filtered.columns: | |
| masks.append(filtered[col].str.lower().str.contains(q, na=False)) | |
| if masks: | |
| final_mask = masks[0] | |
| for m in masks[1:]: | |
| final_mask = final_mask | m | |
| filtered = filtered[final_mask] | |
| # 2. Location filters | |
| if params.get('continent'): | |
| filtered = filtered[filtered['continent'] == params['continent']] | |
| if params.get('country'): | |
| filtered = filtered[filtered['country'] == params['country']] | |
| if params.get('hemisphere'): | |
| if params['hemisphere'] == 'Northern': | |
| filtered = filtered[filtered['latitude'] >= 0] | |
| elif params['hemisphere'] == 'Southern': | |
| filtered = filtered[filtered['latitude'] < 0] | |
| # 3. Categorical trait filters | |
| if params.get('lifestyle'): | |
| filtered = filtered[filtered['Primary.Lifestyle'] == params['lifestyle']] | |
| if params.get('trophic'): | |
| filtered = filtered[filtered['Trophic.Level'] == params['trophic']] | |
| lifestyle_map, trophic_map = get_trait_maps(explorer_mode) | |
| lat = filtered['latitude'].fillna(0.0).values | |
| lon = filtered['longitude'].fillna(0.0).values | |
| elev = filtered['elevation'].fillna(0.0).values | |
| year = filtered['year'].fillna(2019).values.astype(int) | |
| idx = filtered.index.values.astype(int) | |
| temp = filtered['temperature_2m'].fillna(286.2).values | |
| temp = np.where(temp < -1000, 286.2, temp) | |
| ndvi = filtered['ndvi'].fillna(0.0).values | |
| ndvi = np.where(ndvi < -1000, 0.0, ndvi) | |
| lifestyle = filtered['Primary.Lifestyle'].astype(object).map(lifestyle_map).fillna(0).values.astype(int) | |
| trophic = filtered['Trophic.Level'].astype(object).map(trophic_map).fillna(0).values.astype(int) | |
| flat = np.column_stack((lat, lon, elev, idx, year, temp, ndvi, lifestyle, trophic)).flatten() | |
| return {"points": flat.tolist()} | |
| landcover_map = { | |
| 0: "Unknown", | |
| 20: "Shrubs", | |
| 30: "Herbaceous vegetation", | |
| 40: "Cultivated / Agriculture", | |
| 50: "Urban / Built-up", | |
| 60: "Bare / Sparse vegetation", | |
| 70: "Snow / Ice", | |
| 80: "Permanent water", | |
| 90: "Herbaceous wetland", | |
| 100: "Moss / Lichen", | |
| 111: "Closed Forest (Needleleaf)", | |
| 112: "Closed Forest (Broadleaf)", | |
| 113: "Closed Forest (Deciduous Needleleaf)", | |
| 114: "Closed Forest (Deciduous Broadleaf)", | |
| 115: "Closed Forest (Mixed)", | |
| 116: "Closed Forest (Unknown)", | |
| 121: "Open Forest (Needleleaf)", | |
| 122: "Open Forest (Broadleaf)", | |
| 123: "Open Forest (Deciduous Needleleaf)", | |
| 124: "Open Forest (Deciduous Broadleaf)", | |
| 125: "Open Forest (Mixed)", | |
| 126: "Open Forest (Unknown)", | |
| 200: "Ocean / Open Water" | |
| } | |
| def map_landcover(val): | |
| if pd.isna(val): | |
| return "Unknown" | |
| try: | |
| code = int(float(val)) | |
| return landcover_map.get(code, "Unknown") | |
| except (ValueError, TypeError): | |
| return "Unknown" | |
| def get_details(idx: int, mode: str = "birds"): | |
| df = get_data(mode) | |
| if df is None or df.empty: | |
| raise HTTPException(status_code=500, detail=f"Dataset for mode '{mode}' not loaded") | |
| if idx not in df.index: | |
| raise HTTPException(status_code=404, detail=f"Specimen details not found for index {idx} in {mode}") | |
| row = df.loc[idx] | |
| if isinstance(row, pd.DataFrame): | |
| row = row.iloc[0] | |
| return { | |
| "name": sanitize_str(row.get('name')), | |
| "common_name": sanitize_str(row.get('common_name')), | |
| "lifestyle": sanitize_str(row.get('Primary.Lifestyle')), | |
| "trophic": sanitize_str(row.get('Trophic.Level')), | |
| "niche": sanitize_str(row.get('Trophic.Niche')), | |
| "mass": sanitize_float(row.get('Mass')), | |
| "hwi": sanitize_float(row.get('Hand-Wing.Index')), | |
| "temp": sanitize_float(row.get('temperature_2m')), | |
| "elev": sanitize_float(row.get('elevation')), | |
| "ndvi": sanitize_float(row.get('ndvi')), | |
| "file_name": sanitize_str(row.get('file_name')), | |
| # Additional metadata | |
| "order": sanitize_str(row.get('order')), | |
| "family": sanitize_str(row.get('family')), | |
| "genus": sanitize_str(row.get('genus')), | |
| "beak_length": sanitize_float(row.get('Beak.Length_Culmen')), | |
| "tarsus_length": sanitize_float(row.get('Tarsus.Length')), | |
| "landcover": map_landcover(row.get('landcover_class')), | |
| "precipitation": sanitize_float(row.get('total_precipitation')), | |
| "date": sanitize_str(row.get('date')), | |
| "split": sanitize_str(row.get('split')) | |
| } | |
| def get_similarity_endpoint(idx: int, model: str = Query('DINOv3'), mode: str = Query('birds')): | |
| df = get_data(mode) | |
| df_val_aligned = get_val_aligned(mode) | |
| precomputed_similarity = get_similarity(mode) | |
| if df is None or df.empty or df_val_aligned is None or df_val_aligned.empty: | |
| raise HTTPException(status_code=500, detail="Data alignment cache not loaded") | |
| if idx not in df.index: | |
| raise HTTPException(status_code=404, detail="Anchor specimen index not found") | |
| # Get anchor specimen filename key | |
| anchor_row = df.loc[idx] | |
| if isinstance(anchor_row, pd.DataFrame): | |
| anchor_row = anchor_row.iloc[0] | |
| file_name = anchor_row.get('file_name') | |
| # 1. Match anchor in validation set filenames | |
| matched_rows = df_val_aligned[df_val_aligned['file_name_key'] == file_name] | |
| if matched_rows.empty: | |
| # Fallback: Find first validation specimen of the same species | |
| species_name = anchor_row.get('name') | |
| matched_rows = df_val_aligned[df_val_aligned['name'] == species_name] | |
| if matched_rows.empty: | |
| raise HTTPException( | |
| status_code=404, | |
| detail="This specimen is from the train split, and no representative validation specimens exist for its species." | |
| ) | |
| val_row = matched_rows.iloc[0] | |
| file_name_key = val_row['file_name_key'] | |
| # 2. Get precomputed similarity scores for this model | |
| if model not in precomputed_similarity: | |
| raise HTTPException(status_code=400, detail=f"Similarity cache not found for model: {model}") | |
| model_matches = precomputed_similarity[model].get(file_name_key, []) | |
| # Take top 5 matching specimens | |
| matches = [] | |
| for match_aligned_idx, score in model_matches: | |
| if match_aligned_idx not in df_val_aligned.index: | |
| continue | |
| row = df_val_aligned.loc[match_aligned_idx] | |
| matches.append({ | |
| "original_index": int(row.get('original_index', -1)) if not pd.isna(row.get('original_index')) else -1, | |
| "name": sanitize_str(row['name']), | |
| "common_name": sanitize_str(row.get('common_name')), | |
| "lifestyle": sanitize_str(row.get('Primary.Lifestyle')), | |
| "trophic": sanitize_str(row.get('Trophic.Level')), | |
| "niche": sanitize_str(row.get('Trophic.Niche')), | |
| "mass": sanitize_float(row.get('Mass')), | |
| "hwi": sanitize_float(row.get('Hand-Wing.Index')), | |
| "temp": sanitize_float(row.get('temperature_2m')), | |
| "elev": sanitize_float(row.get('elevation')), | |
| "ndvi": sanitize_float(row.get('ndvi')), | |
| "file_name": sanitize_str(row['file_name_key']), | |
| "similarity": float(score), | |
| # Additional metadata | |
| "order": sanitize_str(row.get('order')), | |
| "family": sanitize_str(row.get('family')), | |
| "genus": sanitize_str(row.get('genus')), | |
| "beak_length": sanitize_float(row.get('Beak.Length_Culmen')), | |
| "tarsus_length": sanitize_float(row.get('Tarsus.Length')), | |
| "landcover": map_landcover(row.get('landcover_class')), | |
| "precipitation": sanitize_float(row.get('total_precipitation')), | |
| "date": sanitize_str(row.get('date')), | |
| "split": sanitize_str(row.get('split')) | |
| }) | |
| if len(matches) >= 5: | |
| break | |
| return { | |
| "anchor": { | |
| "name": sanitize_str(anchor_row.get('name')), | |
| "common_name": sanitize_str(anchor_row.get('common_name')), | |
| "file_name": sanitize_str(file_name), | |
| "lifestyle": sanitize_str(anchor_row.get('Primary.Lifestyle')), | |
| "trophic": sanitize_str(anchor_row.get('Trophic.Level')), | |
| "niche": sanitize_str(anchor_row.get('Trophic.Niche')), | |
| "mass": sanitize_float(anchor_row.get('Mass')), | |
| "hwi": sanitize_float(anchor_row.get('Hand-Wing.Index')), | |
| "temp": sanitize_float(anchor_row.get('temperature_2m')), | |
| "elev": sanitize_float(anchor_row.get('elevation')), | |
| "ndvi": sanitize_float(anchor_row.get('ndvi')), | |
| "order": sanitize_str(anchor_row.get('order')), | |
| "family": sanitize_str(anchor_row.get('family')), | |
| "genus": sanitize_str(anchor_row.get('genus')), | |
| "beak_length": sanitize_float(anchor_row.get('Beak.Length_Culmen')), | |
| "tarsus_length": sanitize_float(anchor_row.get('Tarsus.Length')), | |
| "landcover": map_landcover(anchor_row.get('landcover_class')), | |
| "precipitation": sanitize_float(anchor_row.get('total_precipitation')), | |
| "date": sanitize_str(anchor_row.get('date')), | |
| "split": sanitize_str(anchor_row.get('split')) | |
| }, | |
| "matches": matches | |
| } | |
| # Serve main explorer dashboard | |
| def read_root(): | |
| from fastapi.responses import FileResponse | |
| return FileResponse(os.path.join(STATIC_DIR, "index.html")) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 8000)) | |
| uvicorn.run("server:app", host="0.0.0.0", port=port, reload=False) | |