Spaces:
Running
Running
| """FastAPI backend for microbe-model. | |
| Endpoints: | |
| GET /api/health β sanity check | |
| GET /api/catalog β full uncultured catalog as JSON (cached, served gzip-encoded by middleware) | |
| POST /api/predict β { target: "GCA_..." | "Thermus thermophilus" | ">my_fasta\\nACGT..." } | |
| β predicted phenotypes + ranked media | |
| Static React build is mounted at /. Run: | |
| uvicorn api.main:app --host 0.0.0.0 --port 7860 | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import time | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from tempfile import NamedTemporaryFile | |
| from typing import Any | |
| import pandas as pd | |
| import requests | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.middleware.gzip import GZipMiddleware | |
| from fastapi.responses import FileResponse, Response | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(ROOT)) | |
| sys.path.insert(0, str(ROOT / "scripts")) | |
| from microbe_model import config # noqa: E402 | |
| from microbe_model.train.media_recommender import load_models # noqa: E402 | |
| from recommend import ( # noqa: E402 | |
| _format_recipe_summary, | |
| _load_genome_features, | |
| _predict_phenotypes, | |
| ) | |
| EUTILS_BASE = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" | |
| app = FastAPI(title="microbe-model API", version="2.0.0") | |
| app.add_middleware(GZipMiddleware, minimum_size=1024) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["GET", "POST"], | |
| allow_headers=["*"], | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Cached resources (loaded once at startup) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _state: dict[str, Any] = {} | |
| def _phylum(tax: str | None) -> str: | |
| if not isinstance(tax, str): | |
| return "β" | |
| for part in tax.split(";"): | |
| part = part.strip() | |
| if part.startswith("p__"): | |
| return part[3:] or "β" | |
| return "β" | |
| def _load_catalog_frame() -> tuple[pd.DataFrame, str]: | |
| """Load the deploy catalog and overlay hybrid phenotype predictions if available.""" | |
| base = pd.read_parquet(config.ARTIFACTS / "uncultured_predictions.parquet") | |
| if "pred_oxygen_requirement_source" not in base.columns: | |
| base["pred_oxygen_requirement_source"] = "tabular" | |
| hybrid_path = config.ARTIFACTS / "hybrid_predictions.parquet" | |
| if not hybrid_path.exists(): | |
| return base, "tabular" | |
| hybrid = pd.read_parquet(hybrid_path) | |
| if "genome_accession" not in hybrid.columns: | |
| print(f" ! ignoring {hybrid_path}: missing genome_accession") | |
| return base, "tabular" | |
| pred_cols = [c for c in hybrid.columns if c.startswith("pred_")] | |
| if not pred_cols: | |
| print(f" ! ignoring {hybrid_path}: no pred_* columns") | |
| return base, "tabular" | |
| hybrid = hybrid[["genome_accession", *pred_cols]].drop_duplicates("genome_accession") | |
| merged = base.merge(hybrid, on="genome_accession", how="left", suffixes=("", "_hybrid")) | |
| for col in pred_cols: | |
| hcol = f"{col}_hybrid" if col in base.columns else col | |
| if hcol not in merged.columns: | |
| continue | |
| if col in base.columns and hcol != col: | |
| merged[col] = merged[hcol].combine_first(merged[col]) | |
| merged = merged.drop(columns=[hcol]) | |
| if "pred_oxygen_requirement_source" not in merged.columns: | |
| merged["pred_oxygen_requirement_source"] = "tabular" | |
| merged["pred_oxygen_requirement_source"] = merged["pred_oxygen_requirement_source"].fillna("tabular") | |
| return merged, "hybrid" | |
| def _tag_prediction_sources(phenotypes: dict[str, Any]) -> dict[str, Any]: | |
| """Expose model source metadata to the UI without changing prediction values.""" | |
| for key, source in { | |
| "optimal_temperature_c": "tabular", | |
| "optimal_ph": "tabular", | |
| "oxygen_requirement": "tabular", | |
| "salt_tolerance_pct": "tabular", | |
| }.items(): | |
| item = phenotypes.get(key) | |
| if isinstance(item, dict): | |
| item.setdefault("source", source) | |
| return phenotypes | |
| def _load_resources() -> None: | |
| print("Loading recommender models...") | |
| _state["models"], _state["feature_cols"] = load_models(ROOT / "models" / "recommender") | |
| print(f" β {len(_state['models'])} per-medium classifiers loaded") | |
| media_meta = pd.read_parquet(config.DATA / "media_metadata.parquet") | |
| _state["recipes"] = pd.read_parquet(config.DATA / "media_recipes.parquet") | |
| _state["name_by_id"] = dict( | |
| zip(media_meta["medium_id"].astype(str), media_meta["name"], strict=True) | |
| ) | |
| unc, catalog_source = _load_catalog_frame() | |
| unc["phylum"] = unc["gtdb_taxonomy"].map(_phylum) | |
| unc["truly_uncultured"] = ( | |
| unc["ncbi_organism_name"].fillna("").str.lower().str.startswith("uncultured") | |
| ) | |
| _state["catalog"] = unc | |
| _state["catalog_source"] = catalog_source | |
| print(f" β {len(unc):,} catalog rows ({int(unc['truly_uncultured'].sum()):,} truly uncultured)") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Models | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PredictRequest(BaseModel): | |
| target: str # accession, organism name, or FASTA string starting with > or ACGT | |
| top_k: int = 8 | |
| class CatalogRow(BaseModel): | |
| accession: str | |
| name: str | |
| phylum: str | |
| completeness: float | |
| truly_uncultured: bool | |
| T_opt: float | |
| pH: float | |
| O2: str | |
| O2_conf: float | |
| O2_source: str = "tabular" | |
| salt: float | |
| top_medium_id: str | |
| top_medium_name: str | |
| top_confidence: float | |
| top2_medium_id: str | None = None | |
| top2_medium_name: str | None = None | |
| top2_confidence: float | None = None | |
| top3_medium_id: str | None = None | |
| top3_medium_name: str | None = None | |
| top3_confidence: float | None = None | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # NCBI helpers | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # NCBI assembly accessions: GCA_/GCF_ followed by 9 digits, optional .version | |
| _ACCESSION_RE = re.compile(r"^GC[AF]_\d{9}(\.\d+)?$", re.IGNORECASE) | |
| def _looks_like_accession(target: str) -> bool: | |
| return bool(_ACCESSION_RE.match(target.strip())) | |
| def _eutils_get(endpoint: str, params: dict, *, retries: int = 3) -> dict: | |
| """GET an E-utilities endpoint with an NCBI API key (if set) and retry on 429/5xx. | |
| Anonymous eutils is limited to 3 req/sec (10/sec with NCBI_API_KEY), so transient | |
| 429s are expected under concurrent load. Back off and retry rather than surfacing | |
| the rate limit to the user. | |
| """ | |
| api_key = os.environ.get("NCBI_API_KEY") | |
| if api_key: | |
| params = {**params, "api_key": api_key} | |
| last_exc: Exception | None = None | |
| for attempt in range(retries): | |
| try: | |
| r = requests.get(f"{EUTILS_BASE}/{endpoint}", params=params, timeout=20) | |
| r.raise_for_status() | |
| return r.json() | |
| except requests.RequestException as e: | |
| last_exc = e | |
| status = getattr(e.response, "status_code", None) | |
| if status == 429 or (status is not None and status >= 500): | |
| time.sleep(0.5 * (2 ** attempt)) # 0.5s, 1s, 2s | |
| continue | |
| raise | |
| raise last_exc # type: ignore[misc] | |
| def _ncbi_assembly_hits_cached(q_norm: str, retmax: int) -> tuple[dict, ...]: | |
| """Cached core resolver. Keyed on the normalized query; returns a hashable tuple.""" | |
| data = _eutils_get( | |
| "esearch.fcgi", | |
| {"db": "assembly", "term": f"{q_norm}[Organism] AND latest[filter]", | |
| "retmode": "json", "retmax": retmax}, | |
| ) | |
| ids = data.get("esearchresult", {}).get("idlist", []) | |
| if not ids: | |
| return () | |
| data = _eutils_get( | |
| "esummary.fcgi", | |
| {"db": "assembly", "id": ",".join(ids), "retmode": "json"}, | |
| ) | |
| result = data.get("result", {}) | |
| out = [] | |
| for uid in result.get("uids", []): | |
| doc = result.get(uid, {}) | |
| out.append({ | |
| "accession": str(doc.get("assemblyaccession", "")), | |
| "organism": str(doc.get("organism", "")), | |
| "level": str(doc.get("assemblystatus", "")), | |
| }) | |
| rank = {"Complete Genome": 0, "Chromosome": 1, "Scaffold": 2, "Contig": 3} | |
| out.sort(key=lambda r: rank.get(r["level"], 99)) | |
| return tuple(out) | |
| def _ncbi_assembly_hits(q: str, retmax: int = 10) -> list[dict]: | |
| """Resolve an organism name to NCBI assembly accessions, best (most complete) first.""" | |
| hits = _ncbi_assembly_hits_cached(q.strip().lower(), retmax) | |
| return [dict(h) for h in hits] # fresh copies so callers can't mutate the cache | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Endpoints | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| return { | |
| "ok": True, | |
| "models_loaded": len(_state.get("models", {})), | |
| "catalog_rows": len(_state.get("catalog", [])), | |
| "catalog_source": _state.get("catalog_source", "tabular"), | |
| } | |
| def _safe_float(v, ndigits=3, default=0.0): | |
| try: | |
| f = float(v) | |
| except (TypeError, ValueError): | |
| return default | |
| if pd.isna(f) or f != f: | |
| return default | |
| return round(f, ndigits) | |
| def _safe_str(v, default=None): | |
| if v is None or (isinstance(v, float) and pd.isna(v)): | |
| return default | |
| return str(v) | |
| def catalog(limit: int | None = None): | |
| """Return the full uncultured catalog as a list of dicts. Gzipped by middleware.""" | |
| df = _state["catalog"] | |
| if limit is not None: | |
| if limit < 1: | |
| raise HTTPException(status_code=400, detail="limit must be >= 1") | |
| df = df.head(limit) | |
| rows = [] | |
| for _, m in df.iterrows(): | |
| rows.append({ | |
| "accession": _safe_str(m["genome_accession"], "β"), | |
| "name": _safe_str(m["ncbi_organism_name"]) or _safe_str(m["genome_accession"], "β"), | |
| "phylum": _safe_str(m["phylum"], "β"), | |
| "completeness": _safe_float(m["checkm_completeness"], 1), | |
| "truly_uncultured": bool(m["truly_uncultured"]), | |
| "T_opt": _safe_float(m["pred_optimal_temperature_c"], 1), | |
| "pH": _safe_float(m["pred_optimal_ph"], 2), | |
| "O2": _safe_str(m["pred_oxygen_requirement"], "β"), | |
| "O2_conf": _safe_float(m.get("pred_oxygen_requirement_confidence"), 3, 0.0), | |
| "O2_source": _safe_str(m.get("pred_oxygen_requirement_source"), "tabular"), | |
| "salt": _safe_float(m["pred_salt_tolerance_pct"], 2), | |
| "top_medium_id": _safe_str(m["top1_medium_id"], "β"), | |
| "top_medium_name": _safe_str(m["top1_medium_name"], "β"), | |
| "top_confidence": _safe_float(m["top1_confidence"], 4), | |
| "top2_medium_id": _safe_str(m.get("top2_medium_id")), | |
| "top2_medium_name": _safe_str(m.get("top2_medium_name")), | |
| "top2_confidence": _safe_float(m.get("top2_confidence"), 4) if pd.notna(m.get("top2_confidence")) else None, | |
| "top3_medium_id": _safe_str(m.get("top3_medium_id")), | |
| "top3_medium_name": _safe_str(m.get("top3_medium_name")), | |
| "top3_confidence": _safe_float(m.get("top3_confidence"), 4) if pd.notna(m.get("top3_confidence")) else None, | |
| }) | |
| return {"count": len(rows), "source": _state.get("catalog_source", "tabular"), "rows": rows} | |
| def ncbi_search(q: str, retmax: int = 10): | |
| """Resolve an organism name to NCBI assembly accessions.""" | |
| if not q.strip(): | |
| return {"hits": []} | |
| try: | |
| return {"hits": _ncbi_assembly_hits(q, retmax)} | |
| except requests.RequestException as e: | |
| raise HTTPException(status_code=502, detail=f"NCBI search failed: {e}") from e | |
| def predict(req: PredictRequest): | |
| target = req.target.strip() | |
| if not target: | |
| raise HTTPException(status_code=400, detail="empty target") | |
| tmp_path = None | |
| try: | |
| # FASTA inline? | |
| if target.startswith(">") or (len(target) > 200 and set(target.upper()) <= set("ACGTNRYKMSWBDHV>\n\r ")): | |
| tmp = NamedTemporaryFile(suffix=".fasta", delete=False, mode="w") | |
| tmp.write(target) | |
| tmp.close() | |
| tmp_path = tmp.name | |
| feats, accession, n_contigs = _load_genome_features(tmp.name) | |
| else: | |
| # Bare accession β fetch directly. Anything else (e.g. "Thermus | |
| # thermophilus") is treated as an organism name and resolved to its | |
| # best NCBI assembly accession before fetching. | |
| fetch_target = target | |
| if not _looks_like_accession(target) and not Path(target).exists(): | |
| try: | |
| hits = _ncbi_assembly_hits(target, retmax=5) | |
| except requests.RequestException as e: | |
| raise HTTPException( | |
| status_code=502, detail=f"NCBI lookup failed: {e}" | |
| ) from e | |
| if not hits: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f'No NCBI genome found for "{target}". ' | |
| f"Try an NCBI accession (e.g. GCF_000005845.2) or a FASTA.", | |
| ) | |
| fetch_target = hits[0]["accession"] | |
| feats, accession, n_contigs = _load_genome_features(fetch_target) | |
| feats_series = pd.Series(feats) | |
| phenotypes = _tag_prediction_sources(_predict_phenotypes(feats_series)) | |
| models = _state["models"] | |
| feature_cols = _state["feature_cols"] | |
| recipes = _state["recipes"] | |
| name_by_id = _state["name_by_id"] | |
| X_pred = feats_series[feature_cols].to_frame().T | |
| recs = [] | |
| for medium_id, model in models.items(): | |
| proba = float(model.predict_proba(X_pred)[0, 1]) | |
| recs.append({ | |
| "medium_id": str(medium_id), | |
| "name": name_by_id.get(str(medium_id), "(unknown)"), | |
| "confidence": round(proba, 4), | |
| "recipe": _format_recipe_summary(str(medium_id), recipes), | |
| }) | |
| recs.sort(key=lambda r: r["confidence"], reverse=True) | |
| return { | |
| "accession": accession, | |
| "n_contigs": n_contigs, | |
| "n_cds": int(feats["n_predicted_cds"]), | |
| "gc": round(float(feats["gc_content"]), 4), | |
| "phenotypes": phenotypes, | |
| "media": recs[: req.top_k], | |
| } | |
| except SystemExit as e: | |
| raise HTTPException(status_code=404, detail=str(e)) from e | |
| finally: | |
| if tmp_path: | |
| try: | |
| os.unlink(tmp_path) | |
| except OSError: | |
| pass | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Static frontend (mounted last so /api/* routes win) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| WEB_BUILD = ROOT / "web" / "dist" | |
| if WEB_BUILD.exists(): | |
| app.mount("/assets", StaticFiles(directory=WEB_BUILD / "assets"), name="assets") | |
| def root_head(): | |
| return Response(status_code=200) | |
| def root(): | |
| return FileResponse(WEB_BUILD / "index.html") | |
| def spa(path: str): | |
| # SPA fallback β serve index.html for any non-API route | |
| target = WEB_BUILD / path | |
| if target.is_file(): | |
| return FileResponse(target) | |
| return FileResponse(WEB_BUILD / "index.html") | |