mayesh's picture
fix: Image loading and NN match detail expansion
85698dc
Raw
History Blame Contribute Delete
31.2 kB
import os
import json
import subprocess
import pandas as pd
import numpy as np
import geopandas as gpd
from fastapi import FastAPI, Query, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
app = FastAPI(title="iNaturalist & AVONET Explorer API")
EXPLORER_MODE = os.environ.get("EXPLORER_MODE", "birds").lower()
# Startup extraction check for images.tar
IMAGE_DIR = "data/images"
BIRDS_CHECK = os.path.join(IMAGE_DIR, "train")
PLANTS_CHECK_DIR = IMAGE_DIR
# Check if birds images exist (look for Aves directories)
def has_bird_images():
train_dir = os.path.join(IMAGE_DIR, "train")
if not os.path.exists(train_dir):
return False
entries = os.listdir(train_dir)
return any("Aves" in e for e in entries[:20])
def has_plant_images():
train_dir = os.path.join(IMAGE_DIR, "train")
if not os.path.exists(train_dir):
return False
entries = os.listdir(train_dir)
return any("Plantae" in e for e in entries[:20])
if not has_bird_images() or not has_plant_images():
print("Some thumbnails are missing. Fetching from Hugging Face Dataset...")
from huggingface_hub import hf_hub_download
token = os.environ.get("HF_TOKEN")
# 1. Download Birds Thumbnails (if missing)
if not has_bird_images():
try:
print("Downloading birds images.tar from Hugging Face Dataset...")
downloaded_path = hf_hub_download(
repo_id="mayesh/inat-thumbnails",
filename="images.tar",
repo_type="dataset",
token=token,
local_dir="."
)
print(f"Downloaded images.tar to {downloaded_path}. Extracting...")
os.makedirs(IMAGE_DIR, exist_ok=True)
subprocess.run(["tar", "-xf", downloaded_path], check=True)
print("Birds thumbnail extraction complete!")
if os.path.exists(downloaded_path):
os.remove(downloaded_path)
except Exception as e:
print(f"Error fetching/extracting birds images.tar: {e}")
# 2. Download Plants Thumbnails (if missing)
if not has_plant_images():
try:
print("Downloading plants_images.tar from Hugging Face Dataset...")
downloaded_path_plants = hf_hub_download(
repo_id="mayesh/Side-Info_Generation",
filename="plants_images.tar",
repo_type="dataset",
token=token,
local_dir="."
)
print(f"Downloaded plants_images.tar to {downloaded_path_plants}. Extracting...")
os.makedirs(IMAGE_DIR, exist_ok=True)
subprocess.run(["tar", "-xf", downloaded_path_plants], check=True)
print("Plants thumbnail extraction complete!")
if os.path.exists(downloaded_path_plants):
os.remove(downloaded_path_plants)
except Exception as e:
print(f"Error fetching/extracting plants_images.tar: {e}")
# Load world boundaries on startup
world_geojson_data = None
geojson_path = 'data/countries.geojson'
if not os.path.exists(geojson_path):
geojson_path = '/scratch/project/prj-02-visual-ai/mayesh/side-info_generation/inat_probing/data/countries.geojson'
try:
with open(geojson_path, 'r') as f:
world_geojson_data = json.load(f)
print(f"Loaded world geojson boundaries from {geojson_path}")
except Exception as e:
print(f"Error loading countries.geojson: {e}")
# Enable CORS for development
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global data holders
data_by_mode = {}
df_val_aligned_by_mode = {}
precomputed_similarity_by_mode = {}
def get_data(mode: str = "birds"):
mode = str(mode).lower()
if mode not in data_by_mode:
return None
return data_by_mode[mode]
def get_val_aligned(mode: str = "birds"):
mode = str(mode).lower()
if mode not in df_val_aligned_by_mode:
return None
return df_val_aligned_by_mode[mode]
def get_similarity(mode: str = "birds"):
mode = str(mode).lower()
if mode not in precomputed_similarity_by_mode:
return {}
return precomputed_similarity_by_mode[mode]
def get_trait_maps(mode: str = "birds"):
df = get_data(mode)
if df is None or df.empty:
return {}, {}
lifestyles = sorted(df['Primary.Lifestyle'].dropna().unique().tolist())
trophics = sorted(df['Trophic.Level'].dropna().unique().tolist())
lifestyle_map = {val: idx + 1 for idx, val in enumerate(lifestyles)}
trophic_map = {val: idx + 1 for idx, val in enumerate(trophics)}
return lifestyle_map, trophic_map
def sanitize_str(val):
if pd.isna(val):
return None
return str(val)
def sanitize_float(val):
if pd.isna(val) or np.isinf(val):
return None
return float(val)
def load_and_preprocess_dataset(df_path):
if not os.path.exists(df_path):
print(f"Warning: master dataset parquet not found at {df_path}")
return None
print(f"Loading iNaturalist master dataset from {df_path}...")
df_loaded = pd.read_parquet(df_path)
print(f"Loaded all {len(df_loaded)} observations from {df_path}.")
# Subsample to 250k observations for performance (hover responsiveness, memory)
# Stratified by species to preserve taxonomic diversity
MAX_OBSERVATIONS = 250_000
if len(df_loaded) > MAX_OBSERVATIONS:
try:
df_loaded = (
df_loaded.groupby('name', group_keys=False)
.apply(lambda g: g.sample(frac=MAX_OBSERVATIONS / len(df_loaded), random_state=42))
)
if len(df_loaded) > MAX_OBSERVATIONS:
df_loaded = df_loaded.sample(n=MAX_OBSERVATIONS, random_state=42)
except Exception as e:
print(f"Stratified sampling failed ({e}), falling back to random sample.")
df_loaded = df_loaded.sample(n=MAX_OBSERVATIONS, random_state=42)
df_loaded = df_loaded.reset_index(drop=True)
print(f"Subsampled to {len(df_loaded)} observations for performance.")
df_loaded['date_dt'] = pd.to_datetime(df_loaded['date'].astype(str))
df_loaded['year'] = df_loaded['date_dt'].dt.year.fillna(2019).astype(int)
df_loaded['month'] = df_loaded['date_dt'].dt.month.fillna(6).astype(int)
# Standardize column types
df_loaded['latitude'] = df_loaded['latitude'].fillna(0.0).astype(float)
df_loaded['longitude'] = df_loaded['longitude'].fillna(0.0).astype(float)
df_loaded['elevation'] = df_loaded['elevation'].fillna(0.0).astype(float)
df_loaded['temperature_2m'] = df_loaded['temperature_2m'].astype(float)
df_loaded['ndvi'] = df_loaded['ndvi'].astype(float)
df_loaded['Mass'] = df_loaded['Mass'].astype(float)
df_loaded['Hand-Wing.Index'] = df_loaded['Hand-Wing.Index'].astype(float)
# Perform spatial join for country and continent
print("Mapping points geographically (Spatial Join)...")
try:
countries_gdf = gpd.read_file(geojson_path)[['name', 'continent', 'geometry']]
countries_gdf = countries_gdf.rename(columns={'name': 'country_name'})
points_gdf = gpd.GeoDataFrame(
df_loaded[['longitude', 'latitude']],
geometry=gpd.points_from_xy(df_loaded['longitude'], df_loaded['latitude']),
crs='EPSG:4326'
)
joined = gpd.sjoin(points_gdf, countries_gdf, how='left', predicate='intersects')
df_loaded['country'] = joined['country_name'].fillna('Ocean/Unknown').astype(str)
df_loaded['continent'] = joined['continent'].fillna('Ocean/Unknown').astype(str)
print("Geographic mapping complete!")
except Exception as e:
print(f"Warning: Geographic mapping failed: {e}")
df_loaded['country'] = 'Unknown'
df_loaded['continent'] = 'Unknown'
return df_loaded
# Load datasets
data_by_mode['birds'] = load_and_preprocess_dataset('metadata/inat_world_model_master.parquet')
data_by_mode['plants'] = load_and_preprocess_dataset('metadata/plants_metadata_with_traits.parquet')
# Load aligned validation files
for mode in ['birds', 'plants']:
aligned_filename = 'df_val_aligned_plants.json' if mode == 'plants' else 'df_val_aligned.json'
aligned_path = os.path.join('metadata', aligned_filename)
try:
df_val_aligned_by_mode[mode] = pd.read_json(aligned_path, orient='records')
print(f"Loaded {mode} validation alignment cache. Shape: {df_val_aligned_by_mode[mode].shape}")
except Exception as e:
print(f"Error loading {mode} validation alignment cache ({aligned_path}): {e}")
df_val_aligned_by_mode[mode] = pd.DataFrame()
# Load similarity caches
for mode in ['birds', 'plants']:
sim_filename = 'similarity_matches_plants.json' if mode == 'plants' else 'similarity_matches.json'
sim_path = os.path.join('metadata', sim_filename)
try:
with open(sim_path, 'r') as f:
precomputed_similarity_by_mode[mode] = json.load(f)
print(f"Loaded {mode} similarity matches cache successfully.")
except Exception as e:
print(f"Error loading {mode} similarity matches cache ({sim_path}): {e}")
precomputed_similarity_by_mode[mode] = {}
# Expose filter categories
@app.get("/api/filters")
def get_filters(mode: str = "birds"):
df = get_data(mode)
if df is None or df.empty:
raise HTTPException(status_code=404, detail=f"Dataset for mode '{mode}' not loaded")
filters = {
"mode": mode,
"orders": sorted(df['order'].dropna().unique().tolist()),
"families": sorted(df['family'].dropna().unique().tolist()),
"species": sorted(df['name'].dropna().unique().tolist()),
"common_names": sorted(df['common_name'].dropna().unique().tolist()),
"continents": sorted(df['continent'].dropna().unique().tolist()),
"countries": sorted(df['country'].dropna().unique().tolist()),
"lifestyles": sorted(df['Primary.Lifestyle'].dropna().unique().tolist()),
"trophic_levels": sorted(df['Trophic.Level'].dropna().unique().tolist())
}
return filters
# Expose dynamic query endpoint
@app.post("/api/query")
def query_data(params: dict):
explorer_mode = params.get('explorerMode', 'birds')
active_df = get_data(explorer_mode)
active_val_aligned = get_val_aligned(explorer_mode)
mode = params.get('mode', 'spatial') # spatial or pca
model_name = params.get('model', 'DINOv3').lower()
# Select working dataframe
working_df = active_val_aligned if mode == 'pca' else active_df
if working_df is None or working_df.empty:
return {"points": [], "grid_images": []}
filtered = working_df.copy()
# Apply Taxonomy filters
if params.get('order'):
filtered = filtered[filtered['order'] == params['order']]
if params.get('family'):
filtered = filtered[filtered['family'] == params['family']]
if params.get('species'):
sp = params['species']
if isinstance(sp, list) and sp:
filtered = filtered[filtered['name'].isin(sp)]
elif isinstance(sp, str) and sp:
filtered = filtered[filtered['name'] == sp]
# Apply global text search query
if params.get('q'):
q = str(params['q']).strip().lower()
if q:
search_cols = ['name', 'common_name', 'genus', 'family', 'order', 'country', 'continent']
masks = []
for col in search_cols:
if col in filtered.columns:
masks.append(filtered[col].str.lower().str.contains(q, na=False))
if masks:
final_mask = masks[0]
for m in masks[1:]:
final_mask = final_mask | m
filtered = filtered[final_mask]
# Apply Categorical filters
if params.get('lifestyle'):
filtered = filtered[filtered['Primary.Lifestyle'] == params['lifestyle']]
if params.get('trophic'):
filtered = filtered[filtered['Trophic.Level'] == params['trophic']]
# Apply Year filter (Cumulative vs Single Year)
if params.get('year'):
try:
year_val = int(params['year'])
year_mode = params.get('year_mode', 'cumulative')
if year_mode == 'cumulative':
filtered = filtered[filtered['year'] <= year_val]
else:
filtered = filtered[filtered['year'] == year_val]
except Exception as e:
print(f"Error filtering by year: {e}")
# Apply Ranges
if params.get('temp'):
filtered = filtered[(filtered['temperature_2m'] >= params['temp'][0]) & (filtered['temperature_2m'] <= params['temp'][1])]
if params.get('elev'):
filtered = filtered[(filtered['elevation'] >= params['elev'][0]) & (filtered['elevation'] <= params['elev'][1])]
if params.get('ndvi'):
filtered = filtered[(filtered['ndvi'] >= params['ndvi'][0]) & (filtered['ndvi'] <= params['ndvi'][1])]
if params.get('mass'):
filtered = filtered[(filtered['Mass'] >= params['mass'][0]) & (filtered['Mass'] <= params['mass'][1])]
if params.get('hwi'):
filtered = filtered[(filtered['Hand-Wing.Index'] >= params['hwi'][0]) & (filtered['Hand-Wing.Index'] <= params['hwi'][1])]
# Apply Spatial bounds if spatial mode and bounds are active
if mode == 'spatial' and params.get('bounds'):
b = params['bounds']
min_lat, max_lat = b['south'], b['north']
min_lon, max_lon = b['west'], b['east']
# Handle coordinate normalization
filtered = filtered[(filtered['latitude'] >= min_lat) & (filtered['latitude'] <= max_lat)]
# Handle crossing antimeridian
if min_lon > max_lon:
filtered = filtered[(filtered['longitude'] >= min_lon) | (filtered['longitude'] <= max_lon)]
else:
filtered = filtered[(filtered['longitude'] >= min_lon) & (filtered['longitude'] <= max_lon)]
total_matches = len(filtered)
# Generate random sample of 24 images for the right grid panel
grid_sample = filtered.dropna(subset=['file_name'])
if not grid_sample.empty:
grid_sample = grid_sample.sample(n=min(len(grid_sample), 24), random_state=42)
grid_images = []
for _, row in grid_sample.iterrows():
grid_images.append({
"image_id": str(row.get('image_id', '')) if not pd.isna(row.get('image_id')) else '',
"file_name": sanitize_str(row['file_name']),
"name": sanitize_str(row['name']),
"common_name": sanitize_str(row.get('common_name')),
"lifestyle": sanitize_str(row.get('Primary.Lifestyle')),
"trophic": sanitize_str(row.get('Trophic.Level')),
"niche": sanitize_str(row.get('Trophic.Niche')),
"mass": sanitize_float(row.get('Mass')),
"hwi": sanitize_float(row.get('Hand-Wing.Index')),
"temp": sanitize_float(row.get('temperature_2m')),
"elev": sanitize_float(row.get('elevation')),
"ndvi": sanitize_float(row.get('ndvi')),
"lat": sanitize_float(row['latitude']),
"lon": sanitize_float(row['longitude'])
})
else:
grid_images = []
points = []
# Process Map/PCA point display outputs
if mode == 'spatial':
# Apply Zoom-dependent Bounding-Box smart sampling to ensure smooth Leaflet Canvas drawing
zoom = params.get('zoom', 2)
sample_limit = 5000 if zoom < 5 else 15000
if len(filtered) > sample_limit:
plot_sample = filtered.sample(n=sample_limit, random_state=42)
else:
plot_sample = filtered
for _, row in plot_sample.iterrows():
points.append({
"lat": sanitize_float(row['latitude']),
"lon": sanitize_float(row['longitude']),
"name": sanitize_str(row['name']),
"common_name": sanitize_str(row.get('common_name')),
"order": sanitize_str(row.get('order')),
"lifestyle": sanitize_str(row.get('Primary.Lifestyle')),
"trophic": sanitize_str(row.get('Trophic.Level')),
"niche": sanitize_str(row.get('Trophic.Niche')),
"mass": sanitize_float(row.get('Mass')),
"hwi": sanitize_float(row.get('Hand-Wing.Index')),
"temp": sanitize_float(row.get('temperature_2m')),
"elev": sanitize_float(row.get('elevation')),
"ndvi": sanitize_float(row.get('ndvi')),
"file_name": sanitize_str(row['file_name'])
})
else:
# PCA Mode: Return PCA coordinates for chosen model
pca1_col = f'{model_name}_pca1'
pca2_col = f'{model_name}_pca2'
# Limit to 3000 points in PCA space to prevent canvas overlap/lag
if len(filtered) > 3000:
plot_sample = filtered.sample(n=3000, random_state=42)
else:
plot_sample = filtered
for _, row in plot_sample.iterrows():
points.append({
"pca1": sanitize_float(row[pca1_col]),
"pca2": sanitize_float(row[pca2_col]),
"name": sanitize_str(row['name']),
"common_name": sanitize_str(row.get('common_name')),
"lifestyle": sanitize_str(row.get('Primary.Lifestyle')),
"trophic": sanitize_str(row.get('Trophic.Level')),
"niche": sanitize_str(row.get('Trophic.Niche')),
"mass": sanitize_float(row.get('Mass')),
"hwi": sanitize_float(row.get('Hand-Wing.Index')),
"temp": sanitize_float(row.get('temperature_2m')),
"elev": sanitize_float(row.get('elevation')),
"ndvi": sanitize_float(row.get('ndvi')),
"file_name": sanitize_str(row['file_name_key'])
})
return {
"points": points,
"grid_images": grid_images,
"total_matches": total_matches
}
# Mount static files and images
STATIC_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'static')
os.makedirs(STATIC_DIR, exist_ok=True)
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
IMAGE_DIR = "data/images"
if not os.path.exists(IMAGE_DIR):
IMAGE_DIR = "/scratch/project/prj-02-visual-ai/mayesh/side-info_generation/inat_probing/data/images"
if os.path.exists(IMAGE_DIR):
app.mount("/images", StaticFiles(directory=IMAGE_DIR), name="images")
@app.get("/api/world_geojson")
def get_world_geojson():
global world_geojson_data
if world_geojson_data is None:
raise HTTPException(status_code=404, detail="World GeoJSON map boundaries not loaded on server startup")
return world_geojson_data
@app.get("/api/all_coordinates")
def get_all_coordinates(mode: str = "birds"):
df = get_data(mode)
if df is None or df.empty:
raise HTTPException(status_code=500, detail=f"Dataset for mode '{mode}' not loaded")
lifestyle_map, trophic_map = get_trait_maps(mode)
lat = df['latitude'].fillna(0.0).values
lon = df['longitude'].fillna(0.0).values
elev = df['elevation'].fillna(0.0).values
year = df['year'].fillna(2019).values.astype(int)
idx = df.index.values.astype(int)
temp = df['temperature_2m'].fillna(286.2).values
temp = np.where(temp < -1000, 286.2, temp)
ndvi = df['ndvi'].fillna(0.0).values
ndvi = np.where(ndvi < -1000, 0.0, ndvi)
lifestyle = df['Primary.Lifestyle'].astype(object).map(lifestyle_map).fillna(0).values.astype(int)
trophic = df['Trophic.Level'].astype(object).map(trophic_map).fillna(0).values.astype(int)
flat = np.column_stack((lat, lon, elev, idx, year, temp, ndvi, lifestyle, trophic)).flatten()
return {"points": flat.tolist()}
@app.post("/api/query_coordinates")
def query_coordinates(params: dict):
explorer_mode = params.get('explorerMode', 'birds')
df = get_data(explorer_mode)
if df is None or df.empty:
raise HTTPException(status_code=500, detail=f"Dataset for mode '{explorer_mode}' not loaded")
filtered = df.copy()
# 1. Taxonomy filters
if params.get('order'):
filtered = filtered[filtered['order'] == params['order']]
if params.get('family'):
filtered = filtered[filtered['family'] == params['family']]
if params.get('species'):
sp = params['species']
if isinstance(sp, list) and sp:
filtered = filtered[filtered['name'].isin(sp)]
elif isinstance(sp, str) and sp:
filtered = filtered[filtered['name'] == sp]
# Apply global text search query
if params.get('q'):
q = str(params['q']).strip().lower()
if q:
search_cols = ['name', 'common_name', 'genus', 'family', 'order', 'country', 'continent']
masks = []
for col in search_cols:
if col in filtered.columns:
masks.append(filtered[col].str.lower().str.contains(q, na=False))
if masks:
final_mask = masks[0]
for m in masks[1:]:
final_mask = final_mask | m
filtered = filtered[final_mask]
# 2. Location filters
if params.get('continent'):
filtered = filtered[filtered['continent'] == params['continent']]
if params.get('country'):
filtered = filtered[filtered['country'] == params['country']]
if params.get('hemisphere'):
if params['hemisphere'] == 'Northern':
filtered = filtered[filtered['latitude'] >= 0]
elif params['hemisphere'] == 'Southern':
filtered = filtered[filtered['latitude'] < 0]
# 3. Categorical trait filters
if params.get('lifestyle'):
filtered = filtered[filtered['Primary.Lifestyle'] == params['lifestyle']]
if params.get('trophic'):
filtered = filtered[filtered['Trophic.Level'] == params['trophic']]
lifestyle_map, trophic_map = get_trait_maps(explorer_mode)
lat = filtered['latitude'].fillna(0.0).values
lon = filtered['longitude'].fillna(0.0).values
elev = filtered['elevation'].fillna(0.0).values
year = filtered['year'].fillna(2019).values.astype(int)
idx = filtered.index.values.astype(int)
temp = filtered['temperature_2m'].fillna(286.2).values
temp = np.where(temp < -1000, 286.2, temp)
ndvi = filtered['ndvi'].fillna(0.0).values
ndvi = np.where(ndvi < -1000, 0.0, ndvi)
lifestyle = filtered['Primary.Lifestyle'].astype(object).map(lifestyle_map).fillna(0).values.astype(int)
trophic = filtered['Trophic.Level'].astype(object).map(trophic_map).fillna(0).values.astype(int)
flat = np.column_stack((lat, lon, elev, idx, year, temp, ndvi, lifestyle, trophic)).flatten()
return {"points": flat.tolist()}
landcover_map = {
0: "Unknown",
20: "Shrubs",
30: "Herbaceous vegetation",
40: "Cultivated / Agriculture",
50: "Urban / Built-up",
60: "Bare / Sparse vegetation",
70: "Snow / Ice",
80: "Permanent water",
90: "Herbaceous wetland",
100: "Moss / Lichen",
111: "Closed Forest (Needleleaf)",
112: "Closed Forest (Broadleaf)",
113: "Closed Forest (Deciduous Needleleaf)",
114: "Closed Forest (Deciduous Broadleaf)",
115: "Closed Forest (Mixed)",
116: "Closed Forest (Unknown)",
121: "Open Forest (Needleleaf)",
122: "Open Forest (Broadleaf)",
123: "Open Forest (Deciduous Needleleaf)",
124: "Open Forest (Deciduous Broadleaf)",
125: "Open Forest (Mixed)",
126: "Open Forest (Unknown)",
200: "Ocean / Open Water"
}
def map_landcover(val):
if pd.isna(val):
return "Unknown"
try:
code = int(float(val))
return landcover_map.get(code, "Unknown")
except (ValueError, TypeError):
return "Unknown"
@app.get("/api/details/{idx}")
def get_details(idx: int, mode: str = "birds"):
df = get_data(mode)
if df is None or df.empty:
raise HTTPException(status_code=500, detail=f"Dataset for mode '{mode}' not loaded")
if idx not in df.index:
raise HTTPException(status_code=404, detail=f"Specimen details not found for index {idx} in {mode}")
row = df.loc[idx]
if isinstance(row, pd.DataFrame):
row = row.iloc[0]
return {
"name": sanitize_str(row.get('name')),
"common_name": sanitize_str(row.get('common_name')),
"lifestyle": sanitize_str(row.get('Primary.Lifestyle')),
"trophic": sanitize_str(row.get('Trophic.Level')),
"niche": sanitize_str(row.get('Trophic.Niche')),
"mass": sanitize_float(row.get('Mass')),
"hwi": sanitize_float(row.get('Hand-Wing.Index')),
"temp": sanitize_float(row.get('temperature_2m')),
"elev": sanitize_float(row.get('elevation')),
"ndvi": sanitize_float(row.get('ndvi')),
"file_name": sanitize_str(row.get('file_name')),
# Additional metadata
"order": sanitize_str(row.get('order')),
"family": sanitize_str(row.get('family')),
"genus": sanitize_str(row.get('genus')),
"beak_length": sanitize_float(row.get('Beak.Length_Culmen')),
"tarsus_length": sanitize_float(row.get('Tarsus.Length')),
"landcover": map_landcover(row.get('landcover_class')),
"precipitation": sanitize_float(row.get('total_precipitation')),
"date": sanitize_str(row.get('date')),
"split": sanitize_str(row.get('split'))
}
@app.get("/api/similarity/{idx}")
def get_similarity_endpoint(idx: int, model: str = Query('DINOv3'), mode: str = Query('birds')):
df = get_data(mode)
df_val_aligned = get_val_aligned(mode)
precomputed_similarity = get_similarity(mode)
if df is None or df.empty or df_val_aligned is None or df_val_aligned.empty:
raise HTTPException(status_code=500, detail="Data alignment cache not loaded")
if idx not in df.index:
raise HTTPException(status_code=404, detail="Anchor specimen index not found")
# Get anchor specimen filename key
anchor_row = df.loc[idx]
if isinstance(anchor_row, pd.DataFrame):
anchor_row = anchor_row.iloc[0]
file_name = anchor_row.get('file_name')
# 1. Match anchor in validation set filenames
matched_rows = df_val_aligned[df_val_aligned['file_name_key'] == file_name]
if matched_rows.empty:
# Fallback: Find first validation specimen of the same species
species_name = anchor_row.get('name')
matched_rows = df_val_aligned[df_val_aligned['name'] == species_name]
if matched_rows.empty:
raise HTTPException(
status_code=404,
detail="This specimen is from the train split, and no representative validation specimens exist for its species."
)
val_row = matched_rows.iloc[0]
file_name_key = val_row['file_name_key']
# 2. Get precomputed similarity scores for this model
if model not in precomputed_similarity:
raise HTTPException(status_code=400, detail=f"Similarity cache not found for model: {model}")
model_matches = precomputed_similarity[model].get(file_name_key, [])
# Take top 5 matching specimens
matches = []
for match_aligned_idx, score in model_matches:
if match_aligned_idx not in df_val_aligned.index:
continue
row = df_val_aligned.loc[match_aligned_idx]
matches.append({
"original_index": int(row.get('original_index', -1)) if not pd.isna(row.get('original_index')) else -1,
"name": sanitize_str(row['name']),
"common_name": sanitize_str(row.get('common_name')),
"lifestyle": sanitize_str(row.get('Primary.Lifestyle')),
"trophic": sanitize_str(row.get('Trophic.Level')),
"niche": sanitize_str(row.get('Trophic.Niche')),
"mass": sanitize_float(row.get('Mass')),
"hwi": sanitize_float(row.get('Hand-Wing.Index')),
"temp": sanitize_float(row.get('temperature_2m')),
"elev": sanitize_float(row.get('elevation')),
"ndvi": sanitize_float(row.get('ndvi')),
"file_name": sanitize_str(row['file_name_key']),
"similarity": float(score),
# Additional metadata
"order": sanitize_str(row.get('order')),
"family": sanitize_str(row.get('family')),
"genus": sanitize_str(row.get('genus')),
"beak_length": sanitize_float(row.get('Beak.Length_Culmen')),
"tarsus_length": sanitize_float(row.get('Tarsus.Length')),
"landcover": map_landcover(row.get('landcover_class')),
"precipitation": sanitize_float(row.get('total_precipitation')),
"date": sanitize_str(row.get('date')),
"split": sanitize_str(row.get('split'))
})
if len(matches) >= 5:
break
return {
"anchor": {
"name": sanitize_str(anchor_row.get('name')),
"common_name": sanitize_str(anchor_row.get('common_name')),
"file_name": sanitize_str(file_name),
"lifestyle": sanitize_str(anchor_row.get('Primary.Lifestyle')),
"trophic": sanitize_str(anchor_row.get('Trophic.Level')),
"niche": sanitize_str(anchor_row.get('Trophic.Niche')),
"mass": sanitize_float(anchor_row.get('Mass')),
"hwi": sanitize_float(anchor_row.get('Hand-Wing.Index')),
"temp": sanitize_float(anchor_row.get('temperature_2m')),
"elev": sanitize_float(anchor_row.get('elevation')),
"ndvi": sanitize_float(anchor_row.get('ndvi')),
"order": sanitize_str(anchor_row.get('order')),
"family": sanitize_str(anchor_row.get('family')),
"genus": sanitize_str(anchor_row.get('genus')),
"beak_length": sanitize_float(anchor_row.get('Beak.Length_Culmen')),
"tarsus_length": sanitize_float(anchor_row.get('Tarsus.Length')),
"landcover": map_landcover(anchor_row.get('landcover_class')),
"precipitation": sanitize_float(anchor_row.get('total_precipitation')),
"date": sanitize_str(anchor_row.get('date')),
"split": sanitize_str(anchor_row.get('split'))
},
"matches": matches
}
# Serve main explorer dashboard
@app.get("/")
def read_root():
from fastapi.responses import FileResponse
return FileResponse(os.path.join(STATIC_DIR, "index.html"))
if __name__ == "__main__":
port = int(os.environ.get("PORT", 8000))
uvicorn.run("server:app", host="0.0.0.0", port=port, reload=False)