from fastapi import FastAPI, Query, Path, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse, JSONResponse, StreamingResponse import httpx, asyncio, time, hashlib, json, io, gzip, math from typing import Dict, Any, Tuple, Optional, List # ------------------ App constants ------------------ APP_NAME = "neuro-mechanism-backend" CALLER_ID = "neuro-mech-backend-demo" # polite ID for STRING UA = {"User-Agent": f"{APP_NAME}/1.2 (HF Space)"} # ------------------ FastAPI app ------------------ app = FastAPI(title=APP_NAME) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) @app.get("/", include_in_schema=False) def root(): return RedirectResponse(url="/docs") @app.get("/health", include_in_schema=False) def health(): return {"ok": True, "app": APP_NAME} @app.get("/endpoints", include_in_schema=False) def endpoints(): return JSONResponse({ "GET": [ "/mechanism_graph_manifest?receptor=HTR2A&symptom=apathy", "/mechanism_graph/regions?receptor=HTR2A&symptom=apathy", "/download/{job_id}/{section}", "/heuristics/regions_from_string?receptor=HTR2A", "/util/synonyms?term=ACC&kind=region", "/lit/eupmc?query=HTR2A%20AND%20apathy&pageSize=5", "/string/network?identifiers=HTR2A&species=9606", "/gpcrdb/protein?entry=htr2a_human", "/uniprot/search?query=HTR2A&size=5", "/rxnav/rxcui?name=fluoxetine", "/pubchem/compound_by_name?name=fluoxetine", "/trials/search?q=HTR2A&pageSize=5", "/health", "/docs" ] }) # ------------------ tiny in-memory TTL cache ------------------ class TTLCache: def __init__(self, max_items=512): self.store: Dict[str, Tuple[float, Any]] = {} self.max_items = max_items self._lock = asyncio.Lock() def _mk(self, url: str, params: Optional[dict]) -> str: key = url + "?" + (json.dumps(params, sort_keys=True) if params else "") return hashlib.sha1(key.encode()).hexdigest() async def get(self, url: str, params: Optional[dict], ttl: float): k = self._mk(url, params) async with self._lock: item = self.store.get(k) if item and (time.time() < item[0]): return item[1] async with httpx.AsyncClient(headers=UA, timeout=30) as client: r = await client.get(url, params=params) r.raise_for_status() data = r.json() async with self._lock: if len(self.store) > self.max_items: self.store.pop(next(iter(self.store))) self.store[k] = (time.time() + ttl, data) return data CACHE = TTLCache() # ------------------ polite throttling for STRING ------------------ _last_string_call = 0.0 async def throttle_string(): """Be nice to STRING; ~1 req/sec is a good courtesy.""" global _last_string_call now = time.time() wait = 1.05 - (now - _last_string_call) if wait > 0: await asyncio.sleep(wait) _last_string_call = time.time() # ------------------ small helpers ------------------ async def get_json_cached(url: str, params: Optional[dict], ttl: int): try: return await CACHE.get(url, params, ttl) except Exception as e: return {"error": str(e), "url": url, "params": params} def job_key(receptor: str, symptom: str) -> str: raw = f"{receptor}|{symptom}|{int(time.time())}" return hashlib.sha1(raw.encode()).hexdigest()[:16] def gz_json_bytes(obj: Any) -> bytes: b = json.dumps(obj, ensure_ascii=False).encode("utf-8") bio = io.BytesIO() with gzip.GzipFile(fileobj=bio, mode="wb") as gz: gz.write(b) return bio.getvalue() # ------------------ External API wrappers ------------------ @app.get("/lit/eupmc") async def europe_pmc_search(query: str, pageSize: int = 5): # Europe PMC returns a 'hitCount' field in the response for quick counts. url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search" params = {"query": query, "format": "json", "pageSize": pageSize} return await get_json_cached(url, params, ttl=600) @app.get("/lit/pubmed_esearch") async def pubmed_esearch(term: str, retmax: int = 10): url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" params = {"db":"pubmed","term":term,"retmode":"json","retmax":retmax} return await get_json_cached(url, params, ttl=600) @app.get("/trials/search") async def ctgov_v2_studies(q: str, pageSize: int = 5): url = "https://clinicaltrials.gov/api/v2/studies" params = {"query.term": q, "pageSize": pageSize} return await get_json_cached(url, params, ttl=900) @app.get("/rxnav/rxcui") async def rxnav_rxcui(name: str): url = "https://rxnav.nlm.nih.gov/REST/rxcui.json" params = {"name": name} return await get_json_cached(url, params, ttl=86400) @app.get("/openfda/ae") async def openfda_adverse_events(drug: str, limit: int = 5): url = "https://api.fda.gov/drug/event.json" params = {"search": f'patient.drug.medicinalproduct:"{drug}"', "limit": limit} return await get_json_cached(url, params, ttl=3600) @app.get("/pubchem/compound_by_name") async def pubchem_by_name(name: str): url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{name}/JSON" return await get_json_cached(url, None, ttl=86400) @app.get("/uniprot/search") async def uniprot_search(query: str, size: int = 5): url = "https://rest.uniprot.org/uniprotkb/search" params = {"query": query, "format": "json", "size": size} return await get_json_cached(url, params, ttl=86400) @app.get("/gpcrdb/protein") async def gpcrdb_protein(entry: str): url = f"https://gpcrdb.org/services/protein/{entry}" return await get_json_cached(url, None, ttl=86400) @app.get("/string/network") async def string_network(identifiers: str, species: int = 9606, limit: int = 50): await throttle_string() # STRING API supports a 'caller_identity' parameter – good practice to include it. url = "https://string-db.org/api/json/network" params = {"identifiers": identifiers, "species": species, "caller_identity": CALLER_ID, "limit": limit} return await get_json_cached(url, params, ttl=3600) # ------------------ Entity resolver (drug vs gene) ------------------ KNOWN_DRUG_TARGETS = { "bupropion": ["SLC6A3","SLC6A2","CHRNA4","CHRNB2","CHRNA3","CHRNB4"], "bupropion hcl": ["SLC6A3","SLC6A2","CHRNA4","CHRNB2","CHRNA3","CHRNB4"], "atomoxetine": ["SLC6A2"], "reboxetine": ["SLC6A2"], "methylphenidate": ["SLC6A3","SLC6A2"], "dexmethylphenidate": ["SLC6A3","SLC6A2"], "lisdexamfetamine": ["SLC6A3","SLC6A2"], "dextroamphetamine": ["SLC6A3","SLC6A2"], } def _norm(s: str) -> str: return (s or "").strip().lower() async def resolve_mode(receptor: str) -> dict: """ Decide if 'receptor' is a gene (via MyGene) or a drug (fallback map). Returns {"mode":"gene"|"drug","canonical":..., "genes":[...]} """ rq = (receptor or "").strip() if not rq: return {"mode":"gene","canonical":rq,"genes":[]} # Try exact gene symbol first (MyGene) try: async with httpx.AsyncClient(headers=UA, timeout=15) as client: r = await client.get("https://mygene.info/v3/query", params={"q": f"symbol:{rq}", "fields": "symbol", "size": 1}) if r.status_code == 200: js = r.json() or {} hits = js.get("hits") or [] if hits and hits[0].get("symbol"): sym = hits[0]["symbol"] return {"mode":"gene","canonical":sym,"genes":[sym]} except Exception: pass # Not a sure gene → treat as drug dn = _norm(rq) genes = KNOWN_DRUG_TARGETS.get(dn, []) return {"mode":"drug","canonical":rq,"genes":genes} # ------------------ Synonyms (regions/genes/phenotypes) ------------------ REGION_SEED_SYNONYMS = { "prefrontal cortex": ["PFC","mPFC","vmPFC","dlPFC","dorsolateral prefrontal cortex","ventromedial prefrontal cortex"], "anterior cingulate cortex": ["ACC","dACC","pgACC","sgACC","subgenual cingulate"], "nucleus accumbens": ["NAc","ventral striatum","accumbens"], "ventral tegmental area": ["VTA"], "substantia nigra": ["SN","SNc","pars compacta"], "hippocampus": ["HC"], "amygdala": [], "insula": ["insular cortex"], "thalamus": [], "hypothalamus": [], "cerebellum": [] } async def ols4_synonyms(term: str, ontology: Optional[str] = None) -> List[str]: url = "https://www.ebi.ac.uk/ols4/api/search" params = {"q": term, "rows": 20} if ontology: params["ontology"] = ontology data = await get_json_cached(url, params, ttl=86400) syns = [] try: docs = data.get("response", {}).get("docs", []) or [] for d in docs: if "synonym" in d: syns.extend(d.get("synonym", [])) if "label" in d: syns.append(d["label"]) except Exception: pass out, seen = [], set() for s in syns: s2 = s.strip() if s2 and s2.lower() not in seen: out.append(s2); seen.add(s2.lower()) return out[:50] async def mygene_synonyms(symbol: str) -> List[str]: url = "https://mygene.info/v3/query" params = {"q": symbol, "fields": "symbol,name,alias,other_names", "size": 5} data = await get_json_cached(url, params, ttl=86400) syns = [] try: for hit in data.get("hits", []): for k in ("symbol","name"): if k in hit: syns.append(hit[k]) for k in ("alias","other_names"): if k in hit and isinstance(hit[k], list): syns.extend(hit[k]) except Exception: pass out, seen = [], set() for s in syns: s2 = str(s).strip() if s2 and s2.lower() not in seen: out.append(s2); seen.add(s2.lower()) return out[:50] @app.get("/util/synonyms") async def util_synonyms(term: str, kind: str = Query("region", enum=["region","gene","phenotype"])): term_norm = term.strip() if kind == "region": seeds = REGION_SEED_SYNONYMS.get(term_norm.lower(), []) ols = await ols4_synonyms(term_norm, ontology="uberon") return {"term": term_norm, "kind": kind, "synonyms": sorted(set([term_norm] + seeds + ols))} elif kind == "gene": mg = await mygene_synonyms(term_norm) return {"term": term_norm, "kind": kind, "synonyms": sorted(set([term_norm] + mg))} else: ols = await ols4_synonyms(term_norm, ontology="hp") return {"term": term_norm, "kind": kind, "synonyms": sorted(set([term_norm] + ols))} # ------------------ Regions heuristic ------------------ REGION_TERMS_DEFAULT = [ "prefrontal cortex","anterior cingulate cortex","nucleus accumbens","ventral striatum", "dorsal striatum","caudate","putamen","amygdala","hippocampus","thalamus","hypothalamus", "insula","ventral tegmental area","substantia nigra","cerebellum" ] async def eupmc_hitcount(q: str) -> int: # Europe PMC responses include / "hitCount" for fast counts. url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search" params = {"query": q, "format": "json", "pageSize": 0} data = await get_json_cached(url, params, ttl=1800) try: return int(data.get("hitCount", 0)) except Exception: return 0 def collect_gene_symbols_from_string(edges: Any, focus: str) -> List[str]: """Defensive parse of STRING results (may not always be a list).""" genes = set() if not isinstance(edges, list): return [] f = (focus or "").upper() for e in edges: if not isinstance(e, dict): continue for k in ("preferredName_A", "preferredName_B"): g = e.get(k) if isinstance(g, str) and g.upper() != f: genes.add(g) return list(genes) @app.get("/heuristics/regions_from_string") async def regions_from_string( receptor: str = Query(..., description="gene or drug"), species: int = 9606, limit: int = 40, regions: Optional[str] = Query(None, description="comma-separated override"), use_synonyms: bool = True, symptom: Optional[str] = None ): """ Robust region scoring that works for genes AND drugs. gene: (region_syns) AND (receptor OR STRING neighbors OR gene_syns) drug: (region_syns) AND (drug) """ mode = await resolve_mode(receptor) # candidates region_list = [r.strip() for r in (regions.split(",") if regions else REGION_TERMS_DEFAULT) if r.strip()] # region synonyms region_syns_map = {} if use_synonyms: for rgn in region_list: try: syn = await util_synonyms(rgn, "region") region_syns_map[rgn] = syn.get("synonyms", [])[:10] or [rgn] except Exception: region_syns_map[rgn] = [rgn] else: region_syns_map = {rgn: [rgn] for rgn in region_list} # RHS terms rhs_terms: List[str] = [] neighbors: List[str] = [] gene_syns: List[str] = [] if mode["mode"] == "gene": try: sdata = await string_network(receptor, species=species, limit=limit) except Exception: sdata = [] neighbors = collect_gene_symbols_from_string(sdata, receptor) try: gs = await mygene_synonyms(receptor) gene_syns = gs[:20] except Exception: gene_syns = [] rhs_terms = list({t for t in [receptor] + neighbors[:25] + gene_syns[:25] if t}) else: # drug path — skip STRING rhs_terms = [mode["canonical"]] # Europe PMC hit counts (fast) results = [] symptom_clause = f" AND ({symptom})" if symptom else "" for region in region_list: syns = region_syns_map.get(region, [region]) lhs = " OR ".join(syns) if mode["mode"] == "gene": rhs = " OR ".join(rhs_terms) if rhs_terms else receptor q1 = f"({lhs}) AND ({rhs}){symptom_clause}" else: q1 = f"({lhs}) AND ({mode['canonical']}){symptom_clause}" h1 = await eupmc_hitcount(q1) hits, tier = h1, "T1" if h1 == 0 and mode["mode"] == "gene": q2 = f"({lhs}) AND ({receptor}){symptom_clause}" h2 = await eupmc_hitcount(q2) hits, tier = h2, "T2" if h2 == 0: q3 = f"({region}) AND ({receptor}){symptom_clause}" h3 = await eupmc_hitcount(q3) hits, tier = h3, "T3" score = math.log10(hits + 1.0) results.append({"region": region, "hits": hits, "tier": tier, "weighted_score": round(score, 4)}) results.sort(key=lambda x: x["weighted_score"], reverse=True) return { "focus": receptor, "neighbors_considered": neighbors[:25] if neighbors else [], "regions_ranked": results, "notes": f"mode={mode['mode']} genes={mode['genes']}" } # ------------------ Manifest / Section / Download ------------------ JOBS: Dict[str, Dict[str, Any]] = {} @app.get("/mechanism_graph_manifest") async def mechanism_graph_manifest( receptor: str = Query(..., description="e.g., HTR2A"), symptom: str = Query("apathy"), species: int = 9606, string_limit: int = 50, lit_page_size: int = 10 ): """Return a job_id + available sections (lightweight counts only).""" jid = job_key(receptor, symptom) sdata = await string_network(receptor, species=species, limit=string_limit) s_count = len(sdata) if isinstance(sdata, list) else 0 ldata = await europe_pmc_search(f"{receptor} AND {symptom}", pageSize=0) try: lit_hits = int(ldata.get("hitCount", 0)) except Exception: lit_hits = 0 rdata = await regions_from_string(receptor=receptor, species=species, limit=40, regions=None, use_synonyms=True, symptom=symptom) r_count = len(rdata.get("regions_ranked", [])) if isinstance(rdata, dict) else 0 JOBS[jid] = { "_meta": {"receptor": receptor, "symptom": symptom, "species": species}, "overview": { "receptor": receptor, "symptom": symptom, "counts": {"string_edges": s_count, "literature_hits": lit_hits, "regions": r_count} } } sections = [ {"name": "overview", "approx_size": "small"}, {"name": "network", "approx_size": f"{s_count} edges (limit={string_limit})"}, {"name": "literature", "approx_size": f"{lit_hits} hits (pageSize={lit_page_size})"}, {"name": "regions", "approx_size": f"{r_count} entries"} ] return {"job_id": jid, "sections": sections} @app.get("/mechanism_graph/{section}") async def mechanism_graph_section( section: str = Path(..., description="one of: overview, network, literature, regions"), receptor: Optional[str] = None, symptom: Optional[str] = None, species: int = 9606, string_limit: int = 50, lit_page_size: int = 10, job_id: Optional[str] = Query(None, description="optional; use manifest if you want stable ids") ): """Return one section; builds on-the-fly if no job_id.""" ctx = None if job_id and job_id in JOBS: ctx = JOBS[job_id].get("_meta", {}) receptor = receptor or ctx.get("receptor") symptom = symptom or ctx.get("symptom") species = species or ctx.get("species") if not receptor: raise HTTPException(status_code=422, detail="receptor is required (query param)") if section == "overview": if not job_id or job_id not in JOBS: jid = job_key(receptor, symptom or "") JOBS.setdefault(jid, {"_meta": {"receptor": receptor, "symptom": symptom or "", "species": species}}) job_id = jid if "overview" not in JOBS[job_id]: sdata = await string_network(receptor, species=species, limit=string_limit) s_count = len(sdata) if isinstance(sdata, list) else 0 ldata = await europe_pmc_search(f"{receptor} AND {symptom}", pageSize=0) lit_hits = int(ldata.get("hitCount", 0)) if isinstance(ldata, dict) else 0 rdata = await regions_from_string(receptor=receptor, species=species, limit=40, regions=None, use_synonyms=True, symptom=symptom) r_count = len(rdata.get("regions_ranked", [])) if isinstance(rdata, dict) else 0 JOBS[job_id]["overview"] = { "receptor": receptor, "symptom": symptom, "counts": {"string_edges": s_count, "literature_hits": lit_hits, "regions": r_count} } return {"job_id": job_id, "section": "overview", "data": JOBS[job_id]["overview"]} elif section == "network": net = await string_network(receptor, species=species, limit=string_limit) return {"job_id": job_id, "section": "network", "data": net} elif section == "literature": lit = await europe_pmc_search(f"{receptor} AND {symptom}", pageSize=lit_page_size) return {"job_id": job_id, "section": "literature", "data": lit} elif section == "regions": reg = await regions_from_string(receptor=receptor, species=species, limit=40, regions=None, use_synonyms=True, symptom=symptom) return {"job_id": job_id, "section": "regions", "data": reg} else: raise HTTPException(status_code=404, detail=f"unknown section: {section}") @app.get("/download/{job_id}/{section}") async def download_section(job_id: str, section: str): """Gzipped JSON download of a section; returns what's there.""" data = JOBS.get(job_id, {}).get(section) or JOBS.get(job_id, {}).get("_meta") if not data: raise HTTPException(status_code=404, detail="job/section not found") gz = gz_json_bytes({"job_id": job_id, "section": section, "data": data}) return StreamingResponse(io.BytesIO(gz), media_type="application/gzip", headers={"Content-Disposition": f'attachment; filename="{job_id}_{section}.json.gz"'} ) @app.get("/selfcheck") async def selfcheck( receptor: str = Query("bupropion"), symptom: str = Query("anhedonia") ): """ Runs the same checks your GPT does: - /health - /mechanism_graph_manifest - /heuristics/regions_from_string Returns a compact PASS/FAIL summary in one JSON. """ out = {"input": {"receptor": receptor, "symptom": symptom}} # health try: out["health"] = {"ok": True, "data": health()} except Exception as e: out["health"] = {"ok": False, "error": str(e)} # manifest try: mani = await mechanism_graph_manifest( receptor=receptor, symptom=symptom, species=9606, string_limit=50, lit_page_size=10 ) out["manifest"] = { "ok": True, "job_id": mani.get("job_id"), "sections": mani.get("sections", []) } except Exception as e: out["manifest"] = {"ok": False, "error": str(e)} # regions (robust to drugs) try: reg = await regions_from_string( receptor=receptor, species=9606, limit=25, regions=None, use_synonyms=True, symptom=None ) out["regions"] = { "ok": True, "count": len(reg.get("regions_ranked", [])), "sample": reg.get("regions_ranked", [])[:5] } except Exception as e: out["regions"] = {"ok": False, "error": str(e)} # overall out["overall_ok"] = all([ out["health"].get("ok"), out["manifest"].get("ok"), out["regions"].get("ok"), ]) return out