darkfrostx's picture
Update app.py
a2b5b6d verified
raw
history blame
21.1 kB
from fastapi import FastAPI, Query, Path, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse, JSONResponse, FileResponse, StreamingResponse
from typing import Dict, Any, Tuple, Optional, List, Literal
import httpx, asyncio, time, os, hashlib, json, gzip, math
from pathlib import Path as _Path
from datetime import datetime
APP_NAME = "neuro-mechanism-backend"
CALLER_ID = "neuro-mech-backend-demo" # appears in STRING logs
DATA_DIR = _Path("/tmp/neuro_mech_jobs")
DATA_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI(title=APP_NAME)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
@app.get("/", include_in_schema=False)
def root():
return RedirectResponse(url="/docs")
@app.get("/health", include_in_schema=False)
def health():
return {"ok": True, "app": APP_NAME}
@app.get("/endpoints", include_in_schema=False)
def endpoints():
return JSONResponse({
"GET": [
"/mechanism_graph_manifest?receptor=HTR2A&symptom=apathy&species=9606",
"/mechanism_graph/nodes?job_id=<id>&page=1&page_size=200",
"/mechanism_graph/edges?job_id=<id>&page=1&page_size=200",
"/mechanism_graph/literature?job_id=<id>&page=1&page_size=50",
"/mechanism_graph/regions?job_id=<id>&page=1&page_size=50",
"/download/<job_id>/nodes (gz)",
"/download/<job_id>/edges (gz)",
"/download/<job_id>/literature (gz)",
"/download/<job_id>/regions (gz)",
"/util/synonyms?term=apathy&kind=phenotype",
"/heuristics/regions_from_string?receptor=HTR2A&symptom=apathy&limit=40",
"/lit/eupmc?query=HTR2A%20AND%20apathy&pageSize=5",
"/string/network?identifiers=HTR2A&species=9606",
"/gpcrdb/protein?entry=htr2a_human",
"/uniprot/search?query=HTR2A&size=5",
"/rxnav/rxcui?name=fluoxetine",
"/pubchem/compound_by_name?name=fluoxetine",
"/trials/search?q=HTR2A&pageSize=5",
"/health", "/docs"
]
})
UA = {"User-Agent": f"{APP_NAME}/1.2 (HF Space)"}
# ----------------- tiny in-memory TTL cache -----------------
class TTLCache:
def __init__(self, max_items=512):
self.store: Dict[str, Tuple[float, Any]] = {}
self.max_items = max_items
self._lock = asyncio.Lock()
def _mk(self, url: str, params: Optional[dict]) -> str:
key = url + "?" + (json.dumps(params, sort_keys=True) if params else "")
return hashlib.sha1(key.encode()).hexdigest()
async def get(self, url: str, params: Optional[dict], ttl: float):
k = self._mk(url, params)
async with self._lock:
item = self.store.get(k)
if item and (time.time() < item[0]):
return item[1]
async with httpx.AsyncClient(headers=UA, timeout=30) as client:
r = await client.get(url, params=params)
r.raise_for_status()
# Some third-party APIs return plain text/HTML on error;
# Fast path: try JSON, else wrap as text.
try:
data = r.json()
except Exception:
data = {"text": r.text, "status_code": r.status_code}
async with self._lock:
if len(self.store) > self.max_items:
self.store.pop(next(iter(self.store)))
self.store[k] = (time.time() + ttl, data)
return data
CACHE = TTLCache()
# ----------------- polite throttling for STRING ------------------
_last_string_call = 0.0
async def throttle_string():
"""Be nice to STRING; ~1 req/sec as a courtesy."""
global _last_string_call
now = time.time()
wait = 1.05 - (now - _last_string_call)
if wait > 0:
await asyncio.sleep(wait)
_last_string_call = time.time()
# ----------------- helpers -----------------
async def get_json_cached(url: str, params: Optional[dict], ttl: int):
return await CACHE.get(url, params, ttl)
def _safe_float(x, default=0.0):
try:
return float(x)
except Exception:
return default
def _hash_params(d: dict) -> str:
return hashlib.sha1(json.dumps(d, sort_keys=True).encode()).hexdigest()
# ----------------- base connectors -----------------
@app.get("/lit/eupmc")
async def europe_pmc_search(query: str, pageSize: int = 5, page: int = 1):
# Europe PMC REST search (JSON)
# docs: https://europepmc.org/RestfulWebService ; client vignette: europepmc R pkg
url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
params = {"query": query, "format": "json", "pageSize": pageSize, "page": page}
return await get_json_cached(url, params, ttl=600)
@app.get("/lit/pubmed_esearch")
async def pubmed_esearch(term: str, retmax: int = 10):
url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
params = {"db":"pubmed","term":term,"retmode":"json","retmax":retmax}
return await get_json_cached(url, params, ttl=600)
@app.get("/trials/search")
async def ctgov_v2_studies(q: str, pageSize: int = 5):
url = "https://clinicaltrials.gov/api/v2/studies"
params = {"query.term": q, "pageSize": pageSize}
return await get_json_cached(url, params, ttl=900)
@app.get("/rxnav/rxcui")
async def rxnav_rxcui(name: str):
url = "https://rxnav.nlm.nih.gov/REST/rxcui.json"
params = {"name": name}
return await get_json_cached(url, params, ttl=86400)
@app.get("/openfda/ae")
async def openfda_adverse_events(drug: str, limit: int = 5):
url = "https://api.fda.gov/drug/event.json"
params = {"search": f'patient.drug.medicinalproduct:"{drug}"', "limit": limit}
return await get_json_cached(url, params, ttl=3600)
@app.get("/pubchem/compound_by_name")
async def pubchem_by_name(name: str):
url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{name}/JSON"
return await get_json_cached(url, None, ttl=86400)
@app.get("/uniprot/search")
async def uniprot_search(query: str, size: int = 5):
url = "https://rest.uniprot.org/uniprotkb/search"
params = {"query": query, "format": "json", "size": size}
return await get_json_cached(url, params, ttl=86400)
@app.get("/gpcrdb/protein")
async def gpcrdb_protein(entry: str):
url = f"https://gpcrdb.org/services/protein/{entry}"
return await get_json_cached(url, None, ttl=86400)
@app.get("/string/network")
async def string_network(identifiers: str, species: int = 9606, limit: int = 50):
# STRING JSON network endpoint
await throttle_string()
url = "https://string-db.org/api/json/network"
params = {"identifiers": identifiers, "species": species, "caller_identity": CALLER_ID, "limit": limit}
return await get_json_cached(url, params, ttl=3600)
# ----------------- synonym utilities -----------------
# curated region slang/aliases (additive to OLS)
CURATED_REGION_SYNONYMS = {
"prefrontal cortex": ["PFC", "frontal cortex", "dorsolateral prefrontal cortex", "dlPFC",
"ventromedial prefrontal cortex", "vmPFC", "orbitofrontal cortex", "OFC"],
"anterior cingulate cortex": ["ACC", "dorsal ACC", "dACC", "rostral ACC", "rACC"],
"nucleus accumbens": ["NAc", "ventral striatum"],
"ventral tegmental area": ["VTA"],
"substantia nigra": ["SN", "pars compacta", "SNc"],
"hippocampus": ["hippocampal formation", "CA1", "CA3", "dentate gyrus"],
"amygdala": ["basolateral amygdala", "BLA", "central amygdala"]
}
async def _ols_synonyms(term: str, ontologies: Optional[List[str]] = None) -> List[str]:
# OLS4 search; aggregate synonyms for top hits containing the term
url = "https://www.ebi.ac.uk/ols4/api/search"
params = {"q": term}
if ontologies:
# OLS4 supports multiple ontology filters as repeated params
# We'll just join as comma-separated for brevity (works for OLS4)
params["ontology"] = ",".join(ontologies)
data = await get_json_cached(url, params, ttl=86400)
syns = set()
try:
docs = data.get("response", {}).get("docs", [])
for d in docs[:5]:
for s in d.get("synonyms", []) or []:
if isinstance(s, str):
syns.add(s)
except Exception:
pass
return list(syns)
async def _mygene_aliases(symbol: str) -> List[str]:
# MyGene.info v3; pull aliases/other names for the main focus gene
url = "https://mygene.info/v3/query"
params = {"q": f"symbol:{symbol}", "fields": "symbol,name,alias,alias_symbol,other_names", "size": 1, "species": "human"}
data = await get_json_cached(url, params, ttl=86400)
syns = set()
try:
hits = data.get("hits", [])
if hits:
h = hits[0]
for fld in ("symbol","name"):
v = h.get(fld)
if isinstance(v, str):
syns.add(v)
for fld in ("alias","alias_symbol","other_names"):
v = h.get(fld)
if isinstance(v, list):
for x in v:
if isinstance(x, str):
syns.add(x)
except Exception:
pass
return list(syns)
@app.get("/util/synonyms")
async def util_synonyms(term: str, kind: Literal["region","gene","phenotype","auto"]="auto"):
"""
Fetch synonyms for a term.
region: OLS4 (UBERON,HBP/HPO where applicable) + curated slang
gene: MyGene.info aliases
phenotype: OLS4(HPO)
auto: choose gene if ALLCAPS letters+digits, else phenotype->region fallback.
"""
k = kind
if k == "auto":
k = "gene" if term.isupper() else "phenotype"
syns = set([term])
if k == "region":
syns.update(CURATED_REGION_SYNONYMS.get(term.lower(), []))
syns.update(await _ols_synonyms(term, ontologies=["uberon","hbp","hpo","ncit"]))
elif k == "gene":
syns.update(await _mygene_aliases(term))
elif k == "phenotype":
syns.update(await _ols_synonyms(term, ontologies=["hpo","efo","mondo"]))
return {"term": term, "kind": k, "synonyms": sorted({s for s in syns if isinstance(s, str) and len(s) <= 60})}
# ----------------- region heuristic (upgraded) -----------------
REGION_TERMS_DEFAULT = [
"prefrontal cortex","anterior cingulate cortex","mPFC","ACC","nucleus accumbens","ventral striatum",
"dorsal striatum","caudate","putamen","amygdala","hippocampus","thalamus","hypothalamus",
"insula","ventral tegmental area","VTA","substantia nigra","cerebellum"
]
def collect_gene_symbols_from_string(edges: List[dict], focus: str) -> List[str]:
genes = set()
f = focus.upper()
for e in edges or []:
for k in ("preferredName_A","preferredName_B"):
g = e.get(k)
if g and isinstance(g,str) and g.upper() != f:
genes.add(g)
return list(genes)
async def _eupmc_hitcount(q: str) -> int:
# Europe PMC search hitCount (pageSize=0)
url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
params = {"query": q, "format": "json", "pageSize": 0}
data = await get_json_cached(url, params, ttl=3600)
try:
return int(data.get("hitCount", 0))
except Exception:
return 0
@app.get("/heuristics/regions_from_string")
async def regions_from_string(
receptor: str = Query(..., description="e.g., HTR2A"),
species: int = 9606,
limit: int = 40,
regions: Optional[str] = Query(None, description="comma-separated region terms (optional)"),
symptom: Optional[str] = Query(None, description="optional phenotype/symptom to weight co-mentions (e.g., apathy)")
):
"""
Heuristic: rank brain regions by STRING neighbors + Europe PMC co-mentions, with synonyms & tiered fallbacks.
Tiers (all unquoted for flexible match):
T1: (region_syns) AND ((receptor_syns) OR neighbors) AND (symptom_syns?) weight 1.0
T2: (region_syns) AND (receptor_syns OR neighbors) weight 0.6
T3: (region_syns) AND (receptor_syns) weight 0.5
T4: (region_syns) AND (symptom_syns) weight 0.3
Final score = log10(weighted_hits+1) * mean_top_STRING_conf
"""
# 1) STRING neighbors
edges = await string_network(receptor, species=species, limit=limit)
neighbors = collect_gene_symbols_from_string(edges, receptor)
# STRING confidences
conf: Dict[str, float] = {}
for e in edges or []:
a, b, score = e.get("preferredName_A"), e.get("preferredName_B"), _safe_float(e.get("score", 0))
if a and a.upper() != receptor.upper():
conf[a] = max(conf.get(a, 0.0), score)
if b and b.upper() != receptor.upper():
conf[b] = max(conf.get(b, 0.0), score)
mean_conf = sum(conf.values())/max(len(conf),1) if conf else 0.2
# 2) synonyms
receptor_syns = await _mygene_aliases(receptor)
symptom_syns = []
if symptom:
s = await util_synonyms(symptom, kind="phenotype")
symptom_syns = s["synonyms"]
region_list = [r.strip() for r in (regions.split(",") if regions else REGION_TERMS_DEFAULT) if r.strip()]
# Build clauses (unquoted OR lists)
gene_clause = " OR ".join(sorted({receptor} | set(receptor_syns) | set(neighbors[:25])))
results = []
tasks = []
tier_defs = []
for region in region_list:
# region synonyms
rs = await util_synonyms(region, kind="region")
region_syns = rs["synonyms"]
region_clause = " OR ".join(region_syns)
# tiers
# T1
if symptom and symptom_syns:
t1 = f"({region_clause}) AND (({gene_clause})) AND ({' OR '.join(symptom_syns)})"
else:
t1 = f"({region_clause}) AND (({gene_clause}))"
t2 = f"({region_clause}) AND (({gene_clause}))"
t3 = f"({region_clause}) AND ({' OR '.join(sorted(set([receptor] + receptor_syns)))})"
t4 = f"({region_clause}) AND ({' OR '.join(symptom_syns)})" if symptom_syns else None
tiers = [("t1",1.0,t1), ("t2",0.6,t2), ("t3",0.5,t3)]
if t4: tiers.append(("t4",0.3,t4))
# schedule hitCount calls
tier_defs.append((region, tiers))
for _,_,q in tiers:
tasks.append(_eupmc_hitcount(q))
# gather all counts in-order
counts_all = await asyncio.gather(*tasks)
# fold back into regions
idx = 0
for region, tiers in tier_defs:
weighted = 0.0
tier_counts = {}
for name, weight, _q in tiers:
hc = counts_all[idx]; idx += 1
tier_counts[name] = hc
weighted += weight * hc
score = math.log10(weighted + 1.0) * mean_conf
results.append({"region": region, "tiers": tier_counts, "weighted_hits": int(round(weighted)),
"weighted_score": round(score, 4)})
results.sort(key=lambda x: x["weighted_score"], reverse=True)
return {
"focus": receptor,
"neighbors_considered": neighbors[:25],
"regions_ranked": results,
"notes": "STRING + Europe PMC with synonyms and tiered fallbacks (unquoted)."
}
# ----------------- MANIFEST + PAGED SECTIONS + DOWNLOAD -----------------
def _job_dir(job_id: str) -> _Path:
d = DATA_DIR / job_id
d.mkdir(parents=True, exist_ok=True)
return d
def _write_gz_jsonl(path: _Path, items: List[dict]):
with gzip.open(path, "wt", encoding="utf-8") as gz:
for it in items:
gz.write(json.dumps(it, ensure_ascii=False) + "\n")
def _read_gz_page(path: _Path, page: int, page_size: int) -> Tuple[int, List[dict]]:
total = 0
start = (page - 1) * page_size
end = start + page_size
out = []
with gzip.open(path, "rt", encoding="utf-8") as gz:
for i, line in enumerate(gz):
if not line.strip():
continue
if i >= start and i < end:
out.append(json.loads(line))
total += 1
return total, out
async def _build_mech_job(params: dict) -> dict:
"""
Build nodes/edges/literature/regions; write gz NDJSON + meta.
"""
receptor = params["receptor"]
species = int(params.get("species", 9606))
symptom = params.get("symptom")
string_limit = int(params.get("string_limit", 200))
eupmc_page_size = int(params.get("eupmc_page_size", 100))
eupmc_max_pages = int(params.get("eupmc_max_pages", 3))
job_id = _hash_params(params)
d = _job_dir(job_id)
meta_path = d / "meta.json"
if meta_path.exists():
return json.loads(meta_path.read_text("utf-8"))
# 1) STRING edges + nodes
edges = await string_network(receptor, species=species, limit=string_limit)
edge_items = []
nodes = set([receptor])
for e in edges or []:
a = e.get("preferredName_A"); b = e.get("preferredName_B")
score = _safe_float(e.get("score", 0))
if a and b:
edge_items.append({"a": a, "b": b, "score": score})
nodes.add(a); nodes.add(b)
node_items = [{"symbol": n, "seed": (n.upper()==receptor.upper())} for n in sorted(nodes)]
_write_gz_jsonl(d / "edges.jsonl.gz", edge_items)
_write_gz_jsonl(d / "nodes.jsonl.gz", node_items)
# 2) Europe PMC literature for (receptor AND symptom?) else receptor
lit_items = []
base_q = f"{receptor} AND {symptom}" if symptom else receptor
for page in range(1, eupmc_max_pages+1):
res = await europe_pmc_search(base_q, pageSize=eupmc_page_size, page=page)
hits = res.get("resultList", {}).get("result", []) or []
for h in hits:
lit_items.append({
"id": h.get("id"),
"source": h.get("source"), "title": h.get("title"),
"pubYear": h.get("pubYear"), "authorString": h.get("authorString"),
"journalTitle": h.get("journalTitle"), "doi": h.get("doi")
})
# stop early if last page
if len(hits) < eupmc_page_size:
break
_write_gz_jsonl(d / "literature.jsonl.gz", lit_items)
# 3) Regions heuristic (with symptom)
reg = await regions_from_string(receptor=receptor, species=species, limit=min(100, string_limit), regions=None, symptom=symptom)
reg_items = []
for r in reg.get("regions_ranked", []):
reg_items.append(r)
_write_gz_jsonl(d / "regions.jsonl.gz", reg_items)
meta = {
"job_id": job_id,
"created": datetime.utcnow().isoformat() + "Z",
"params": params,
"counts": {
"nodes": len(node_items),
"edges": len(edge_items),
"literature": len(lit_items),
"regions": len(reg_items)
},
"sections": ["nodes","edges","literature","regions"]
}
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
return meta
@app.get("/mechanism_graph_manifest")
async def mechanism_graph_manifest(
receptor: str = Query(...),
species: int = 9606,
symptom: Optional[str] = None,
string_limit: int = 200,
eupmc_page_size: int = 100,
eupmc_max_pages: int = 3
):
"""
Build the full mechanism dataset server-side and return a manifest with job_id + counts.
The actual data is stored as gzipped NDJSON and can be:
- paged via /mechanism_graph/{section}?job_id=...&page=1&page_size=...
- or downloaded as a single gz file via /download/{job_id}/{section}
"""
params = {
"receptor": receptor, "species": species, "symptom": symptom,
"string_limit": string_limit, "eupmc_page_size": eupmc_page_size, "eupmc_max_pages": eupmc_max_pages
}
meta = await _build_mech_job(params)
return meta
@app.get("/mechanism_graph/{section}")
async def mechanism_graph_section(
section: Literal["nodes","edges","literature","regions"] = Path(...),
job_id: str = Query(...),
page: int = 1,
page_size: int = 100
):
"""
Return a single page from a section (nodes|edges|literature|regions).
"""
d = _job_dir(job_id)
p = d / f"{section}.jsonl.gz"
if not p.exists():
raise HTTPException(status_code=404, detail=f"section {section} not found for job {job_id}")
total, items = _read_gz_page(p, page=page, page_size=page_size)
return {
"job_id": job_id,
"section": section,
"page": page, "page_size": page_size,
"total": total,
"items": items
}
@app.get("/download/{job_id}/{section}")
async def download_section(job_id: str, section: Literal["nodes","edges","literature","regions"]):
"""
Download the full gzipped NDJSON for a section.
"""
d = _job_dir(job_id)
p = d / f"{section}.jsonl.gz"
if not p.exists():
raise HTTPException(status_code=404, detail=f"section {section} not found for job {job_id}")
return FileResponse(
path=str(p),
filename=f"{APP_NAME}-{job_id}-{section}.jsonl.gz",
media_type="application/gzip"
)