from fastapi import FastAPI, Query, Path, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse, JSONResponse, FileResponse, StreamingResponse from typing import Dict, Any, Tuple, Optional, List, Literal import httpx, asyncio, time, os, hashlib, json, gzip, math from pathlib import Path as _Path from datetime import datetime APP_NAME = "neuro-mechanism-backend" CALLER_ID = "neuro-mech-backend-demo" # appears in STRING logs DATA_DIR = _Path("/tmp/neuro_mech_jobs") DATA_DIR.mkdir(parents=True, exist_ok=True) app = FastAPI(title=APP_NAME) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/", include_in_schema=False) def root(): return RedirectResponse(url="/docs") @app.get("/health", include_in_schema=False) def health(): return {"ok": True, "app": APP_NAME} @app.get("/endpoints", include_in_schema=False) def endpoints(): return JSONResponse({ "GET": [ "/mechanism_graph_manifest?receptor=HTR2A&symptom=apathy&species=9606", "/mechanism_graph/nodes?job_id=&page=1&page_size=200", "/mechanism_graph/edges?job_id=&page=1&page_size=200", "/mechanism_graph/literature?job_id=&page=1&page_size=50", "/mechanism_graph/regions?job_id=&page=1&page_size=50", "/download//nodes (gz)", "/download//edges (gz)", "/download//literature (gz)", "/download//regions (gz)", "/util/synonyms?term=apathy&kind=phenotype", "/heuristics/regions_from_string?receptor=HTR2A&symptom=apathy&limit=40", "/lit/eupmc?query=HTR2A%20AND%20apathy&pageSize=5", "/string/network?identifiers=HTR2A&species=9606", "/gpcrdb/protein?entry=htr2a_human", "/uniprot/search?query=HTR2A&size=5", "/rxnav/rxcui?name=fluoxetine", "/pubchem/compound_by_name?name=fluoxetine", "/trials/search?q=HTR2A&pageSize=5", "/health", "/docs" ] }) UA = {"User-Agent": f"{APP_NAME}/1.2 (HF Space)"} # ----------------- tiny in-memory TTL cache ----------------- class TTLCache: def __init__(self, max_items=512): self.store: Dict[str, Tuple[float, Any]] = {} self.max_items = max_items self._lock = asyncio.Lock() def _mk(self, url: str, params: Optional[dict]) -> str: key = url + "?" + (json.dumps(params, sort_keys=True) if params else "") return hashlib.sha1(key.encode()).hexdigest() async def get(self, url: str, params: Optional[dict], ttl: float): k = self._mk(url, params) async with self._lock: item = self.store.get(k) if item and (time.time() < item[0]): return item[1] async with httpx.AsyncClient(headers=UA, timeout=30) as client: r = await client.get(url, params=params) r.raise_for_status() # Some third-party APIs return plain text/HTML on error; # Fast path: try JSON, else wrap as text. try: data = r.json() except Exception: data = {"text": r.text, "status_code": r.status_code} async with self._lock: if len(self.store) > self.max_items: self.store.pop(next(iter(self.store))) self.store[k] = (time.time() + ttl, data) return data CACHE = TTLCache() # ----------------- polite throttling for STRING ------------------ _last_string_call = 0.0 async def throttle_string(): """Be nice to STRING; ~1 req/sec as a courtesy.""" global _last_string_call now = time.time() wait = 1.05 - (now - _last_string_call) if wait > 0: await asyncio.sleep(wait) _last_string_call = time.time() # ----------------- helpers ----------------- async def get_json_cached(url: str, params: Optional[dict], ttl: int): return await CACHE.get(url, params, ttl) def _safe_float(x, default=0.0): try: return float(x) except Exception: return default def _hash_params(d: dict) -> str: return hashlib.sha1(json.dumps(d, sort_keys=True).encode()).hexdigest() # ----------------- base connectors ----------------- @app.get("/lit/eupmc") async def europe_pmc_search(query: str, pageSize: int = 5, page: int = 1): # Europe PMC REST search (JSON) # docs: https://europepmc.org/RestfulWebService ; client vignette: europepmc R pkg url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search" params = {"query": query, "format": "json", "pageSize": pageSize, "page": page} return await get_json_cached(url, params, ttl=600) @app.get("/lit/pubmed_esearch") async def pubmed_esearch(term: str, retmax: int = 10): url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" params = {"db":"pubmed","term":term,"retmode":"json","retmax":retmax} return await get_json_cached(url, params, ttl=600) @app.get("/trials/search") async def ctgov_v2_studies(q: str, pageSize: int = 5): url = "https://clinicaltrials.gov/api/v2/studies" params = {"query.term": q, "pageSize": pageSize} return await get_json_cached(url, params, ttl=900) @app.get("/rxnav/rxcui") async def rxnav_rxcui(name: str): url = "https://rxnav.nlm.nih.gov/REST/rxcui.json" params = {"name": name} return await get_json_cached(url, params, ttl=86400) @app.get("/openfda/ae") async def openfda_adverse_events(drug: str, limit: int = 5): url = "https://api.fda.gov/drug/event.json" params = {"search": f'patient.drug.medicinalproduct:"{drug}"', "limit": limit} return await get_json_cached(url, params, ttl=3600) @app.get("/pubchem/compound_by_name") async def pubchem_by_name(name: str): url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{name}/JSON" return await get_json_cached(url, None, ttl=86400) @app.get("/uniprot/search") async def uniprot_search(query: str, size: int = 5): url = "https://rest.uniprot.org/uniprotkb/search" params = {"query": query, "format": "json", "size": size} return await get_json_cached(url, params, ttl=86400) @app.get("/gpcrdb/protein") async def gpcrdb_protein(entry: str): url = f"https://gpcrdb.org/services/protein/{entry}" return await get_json_cached(url, None, ttl=86400) @app.get("/string/network") async def string_network(identifiers: str, species: int = 9606, limit: int = 50): # STRING JSON network endpoint await throttle_string() url = "https://string-db.org/api/json/network" params = {"identifiers": identifiers, "species": species, "caller_identity": CALLER_ID, "limit": limit} return await get_json_cached(url, params, ttl=3600) # ----------------- synonym utilities ----------------- # curated region slang/aliases (additive to OLS) CURATED_REGION_SYNONYMS = { "prefrontal cortex": ["PFC", "frontal cortex", "dorsolateral prefrontal cortex", "dlPFC", "ventromedial prefrontal cortex", "vmPFC", "orbitofrontal cortex", "OFC"], "anterior cingulate cortex": ["ACC", "dorsal ACC", "dACC", "rostral ACC", "rACC"], "nucleus accumbens": ["NAc", "ventral striatum"], "ventral tegmental area": ["VTA"], "substantia nigra": ["SN", "pars compacta", "SNc"], "hippocampus": ["hippocampal formation", "CA1", "CA3", "dentate gyrus"], "amygdala": ["basolateral amygdala", "BLA", "central amygdala"] } async def _ols_synonyms(term: str, ontologies: Optional[List[str]] = None) -> List[str]: # OLS4 search; aggregate synonyms for top hits containing the term url = "https://www.ebi.ac.uk/ols4/api/search" params = {"q": term} if ontologies: # OLS4 supports multiple ontology filters as repeated params # We'll just join as comma-separated for brevity (works for OLS4) params["ontology"] = ",".join(ontologies) data = await get_json_cached(url, params, ttl=86400) syns = set() try: docs = data.get("response", {}).get("docs", []) for d in docs[:5]: for s in d.get("synonyms", []) or []: if isinstance(s, str): syns.add(s) except Exception: pass return list(syns) async def _mygene_aliases(symbol: str) -> List[str]: # MyGene.info v3; pull aliases/other names for the main focus gene url = "https://mygene.info/v3/query" params = {"q": f"symbol:{symbol}", "fields": "symbol,name,alias,alias_symbol,other_names", "size": 1, "species": "human"} data = await get_json_cached(url, params, ttl=86400) syns = set() try: hits = data.get("hits", []) if hits: h = hits[0] for fld in ("symbol","name"): v = h.get(fld) if isinstance(v, str): syns.add(v) for fld in ("alias","alias_symbol","other_names"): v = h.get(fld) if isinstance(v, list): for x in v: if isinstance(x, str): syns.add(x) except Exception: pass return list(syns) @app.get("/util/synonyms") async def util_synonyms(term: str, kind: Literal["region","gene","phenotype","auto"]="auto"): """ Fetch synonyms for a term. region: OLS4 (UBERON,HBP/HPO where applicable) + curated slang gene: MyGene.info aliases phenotype: OLS4(HPO) auto: choose gene if ALLCAPS letters+digits, else phenotype->region fallback. """ k = kind if k == "auto": k = "gene" if term.isupper() else "phenotype" syns = set([term]) if k == "region": syns.update(CURATED_REGION_SYNONYMS.get(term.lower(), [])) syns.update(await _ols_synonyms(term, ontologies=["uberon","hbp","hpo","ncit"])) elif k == "gene": syns.update(await _mygene_aliases(term)) elif k == "phenotype": syns.update(await _ols_synonyms(term, ontologies=["hpo","efo","mondo"])) return {"term": term, "kind": k, "synonyms": sorted({s for s in syns if isinstance(s, str) and len(s) <= 60})} # ----------------- region heuristic (upgraded) ----------------- REGION_TERMS_DEFAULT = [ "prefrontal cortex","anterior cingulate cortex","mPFC","ACC","nucleus accumbens","ventral striatum", "dorsal striatum","caudate","putamen","amygdala","hippocampus","thalamus","hypothalamus", "insula","ventral tegmental area","VTA","substantia nigra","cerebellum" ] def collect_gene_symbols_from_string(edges: List[dict], focus: str) -> List[str]: genes = set() f = focus.upper() for e in edges or []: for k in ("preferredName_A","preferredName_B"): g = e.get(k) if g and isinstance(g,str) and g.upper() != f: genes.add(g) return list(genes) async def _eupmc_hitcount(q: str) -> int: # Europe PMC search hitCount (pageSize=0) url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search" params = {"query": q, "format": "json", "pageSize": 0} data = await get_json_cached(url, params, ttl=3600) try: return int(data.get("hitCount", 0)) except Exception: return 0 @app.get("/heuristics/regions_from_string") async def regions_from_string( receptor: str = Query(..., description="e.g., HTR2A"), species: int = 9606, limit: int = 40, regions: Optional[str] = Query(None, description="comma-separated region terms (optional)"), symptom: Optional[str] = Query(None, description="optional phenotype/symptom to weight co-mentions (e.g., apathy)") ): """ Heuristic: rank brain regions by STRING neighbors + Europe PMC co-mentions, with synonyms & tiered fallbacks. Tiers (all unquoted for flexible match): T1: (region_syns) AND ((receptor_syns) OR neighbors) AND (symptom_syns?) weight 1.0 T2: (region_syns) AND (receptor_syns OR neighbors) weight 0.6 T3: (region_syns) AND (receptor_syns) weight 0.5 T4: (region_syns) AND (symptom_syns) weight 0.3 Final score = log10(weighted_hits+1) * mean_top_STRING_conf """ # 1) STRING neighbors edges = await string_network(receptor, species=species, limit=limit) neighbors = collect_gene_symbols_from_string(edges, receptor) # STRING confidences conf: Dict[str, float] = {} for e in edges or []: a, b, score = e.get("preferredName_A"), e.get("preferredName_B"), _safe_float(e.get("score", 0)) if a and a.upper() != receptor.upper(): conf[a] = max(conf.get(a, 0.0), score) if b and b.upper() != receptor.upper(): conf[b] = max(conf.get(b, 0.0), score) mean_conf = sum(conf.values())/max(len(conf),1) if conf else 0.2 # 2) synonyms receptor_syns = await _mygene_aliases(receptor) symptom_syns = [] if symptom: s = await util_synonyms(symptom, kind="phenotype") symptom_syns = s["synonyms"] region_list = [r.strip() for r in (regions.split(",") if regions else REGION_TERMS_DEFAULT) if r.strip()] # Build clauses (unquoted OR lists) gene_clause = " OR ".join(sorted({receptor} | set(receptor_syns) | set(neighbors[:25]))) results = [] tasks = [] tier_defs = [] for region in region_list: # region synonyms rs = await util_synonyms(region, kind="region") region_syns = rs["synonyms"] region_clause = " OR ".join(region_syns) # tiers # T1 if symptom and symptom_syns: t1 = f"({region_clause}) AND (({gene_clause})) AND ({' OR '.join(symptom_syns)})" else: t1 = f"({region_clause}) AND (({gene_clause}))" t2 = f"({region_clause}) AND (({gene_clause}))" t3 = f"({region_clause}) AND ({' OR '.join(sorted(set([receptor] + receptor_syns)))})" t4 = f"({region_clause}) AND ({' OR '.join(symptom_syns)})" if symptom_syns else None tiers = [("t1",1.0,t1), ("t2",0.6,t2), ("t3",0.5,t3)] if t4: tiers.append(("t4",0.3,t4)) # schedule hitCount calls tier_defs.append((region, tiers)) for _,_,q in tiers: tasks.append(_eupmc_hitcount(q)) # gather all counts in-order counts_all = await asyncio.gather(*tasks) # fold back into regions idx = 0 for region, tiers in tier_defs: weighted = 0.0 tier_counts = {} for name, weight, _q in tiers: hc = counts_all[idx]; idx += 1 tier_counts[name] = hc weighted += weight * hc score = math.log10(weighted + 1.0) * mean_conf results.append({"region": region, "tiers": tier_counts, "weighted_hits": int(round(weighted)), "weighted_score": round(score, 4)}) results.sort(key=lambda x: x["weighted_score"], reverse=True) return { "focus": receptor, "neighbors_considered": neighbors[:25], "regions_ranked": results, "notes": "STRING + Europe PMC with synonyms and tiered fallbacks (unquoted)." } # ----------------- MANIFEST + PAGED SECTIONS + DOWNLOAD ----------------- def _job_dir(job_id: str) -> _Path: d = DATA_DIR / job_id d.mkdir(parents=True, exist_ok=True) return d def _write_gz_jsonl(path: _Path, items: List[dict]): with gzip.open(path, "wt", encoding="utf-8") as gz: for it in items: gz.write(json.dumps(it, ensure_ascii=False) + "\n") def _read_gz_page(path: _Path, page: int, page_size: int) -> Tuple[int, List[dict]]: total = 0 start = (page - 1) * page_size end = start + page_size out = [] with gzip.open(path, "rt", encoding="utf-8") as gz: for i, line in enumerate(gz): if not line.strip(): continue if i >= start and i < end: out.append(json.loads(line)) total += 1 return total, out async def _build_mech_job(params: dict) -> dict: """ Build nodes/edges/literature/regions; write gz NDJSON + meta. """ receptor = params["receptor"] species = int(params.get("species", 9606)) symptom = params.get("symptom") string_limit = int(params.get("string_limit", 200)) eupmc_page_size = int(params.get("eupmc_page_size", 100)) eupmc_max_pages = int(params.get("eupmc_max_pages", 3)) job_id = _hash_params(params) d = _job_dir(job_id) meta_path = d / "meta.json" if meta_path.exists(): return json.loads(meta_path.read_text("utf-8")) # 1) STRING edges + nodes edges = await string_network(receptor, species=species, limit=string_limit) edge_items = [] nodes = set([receptor]) for e in edges or []: a = e.get("preferredName_A"); b = e.get("preferredName_B") score = _safe_float(e.get("score", 0)) if a and b: edge_items.append({"a": a, "b": b, "score": score}) nodes.add(a); nodes.add(b) node_items = [{"symbol": n, "seed": (n.upper()==receptor.upper())} for n in sorted(nodes)] _write_gz_jsonl(d / "edges.jsonl.gz", edge_items) _write_gz_jsonl(d / "nodes.jsonl.gz", node_items) # 2) Europe PMC literature for (receptor AND symptom?) else receptor lit_items = [] base_q = f"{receptor} AND {symptom}" if symptom else receptor for page in range(1, eupmc_max_pages+1): res = await europe_pmc_search(base_q, pageSize=eupmc_page_size, page=page) hits = res.get("resultList", {}).get("result", []) or [] for h in hits: lit_items.append({ "id": h.get("id"), "source": h.get("source"), "title": h.get("title"), "pubYear": h.get("pubYear"), "authorString": h.get("authorString"), "journalTitle": h.get("journalTitle"), "doi": h.get("doi") }) # stop early if last page if len(hits) < eupmc_page_size: break _write_gz_jsonl(d / "literature.jsonl.gz", lit_items) # 3) Regions heuristic (with symptom) reg = await regions_from_string(receptor=receptor, species=species, limit=min(100, string_limit), regions=None, symptom=symptom) reg_items = [] for r in reg.get("regions_ranked", []): reg_items.append(r) _write_gz_jsonl(d / "regions.jsonl.gz", reg_items) meta = { "job_id": job_id, "created": datetime.utcnow().isoformat() + "Z", "params": params, "counts": { "nodes": len(node_items), "edges": len(edge_items), "literature": len(lit_items), "regions": len(reg_items) }, "sections": ["nodes","edges","literature","regions"] } meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8") return meta @app.get("/mechanism_graph_manifest") async def mechanism_graph_manifest( receptor: str = Query(...), species: int = 9606, symptom: Optional[str] = None, string_limit: int = 200, eupmc_page_size: int = 100, eupmc_max_pages: int = 3 ): """ Build the full mechanism dataset server-side and return a manifest with job_id + counts. The actual data is stored as gzipped NDJSON and can be: - paged via /mechanism_graph/{section}?job_id=...&page=1&page_size=... - or downloaded as a single gz file via /download/{job_id}/{section} """ params = { "receptor": receptor, "species": species, "symptom": symptom, "string_limit": string_limit, "eupmc_page_size": eupmc_page_size, "eupmc_max_pages": eupmc_max_pages } meta = await _build_mech_job(params) return meta @app.get("/mechanism_graph/{section}") async def mechanism_graph_section( section: Literal["nodes","edges","literature","regions"] = Path(...), job_id: str = Query(...), page: int = 1, page_size: int = 100 ): """ Return a single page from a section (nodes|edges|literature|regions). """ d = _job_dir(job_id) p = d / f"{section}.jsonl.gz" if not p.exists(): raise HTTPException(status_code=404, detail=f"section {section} not found for job {job_id}") total, items = _read_gz_page(p, page=page, page_size=page_size) return { "job_id": job_id, "section": section, "page": page, "page_size": page_size, "total": total, "items": items } @app.get("/download/{job_id}/{section}") async def download_section(job_id: str, section: Literal["nodes","edges","literature","regions"]): """ Download the full gzipped NDJSON for a section. """ d = _job_dir(job_id) p = d / f"{section}.jsonl.gz" if not p.exists(): raise HTTPException(status_code=404, detail=f"section {section} not found for job {job_id}") return FileResponse( path=str(p), filename=f"{APP_NAME}-{job_id}-{section}.jsonl.gz", media_type="application/gzip" )