microbe-model / api /main.py
Miyu Horiuchi
Deploy app from main@a3254bf (no paper/ binaries)
0ed74db
"""FastAPI backend for microbe-model.
Endpoints:
GET /api/health β€” sanity check
GET /api/catalog β€” full uncultured catalog as JSON (cached, served gzip-encoded by middleware)
POST /api/predict β€” { target: "GCA_..." | "Thermus thermophilus" | ">my_fasta\\nACGT..." }
β†’ predicted phenotypes + ranked media
Static React build is mounted at /. Run:
uvicorn api.main:app --host 0.0.0.0 --port 7860
"""
from __future__ import annotations
import json
import os
import re
import sys
import time
from functools import lru_cache
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any
import pandas as pd
import requests
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
sys.path.insert(0, str(ROOT / "scripts"))
from microbe_model import config # noqa: E402
from microbe_model.train.media_recommender import load_models # noqa: E402
from recommend import ( # noqa: E402
_format_recipe_summary,
_load_genome_features,
_predict_phenotypes,
)
EUTILS_BASE = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
app = FastAPI(title="microbe-model API", version="2.0.0")
app.add_middleware(GZipMiddleware, minimum_size=1024)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# ──────────────────────────────────────────────────────────────────────
# Cached resources (loaded once at startup)
# ──────────────────────────────────────────────────────────────────────
_state: dict[str, Any] = {}
def _phylum(tax: str | None) -> str:
if not isinstance(tax, str):
return "β€”"
for part in tax.split(";"):
part = part.strip()
if part.startswith("p__"):
return part[3:] or "β€”"
return "β€”"
def _load_catalog_frame() -> tuple[pd.DataFrame, str]:
"""Load the deploy catalog and overlay hybrid phenotype predictions if available."""
base = pd.read_parquet(config.ARTIFACTS / "uncultured_predictions.parquet")
if "pred_oxygen_requirement_source" not in base.columns:
base["pred_oxygen_requirement_source"] = "tabular"
hybrid_path = config.ARTIFACTS / "hybrid_predictions.parquet"
if not hybrid_path.exists():
return base, "tabular"
hybrid = pd.read_parquet(hybrid_path)
if "genome_accession" not in hybrid.columns:
print(f" ! ignoring {hybrid_path}: missing genome_accession")
return base, "tabular"
pred_cols = [c for c in hybrid.columns if c.startswith("pred_")]
if not pred_cols:
print(f" ! ignoring {hybrid_path}: no pred_* columns")
return base, "tabular"
hybrid = hybrid[["genome_accession", *pred_cols]].drop_duplicates("genome_accession")
merged = base.merge(hybrid, on="genome_accession", how="left", suffixes=("", "_hybrid"))
for col in pred_cols:
hcol = f"{col}_hybrid" if col in base.columns else col
if hcol not in merged.columns:
continue
if col in base.columns and hcol != col:
merged[col] = merged[hcol].combine_first(merged[col])
merged = merged.drop(columns=[hcol])
if "pred_oxygen_requirement_source" not in merged.columns:
merged["pred_oxygen_requirement_source"] = "tabular"
merged["pred_oxygen_requirement_source"] = merged["pred_oxygen_requirement_source"].fillna("tabular")
return merged, "hybrid"
def _tag_prediction_sources(phenotypes: dict[str, Any]) -> dict[str, Any]:
"""Expose model source metadata to the UI without changing prediction values."""
for key, source in {
"optimal_temperature_c": "tabular",
"optimal_ph": "tabular",
"oxygen_requirement": "tabular",
"salt_tolerance_pct": "tabular",
}.items():
item = phenotypes.get(key)
if isinstance(item, dict):
item.setdefault("source", source)
return phenotypes
@app.on_event("startup")
def _load_resources() -> None:
print("Loading recommender models...")
_state["models"], _state["feature_cols"] = load_models(ROOT / "models" / "recommender")
print(f" β†’ {len(_state['models'])} per-medium classifiers loaded")
media_meta = pd.read_parquet(config.DATA / "media_metadata.parquet")
_state["recipes"] = pd.read_parquet(config.DATA / "media_recipes.parquet")
_state["name_by_id"] = dict(
zip(media_meta["medium_id"].astype(str), media_meta["name"], strict=True)
)
unc, catalog_source = _load_catalog_frame()
unc["phylum"] = unc["gtdb_taxonomy"].map(_phylum)
unc["truly_uncultured"] = (
unc["ncbi_organism_name"].fillna("").str.lower().str.startswith("uncultured")
)
_state["catalog"] = unc
_state["catalog_source"] = catalog_source
print(f" β†’ {len(unc):,} catalog rows ({int(unc['truly_uncultured'].sum()):,} truly uncultured)")
# ──────────────────────────────────────────────────────────────────────
# Models
# ──────────────────────────────────────────────────────────────────────
class PredictRequest(BaseModel):
target: str # accession, organism name, or FASTA string starting with > or ACGT
top_k: int = 8
class CatalogRow(BaseModel):
accession: str
name: str
phylum: str
completeness: float
truly_uncultured: bool
T_opt: float
pH: float
O2: str
O2_conf: float
O2_source: str = "tabular"
salt: float
top_medium_id: str
top_medium_name: str
top_confidence: float
top2_medium_id: str | None = None
top2_medium_name: str | None = None
top2_confidence: float | None = None
top3_medium_id: str | None = None
top3_medium_name: str | None = None
top3_confidence: float | None = None
# ──────────────────────────────────────────────────────────────────────
# NCBI helpers
# ──────────────────────────────────────────────────────────────────────
# NCBI assembly accessions: GCA_/GCF_ followed by 9 digits, optional .version
_ACCESSION_RE = re.compile(r"^GC[AF]_\d{9}(\.\d+)?$", re.IGNORECASE)
def _looks_like_accession(target: str) -> bool:
return bool(_ACCESSION_RE.match(target.strip()))
def _eutils_get(endpoint: str, params: dict, *, retries: int = 3) -> dict:
"""GET an E-utilities endpoint with an NCBI API key (if set) and retry on 429/5xx.
Anonymous eutils is limited to 3 req/sec (10/sec with NCBI_API_KEY), so transient
429s are expected under concurrent load. Back off and retry rather than surfacing
the rate limit to the user.
"""
api_key = os.environ.get("NCBI_API_KEY")
if api_key:
params = {**params, "api_key": api_key}
last_exc: Exception | None = None
for attempt in range(retries):
try:
r = requests.get(f"{EUTILS_BASE}/{endpoint}", params=params, timeout=20)
r.raise_for_status()
return r.json()
except requests.RequestException as e:
last_exc = e
status = getattr(e.response, "status_code", None)
if status == 429 or (status is not None and status >= 500):
time.sleep(0.5 * (2 ** attempt)) # 0.5s, 1s, 2s
continue
raise
raise last_exc # type: ignore[misc]
@lru_cache(maxsize=512)
def _ncbi_assembly_hits_cached(q_norm: str, retmax: int) -> tuple[dict, ...]:
"""Cached core resolver. Keyed on the normalized query; returns a hashable tuple."""
data = _eutils_get(
"esearch.fcgi",
{"db": "assembly", "term": f"{q_norm}[Organism] AND latest[filter]",
"retmode": "json", "retmax": retmax},
)
ids = data.get("esearchresult", {}).get("idlist", [])
if not ids:
return ()
data = _eutils_get(
"esummary.fcgi",
{"db": "assembly", "id": ",".join(ids), "retmode": "json"},
)
result = data.get("result", {})
out = []
for uid in result.get("uids", []):
doc = result.get(uid, {})
out.append({
"accession": str(doc.get("assemblyaccession", "")),
"organism": str(doc.get("organism", "")),
"level": str(doc.get("assemblystatus", "")),
})
rank = {"Complete Genome": 0, "Chromosome": 1, "Scaffold": 2, "Contig": 3}
out.sort(key=lambda r: rank.get(r["level"], 99))
return tuple(out)
def _ncbi_assembly_hits(q: str, retmax: int = 10) -> list[dict]:
"""Resolve an organism name to NCBI assembly accessions, best (most complete) first."""
hits = _ncbi_assembly_hits_cached(q.strip().lower(), retmax)
return [dict(h) for h in hits] # fresh copies so callers can't mutate the cache
# ──────────────────────────────────────────────────────────────────────
# Endpoints
# ──────────────────────────────────────────────────────────────────────
@app.get("/api/health")
def health():
return {
"ok": True,
"models_loaded": len(_state.get("models", {})),
"catalog_rows": len(_state.get("catalog", [])),
"catalog_source": _state.get("catalog_source", "tabular"),
}
def _safe_float(v, ndigits=3, default=0.0):
try:
f = float(v)
except (TypeError, ValueError):
return default
if pd.isna(f) or f != f:
return default
return round(f, ndigits)
def _safe_str(v, default=None):
if v is None or (isinstance(v, float) and pd.isna(v)):
return default
return str(v)
@app.get("/api/catalog")
def catalog(limit: int | None = None):
"""Return the full uncultured catalog as a list of dicts. Gzipped by middleware."""
df = _state["catalog"]
if limit is not None:
if limit < 1:
raise HTTPException(status_code=400, detail="limit must be >= 1")
df = df.head(limit)
rows = []
for _, m in df.iterrows():
rows.append({
"accession": _safe_str(m["genome_accession"], "β€”"),
"name": _safe_str(m["ncbi_organism_name"]) or _safe_str(m["genome_accession"], "β€”"),
"phylum": _safe_str(m["phylum"], "β€”"),
"completeness": _safe_float(m["checkm_completeness"], 1),
"truly_uncultured": bool(m["truly_uncultured"]),
"T_opt": _safe_float(m["pred_optimal_temperature_c"], 1),
"pH": _safe_float(m["pred_optimal_ph"], 2),
"O2": _safe_str(m["pred_oxygen_requirement"], "β€”"),
"O2_conf": _safe_float(m.get("pred_oxygen_requirement_confidence"), 3, 0.0),
"O2_source": _safe_str(m.get("pred_oxygen_requirement_source"), "tabular"),
"salt": _safe_float(m["pred_salt_tolerance_pct"], 2),
"top_medium_id": _safe_str(m["top1_medium_id"], "β€”"),
"top_medium_name": _safe_str(m["top1_medium_name"], "β€”"),
"top_confidence": _safe_float(m["top1_confidence"], 4),
"top2_medium_id": _safe_str(m.get("top2_medium_id")),
"top2_medium_name": _safe_str(m.get("top2_medium_name")),
"top2_confidence": _safe_float(m.get("top2_confidence"), 4) if pd.notna(m.get("top2_confidence")) else None,
"top3_medium_id": _safe_str(m.get("top3_medium_id")),
"top3_medium_name": _safe_str(m.get("top3_medium_name")),
"top3_confidence": _safe_float(m.get("top3_confidence"), 4) if pd.notna(m.get("top3_confidence")) else None,
})
return {"count": len(rows), "source": _state.get("catalog_source", "tabular"), "rows": rows}
@app.get("/api/ncbi-search")
def ncbi_search(q: str, retmax: int = 10):
"""Resolve an organism name to NCBI assembly accessions."""
if not q.strip():
return {"hits": []}
try:
return {"hits": _ncbi_assembly_hits(q, retmax)}
except requests.RequestException as e:
raise HTTPException(status_code=502, detail=f"NCBI search failed: {e}") from e
@app.post("/api/predict")
def predict(req: PredictRequest):
target = req.target.strip()
if not target:
raise HTTPException(status_code=400, detail="empty target")
tmp_path = None
try:
# FASTA inline?
if target.startswith(">") or (len(target) > 200 and set(target.upper()) <= set("ACGTNRYKMSWBDHV>\n\r ")):
tmp = NamedTemporaryFile(suffix=".fasta", delete=False, mode="w")
tmp.write(target)
tmp.close()
tmp_path = tmp.name
feats, accession, n_contigs = _load_genome_features(tmp.name)
else:
# Bare accession β†’ fetch directly. Anything else (e.g. "Thermus
# thermophilus") is treated as an organism name and resolved to its
# best NCBI assembly accession before fetching.
fetch_target = target
if not _looks_like_accession(target) and not Path(target).exists():
try:
hits = _ncbi_assembly_hits(target, retmax=5)
except requests.RequestException as e:
raise HTTPException(
status_code=502, detail=f"NCBI lookup failed: {e}"
) from e
if not hits:
raise HTTPException(
status_code=404,
detail=f'No NCBI genome found for "{target}". '
f"Try an NCBI accession (e.g. GCF_000005845.2) or a FASTA.",
)
fetch_target = hits[0]["accession"]
feats, accession, n_contigs = _load_genome_features(fetch_target)
feats_series = pd.Series(feats)
phenotypes = _tag_prediction_sources(_predict_phenotypes(feats_series))
models = _state["models"]
feature_cols = _state["feature_cols"]
recipes = _state["recipes"]
name_by_id = _state["name_by_id"]
X_pred = feats_series[feature_cols].to_frame().T
recs = []
for medium_id, model in models.items():
proba = float(model.predict_proba(X_pred)[0, 1])
recs.append({
"medium_id": str(medium_id),
"name": name_by_id.get(str(medium_id), "(unknown)"),
"confidence": round(proba, 4),
"recipe": _format_recipe_summary(str(medium_id), recipes),
})
recs.sort(key=lambda r: r["confidence"], reverse=True)
return {
"accession": accession,
"n_contigs": n_contigs,
"n_cds": int(feats["n_predicted_cds"]),
"gc": round(float(feats["gc_content"]), 4),
"phenotypes": phenotypes,
"media": recs[: req.top_k],
}
except SystemExit as e:
raise HTTPException(status_code=404, detail=str(e)) from e
finally:
if tmp_path:
try:
os.unlink(tmp_path)
except OSError:
pass
# ──────────────────────────────────────────────────────────────────────
# Static frontend (mounted last so /api/* routes win)
# ──────────────────────────────────────────────────────────────────────
WEB_BUILD = ROOT / "web" / "dist"
if WEB_BUILD.exists():
app.mount("/assets", StaticFiles(directory=WEB_BUILD / "assets"), name="assets")
@app.head("/")
def root_head():
return Response(status_code=200)
@app.get("/")
def root():
return FileResponse(WEB_BUILD / "index.html")
@app.get("/{path:path}")
def spa(path: str):
# SPA fallback β€” serve index.html for any non-API route
target = WEB_BUILD / path
if target.is_file():
return FileResponse(target)
return FileResponse(WEB_BUILD / "index.html")