Net Zhang
Fix path resolution: use hf_hub_download to find cached artifacts
c66c0be
"""FastAPI service for the BioCLIP Image Search API.
A thin HTTP layer over the existing ``SearchService`` from ``src/bioclip_lite/``.
Designed to deploy as a separate Hugging Face Space (Docker SDK) alongside the
Gradio demo Space; both Spaces share the same model repo of preloaded artifacts
(FAISS index + DuckDB metadata).
Run locally:
# paths point at the production artifacts
FAISS_INDEX_PATH=/fs/ess/PAS2136/TreeOfLife/bioclip_image_search/TreeOfLife-200M/faiss/index.index \
DUCKDB_PATH=/fs/ess/PAS2136/TreeOfLife/bioclip_image_search/TreeOfLife-200M/duckdb/metadata.duckdb \
uvicorn api_app:app --host 0.0.0.0 --port 7860
On Hugging Face Spaces (Docker SDK), ``preload_from_hub`` mounts the same files
under the container's working directory; the defaults below find them there.
HTTP Status Codes:
- 1xx: informational
- 2xx: success
- 3xx: redirection
- 4xx: client error (malformed request, invalid embedding, etc.)
- 5xx: server error (unhandled exception, etc.)
used in this API:
- 200 OK: successful response to /v1/healthz and /v1/search/embedding
- 400 Bad Request: malformed request (invalid JSON, wrong content-type, invalid embedding values
- 422 Unprocessable Entity: JSON body failed schema validation (missing fields, wrong types, etc.)
- 415 Unsupported Media Type: content-type is not application/json or application/octet-stream
"""
from __future__ import annotations
import logging
import os
from contextlib import asynccontextmanager
from typing import Literal, Optional
import numpy as np
from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from bioclip_lite.config import LiteConfig
from bioclip_lite.services.search_service import SearchService
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
# ----------------------------------------------------------------------
# Configuration via environment variables
# ----------------------------------------------------------------------
#
# Path resolution priority for FAISS + DuckDB:
# 1. If FAISS_INDEX_PATH and DUCKDB_PATH are both set AND point at
# existing files, use them as-is. This is the local-dev path
# (cardinal-login runs against absolute paths on shared scratch).
# 2. Otherwise resolve via `huggingface_hub.hf_hub_download` from
# DATA_REPO_ID. On HF Spaces, `preload_from_hub` (declared in
# README.md YAML) downloads the files at build time and writes them
# into the HF Hub cache; hf_hub_download then resolves to the
# already-cached local path without re-downloading.
DATA_REPO_ID = os.environ.get("DATA_REPO_ID", "imageomics/bioclip-image-search-api")
FAISS_INDEX_PATH = os.environ.get("FAISS_INDEX_PATH") # may be unset
DUCKDB_PATH = os.environ.get("DUCKDB_PATH") # may be unset
NPROBE = int(os.environ.get("NPROBE", "16"))
OVER_FETCH_FACTOR = int(os.environ.get("OVER_FETCH_FACTOR", "3"))
def _resolve_data_paths() -> tuple[str, str]:
"""Return (faiss_index_path, duckdb_path), using env-var paths or HF cache."""
if (
FAISS_INDEX_PATH and DUCKDB_PATH
and os.path.exists(FAISS_INDEX_PATH) and os.path.exists(DUCKDB_PATH)
):
logger.info(
"Using explicit paths from env vars:\n"
f" FAISS : {FAISS_INDEX_PATH}\n"
f" DuckDB : {DUCKDB_PATH}"
)
return FAISS_INDEX_PATH, DUCKDB_PATH
# Defer the import so local-dev runs don't need huggingface_hub on disk.
from huggingface_hub import hf_hub_download
logger.info(f"Resolving FAISS + DuckDB paths from HF Hub repo {DATA_REPO_ID}")
faiss_path = hf_hub_download(repo_id=DATA_REPO_ID, filename="faiss/index.index")
duckdb_path = hf_hub_download(repo_id=DATA_REPO_ID, filename="duckdb/metadata.duckdb")
logger.info(f" FAISS : {faiss_path}")
logger.info(f" DuckDB : {duckdb_path}")
return faiss_path, duckdb_path
EMBEDDING_DIM = 768 # BioCLIP 2 ViT-L/14 output dimension
EMBEDDING_BYTES = EMBEDDING_DIM * 4 # 3072 (float32)
Scope = Literal["all", "url_only", "inaturalist", "bioclip2_training"]
# ----------------------------------------------------------------------
# Lifespan: load FAISS + DuckDB once at startup, release at shutdown
# ----------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load SearchService at boot; release the DuckDB connection on shutdown.
Loading the FAISS index from disk takes ~30 seconds (5.4 GiB read +
mmap warm-up). The first ``/v1/search/embedding`` call after a cold
start therefore waits on this lifespan to finish; subsequent calls
are instant. HF Spaces wake/cold-start probes should target
``/v1/healthz`` which returns 200 as soon as the lifespan completes.
"""
logger.info("BioCLIP Image Search API starting up")
faiss_index_path, duckdb_path = _resolve_data_paths()
logger.info(
f" nprobe : {NPROBE}\n"
f" over_fetch : {OVER_FETCH_FACTOR}x"
)
search = SearchService(
faiss_index_path=faiss_index_path,
duckdb_path=duckdb_path,
nprobe=NPROBE,
over_fetch_factor=OVER_FETCH_FACTOR,
metadata_columns=LiteConfig.METADATA_COLUMNS,
)
app.state.search = search
logger.info(
f"SearchService ready: ntotal={search.index.ntotal:,}, "
f"dim={search.index.d}, nprobe={search.index.nprobe}"
)
yield
logger.info("BioCLIP Image Search API shutting down")
try:
search.conn.close()
except Exception: # pragma: no cover
logger.exception("ignoring DuckDB close error during shutdown")
app = FastAPI(
title="BioCLIP Image Search API",
version="0.1.0",
lifespan=lifespan,
description=(
"HTTP API over the BioCLIP 2 + FAISS + DuckDB image search pipeline. "
"Clients encode images locally (via the FP16 ONNX visual tower) and "
"POST embeddings to /v1/search/embedding. See the project README for "
"the full architecture sketch."
),
)
# Public API: any web origin may call. CORS spec forbids credentials when
# allow_origins is "*", which is the property we want; this API is stateless
# and unauthenticated for now (auth + rate-limiting tracked separately).
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"],
max_age=86400,
)
# ----------------------------------------------------------------------
# Request / response models
# ----------------------------------------------------------------------
class SearchByEmbeddingRequest(BaseModel):
"""JSON body shape for POST /v1/search/embedding."""
model_config = ConfigDict(extra="forbid")
embedding: list[float] = Field(
...,
min_length=EMBEDDING_DIM,
max_length=EMBEDDING_DIM,
description=f"{EMBEDDING_DIM}-dim BioCLIP 2 visual embedding. The server "
"L2-normalizes it before searching, so any non-zero norm is accepted.",
)
k: int = Field(default=10, ge=1, le=200, description="Number of results to return.")
nprobe: int = Field(default=16, ge=1, le=128, description="IVF cells to probe per query.")
scope: Scope = Field(default="all", description="Result-set scope filter.")
class Taxonomy(BaseModel):
"""Taxonomic ranks for a single result.
Field ``class_`` is exposed as ``class`` in the JSON response via an
alias because ``class`` is a Python reserved word.
"""
model_config = ConfigDict(populate_by_name=True)
kingdom: Optional[str] = None
phylum: Optional[str] = None
class_: Optional[str] = Field(default=None, alias="class")
order: Optional[str] = None
family: Optional[str] = None
genus: Optional[str] = None
species: Optional[str] = None
common_name: Optional[str] = None
class Source(BaseModel):
"""Provenance info for a single result."""
dataset: Optional[str] = None # gbif / eol / bioscan / fathomnet
source_id: Optional[str] = None
publisher: Optional[str] = None
img_type: Optional[str] = None
basis_of_record: Optional[str] = None
class SearchResult(BaseModel):
"""One row of a search response."""
uuid: str
faiss_id: int
distance: float
image_url: Optional[str]
has_url: bool
taxonomy: Taxonomy
source: Source
in_bioclip2_training: bool
class SearchResponse(BaseModel):
"""Top-level shape of POST /v1/search/embedding."""
results: list[SearchResult]
# ----------------------------------------------------------------------
# Endpoints
# ----------------------------------------------------------------------
@app.get("/v1/healthz")
async def healthz() -> dict:
"""Liveness + data-load check.
Returns ``ntotal`` (FAISS vector count) and ``row_count`` (DuckDB row
count) so HF Space cold-start probes and client wake calls can verify
the data loaded correctly. ntotal and row_count should match since
the index is 1:1 with the catalog by design.
"""
search: SearchService = app.state.search
row_count = search.conn.execute("SELECT COUNT(*) FROM metadata").fetchone()[0]
return {
"status": "ok",
"service": app.title,
"version": app.version,
"faiss": {
"ntotal": int(search.index.ntotal),
"nprobe": int(search.index.nprobe),
"dim": int(search.index.d),
},
"duckdb": {
"row_count": int(row_count),
},
"device": "cpu",
}
def _result_to_response(r: dict) -> SearchResult:
"""Reshape one SearchService result dict into the API response model."""
return SearchResult(
uuid=str(r["uuid"]),
faiss_id=int(r["id"]),
distance=float(r["distance"]),
image_url=r.get("identifier"),
has_url=bool(r.get("has_url", False)),
taxonomy=Taxonomy(
kingdom=r.get("kingdom"),
phylum=r.get("phylum"),
**{"class": r.get("class")},
order=r.get("order"),
family=r.get("family"),
genus=r.get("genus"),
species=r.get("species"),
common_name=r.get("common_name"),
),
source=Source(
dataset=r.get("source_dataset"),
source_id=r.get("source_id"),
publisher=r.get("publisher"),
img_type=r.get("img_type"),
basis_of_record=r.get("basisOfRecord"),
),
in_bioclip2_training=bool(r.get("in_bioclip2_training", False)),
)
@app.post("/v1/search/embedding", response_model=SearchResponse)
async def search_by_embedding(
request: Request,
# Query params are only consulted on the octet-stream path; the JSON
# body carries its own k/nprobe/scope. Defaults match the JSON model.
k: int = Query(default=10, ge=1, le=200),
nprobe: int = Query(default=16, ge=1, le=128),
scope: Scope = Query(default="all"),
) -> SearchResponse:
"""Search by a pre-computed 768-dim BioCLIP 2 embedding.
Two content types are accepted:
* ``application/json``: body is ``SearchByEmbeddingRequest`` (carries
embedding + k + nprobe + scope as one document).
* ``application/octet-stream``: body is the raw float32 embedding
(3072 bytes = 768 * 4); k, nprobe, scope come from the query string.
The server L2-normalizes the embedding inside ``SearchService`` (via
``faiss.normalize_L2``) regardless of incoming norm, so clients can
send either raw model outputs or pre-normalized vectors.
"""
content_type = request.headers.get("content-type", "").lower().split(";")[0].strip()
if content_type == "application/octet-stream":
raw = await request.body()
if len(raw) != EMBEDDING_BYTES:
raise HTTPException(
400,
f"octet-stream body must be exactly {EMBEDDING_BYTES} bytes "
f"({EMBEDDING_DIM} float32), got {len(raw)}",
)
embedding = np.frombuffer(raw, dtype=np.float32).copy()
# k, nprobe, scope already populated from query params
elif content_type in ("application/json", ""):
# Empty content-type happens when the client forgets to set it; we
# try to parse as JSON before giving up.
try:
body = await request.json()
req = SearchByEmbeddingRequest.model_validate(body)
except ValidationError as e:
raise HTTPException(422, detail=e.errors())
except Exception as e:
raise HTTPException(400, detail=f"invalid JSON body: {e}")
embedding = np.asarray(req.embedding, dtype=np.float32)
k = req.k
nprobe = req.nprobe
scope = req.scope
else:
raise HTTPException(415, f"unsupported content-type: {content_type!r}")
if not np.isfinite(embedding).all():
raise HTTPException(400, "embedding contains NaN or Inf")
if float(np.linalg.norm(embedding)) == 0.0:
raise HTTPException(400, "embedding has zero norm")
raw_results = app.state.search.search(
query_vector=embedding,
top_n=k,
nprobe=nprobe,
scope=scope,
)
return SearchResponse(results=[_result_to_response(r) for r in raw_results])