darkfrostx's picture
Update app.py
ae887ef verified
raw
history blame
9.86 kB
from fastapi import FastAPI, Query
from fastapi.middleware.cors import CORSMiddleware
import httpx, asyncio, time, os, hashlib, json
from typing import Dict, Any, Tuple, Optional, List
from fastapi.responses import RedirectResponse, JSONResponse
APP_NAME = "neuro-mechanism-backend"
CALLER_ID = "neuro-mech-backend-demo" # shows up in STRING logs
app = FastAPI(title=APP_NAME)
@app.get("/", include_in_schema=False)
def root():
# Nice landing: send people to the interactive docs
return RedirectResponse(url="/docs")
@app.get("/health", include_in_schema=False)
def health():
return {"ok": True, "app": "neuro-mechanism-backend"}
@app.get("/endpoints", include_in_schema=False)
def endpoints():
return JSONResponse({
"GET": [
"/mechanism_graph?receptor=HTR2A&symptom=apathy",
"/heuristics/regions_from_string?receptor=HTR2A&limit=40",
"/lit/eupmc?query=HTR2A%20AND%20apathy&pageSize=5",
"/string/network?identifiers=HTR2A&species=9606",
"/gpcrdb/protein?entry=htr2a_human",
"/uniprot/search?query=HTR2A&size=5",
"/rxnav/rxcui?name=fluoxetine",
"/pubchem/compound_by_name?name=fluoxetine",
"/trials/search?q=HTR2A&pageSize=5",
"/health", "/docs"
]
})
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"]
)
UA = {"User-Agent": f"{APP_NAME}/1.1 (HF Space)"}
# ----------------- NEW: tiny in-memory TTL cache -----------------
class TTLCache:
def __init__(self, max_items=512):
self.store: Dict[str, Tuple[float, Any]] = {}
self.max_items = max_items
self._lock = asyncio.Lock()
def _mk(self, url: str, params: Optional[dict]) -> str:
key = url + "?" + (json.dumps(params, sort_keys=True) if params else "")
return hashlib.sha1(key.encode()).hexdigest()
async def get(self, url: str, params: Optional[dict], ttl: float):
k = self._mk(url, params)
async with self._lock:
item = self.store.get(k)
if item and (time.time() < item[0]):
return item[1]
async with httpx.AsyncClient(headers=UA, timeout=30) as client:
r = await client.get(url, params=params)
r.raise_for_status()
data = r.json()
async with self._lock:
if len(self.store) > self.max_items:
# drop an arbitrary item (good enough for a tiny Space)
self.store.pop(next(iter(self.store)))
self.store[k] = (time.time() + ttl, data)
return data
CACHE = TTLCache()
# ----------------- polite throttling for STRING ------------------
_last_string_call = 0.0
async def throttle_string():
"""Be nice to STRING; ~1 req/sec as a courtesy."""
global _last_string_call
now = time.time()
wait = 1.05 - (now - _last_string_call)
if wait > 0:
await asyncio.sleep(wait)
_last_string_call = time.time()
# ----------------- Helpers -----------------
async def get_json_cached(url: str, params: dict, ttl: int):
return await CACHE.get(url, params, ttl)
# ----------------- Existing endpoints ---------
@app.get("/lit/eupmc")
async def europe_pmc_search(query: str, pageSize: int = 5):
url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
params = {"query": query, "format": "json", "pageSize": pageSize}
return await get_json_cached(url, params, ttl=600)
@app.get("/lit/pubmed_esearch")
async def pubmed_esearch(term: str, retmax: int = 10):
url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
params = {"db":"pubmed","term":term,"retmode":"json","retmax":retmax}
return await get_json_cached(url, params, ttl=600)
@app.get("/trials/search")
async def ctgov_v2_studies(q: str, pageSize: int = 5):
url = "https://clinicaltrials.gov/api/v2/studies"
params = {"query.term": q, "pageSize": pageSize}
return await get_json_cached(url, params, ttl=900)
@app.get("/rxnav/rxcui")
async def rxnav_rxcui(name: str):
url = "https://rxnav.nlm.nih.gov/REST/rxcui.json"
params = {"name": name}
return await get_json_cached(url, params, ttl=86400)
@app.get("/openfda/ae")
async def openfda_adverse_events(drug: str, limit: int = 5):
url = "https://api.fda.gov/drug/event.json"
params = {"search": f'patient.drug.medicinalproduct:"{drug}"', "limit": limit}
return await get_json_cached(url, params, ttl=3600)
@app.get("/pubchem/compound_by_name")
async def pubchem_by_name(name: str):
url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{name}/JSON"
return await get_json_cached(url, None, ttl=86400)
@app.get("/uniprot/search")
async def uniprot_search(query: str, size: int = 5):
url = "https://rest.uniprot.org/uniprotkb/search"
params = {"query": query, "format": "json", "size": size}
return await get_json_cached(url, params, ttl=86400)
@app.get("/gpcrdb/protein")
async def gpcrdb_protein(entry: str):
url = f"https://gpcrdb.org/services/protein/{entry}"
return await get_json_cached(url, None, ttl=86400)
@app.get("/string/network")
async def string_network(identifiers: str, species: int = 9606, limit: int = 50):
await throttle_string()
url = "https://string-db.org/api/json/network"
params = {"identifiers": identifiers, "species": species, "caller_identity": CALLER_ID, "limit": limit}
return await get_json_cached(url, params, ttl=3600)
# ----------------- STRING → region heuristic -----------------
REGION_TERMS_DEFAULT = [
"prefrontal cortex","anterior cingulate cortex","mPFC","ACC","nucleus accumbens","ventral striatum",
"dorsal striatum","caudate","putamen","amygdala","hippocampus","thalamus","hypothalamus",
"insula","VTA","substantia nigra","cerebellum"
]
async def eupmc_hitcount(q: str) -> int:
url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
params = {"query": q, "format": "json", "pageSize": 0}
data = await get_json_cached(url, params, ttl=3600)
return int(data.get("hitCount", 0))
def collect_gene_symbols_from_string(edges: List[dict], focus: str) -> List[str]:
genes = set()
f = focus.upper()
for e in edges:
for k in ("preferredName_A","preferredName_B"):
g = e.get(k)
if g and g.upper() != f:
genes.add(g)
return list(genes)
@app.get("/heuristics/regions_from_string")
async def regions_from_string(
receptor: str = Query(..., description="e.g., HTR2A"),
species: int = 9606,
limit: int = 40,
regions: Optional[str] = Query(None, description="comma-separated region terms; default common regions")
):
"""
Heuristic: rank brain regions by (STRING-weighted) literature co-occurrence.
"""
# 1) STRING neighbors
edges = await string_network(receptor, species=species, limit=limit)
neighbors = collect_gene_symbols_from_string(edges, receptor)
# precompute STRING confidence per neighbor
conf: Dict[str, float] = {}
for e in edges:
a, b, score = e.get("preferredName_A"), e.get("preferredName_B"), float(e.get("score", 0))
if a and a.upper() != receptor.upper():
conf[a] = max(conf.get(a, 0.0), score)
if b and b.upper() != receptor.upper():
conf[b] = max(conf.get(b, 0.0), score)
region_list = [r.strip() for r in (regions.split(",") if regions else REGION_TERMS_DEFAULT) if r.strip()]
# 2) Europe PMC hitCount per region
gene_clause = " OR ".join([receptor] + neighbors[:25]) # cap size
tasks = []
for region in region_list:
q = f'("{region}") AND ({gene_clause})'
tasks.append(eupmc_hitcount(q))
counts = await asyncio.gather(*tasks)
# 3) Weighting
import math
mean_conf = sum(conf.values())/max(len(conf),1)
results = []
for region, hc in zip(region_list, counts):
score = (math.log10(hc+1.0)) * (mean_conf if conf else 0.2)
results.append({"region": region, "hits": hc, "weighted_score": round(score, 4)})
results.sort(key=lambda x: x["weighted_score"], reverse=True)
return {
"focus": receptor,
"neighbors_considered": neighbors[:25],
"regions_ranked": results,
"notes": "Exploratory heuristic using STRING neighbors + Europe PMC co-occurrence."
}
# ----------------- Aggregator (adds region heuristic) --------
@app.get("/mechanism_graph")
async def mechanism_graph(
receptor: str = Query(..., description="e.g., HTR2A"),
species: int = 9606,
symptom: str = "apathy"
):
gpcr_entry = f"{receptor.lower()}_human" if not receptor.lower().endswith("_human") else receptor.lower()
# cache-powered parallel fetches
gpcr = get_json_cached(f"https://gpcrdb.org/services/protein/{gpcr_entry}", None, ttl=86400)
string_net = get_json_cached("https://string-db.org/api/json/network",
{"identifiers": receptor, "species": species, "caller_identity": CALLER_ID, "limit": 50},
ttl=3600)
lit = get_json_cached("https://www.ebi.ac.uk/europepmc/webservices/rest/search",
{"query": f"{receptor} AND {symptom}", "format": "json", "pageSize": 10},
ttl=600)
# call our local async fn without FastAPI wrapper
region_scores = regions_from_string.__wrapped__(receptor=receptor, species=species, limit=40, regions=None)
gpcr_r, string_r, lit_r, regions_r = await asyncio.gather(gpcr, string_net, lit, region_scores)
return {
"receptor": receptor,
"gpcrdb": gpcr_r,
"string": string_r,
"literature": lit_r,
"region_scores": regions_r,
"notes": "Mechanism aggregator with cache + STRING→region heuristic"
}