File size: 9,861 Bytes
ae887ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
from fastapi import FastAPI, Query
from fastapi.middleware.cors import CORSMiddleware
import httpx, asyncio, time, os, hashlib, json
from typing import Dict, Any, Tuple, Optional, List
from fastapi.responses import RedirectResponse, JSONResponse


APP_NAME = "neuro-mechanism-backend"
CALLER_ID = "neuro-mech-backend-demo"   # shows up in STRING logs

app = FastAPI(title=APP_NAME)

@app.get("/", include_in_schema=False)
def root():
    # Nice landing: send people to the interactive docs
    return RedirectResponse(url="/docs")

@app.get("/health", include_in_schema=False)
def health():
    return {"ok": True, "app": "neuro-mechanism-backend"}

@app.get("/endpoints", include_in_schema=False)
def endpoints():
    return JSONResponse({
        "GET": [
            "/mechanism_graph?receptor=HTR2A&symptom=apathy",
            "/heuristics/regions_from_string?receptor=HTR2A&limit=40",
            "/lit/eupmc?query=HTR2A%20AND%20apathy&pageSize=5",
            "/string/network?identifiers=HTR2A&species=9606",
            "/gpcrdb/protein?entry=htr2a_human",
            "/uniprot/search?query=HTR2A&size=5",
            "/rxnav/rxcui?name=fluoxetine",
            "/pubchem/compound_by_name?name=fluoxetine",
            "/trials/search?q=HTR2A&pageSize=5",
            "/health", "/docs"
        ]
    })

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True,
    allow_methods=["*"], allow_headers=["*"]
)

UA = {"User-Agent": f"{APP_NAME}/1.1 (HF Space)"}

# ----------------- NEW: tiny in-memory TTL cache -----------------
class TTLCache:
    def __init__(self, max_items=512):
        self.store: Dict[str, Tuple[float, Any]] = {}
        self.max_items = max_items
        self._lock = asyncio.Lock()

    def _mk(self, url: str, params: Optional[dict]) -> str:
        key = url + "?" + (json.dumps(params, sort_keys=True) if params else "")
        return hashlib.sha1(key.encode()).hexdigest()

    async def get(self, url: str, params: Optional[dict], ttl: float):
        k = self._mk(url, params)
        async with self._lock:
            item = self.store.get(k)
            if item and (time.time() < item[0]):
                return item[1]
        async with httpx.AsyncClient(headers=UA, timeout=30) as client:
            r = await client.get(url, params=params)
            r.raise_for_status()
            data = r.json()
        async with self._lock:
            if len(self.store) > self.max_items:
                # drop an arbitrary item (good enough for a tiny Space)
                self.store.pop(next(iter(self.store)))
            self.store[k] = (time.time() + ttl, data)
        return data

CACHE = TTLCache()

# ----------------- polite throttling for STRING ------------------
_last_string_call = 0.0
async def throttle_string():
    """Be nice to STRING; ~1 req/sec as a courtesy."""
    global _last_string_call
    now = time.time()
    wait = 1.05 - (now - _last_string_call)
    if wait > 0:
        await asyncio.sleep(wait)
    _last_string_call = time.time()

# ----------------- Helpers -----------------
async def get_json_cached(url: str, params: dict, ttl: int):
    return await CACHE.get(url, params, ttl)

# ----------------- Existing endpoints ---------
@app.get("/lit/eupmc")
async def europe_pmc_search(query: str, pageSize: int = 5):
    url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
    params = {"query": query, "format": "json", "pageSize": pageSize}
    return await get_json_cached(url, params, ttl=600)

@app.get("/lit/pubmed_esearch")
async def pubmed_esearch(term: str, retmax: int = 10):
    url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
    params = {"db":"pubmed","term":term,"retmode":"json","retmax":retmax}
    return await get_json_cached(url, params, ttl=600)

@app.get("/trials/search")
async def ctgov_v2_studies(q: str, pageSize: int = 5):
    url = "https://clinicaltrials.gov/api/v2/studies"
    params = {"query.term": q, "pageSize": pageSize}
    return await get_json_cached(url, params, ttl=900)

@app.get("/rxnav/rxcui")
async def rxnav_rxcui(name: str):
    url = "https://rxnav.nlm.nih.gov/REST/rxcui.json"
    params = {"name": name}
    return await get_json_cached(url, params, ttl=86400)

@app.get("/openfda/ae")
async def openfda_adverse_events(drug: str, limit: int = 5):
    url = "https://api.fda.gov/drug/event.json"
    params = {"search": f'patient.drug.medicinalproduct:"{drug}"', "limit": limit}
    return await get_json_cached(url, params, ttl=3600)

@app.get("/pubchem/compound_by_name")
async def pubchem_by_name(name: str):
    url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{name}/JSON"
    return await get_json_cached(url, None, ttl=86400)

@app.get("/uniprot/search")
async def uniprot_search(query: str, size: int = 5):
    url = "https://rest.uniprot.org/uniprotkb/search"
    params = {"query": query, "format": "json", "size": size}
    return await get_json_cached(url, params, ttl=86400)

@app.get("/gpcrdb/protein")
async def gpcrdb_protein(entry: str):
    url = f"https://gpcrdb.org/services/protein/{entry}"
    return await get_json_cached(url, None, ttl=86400)

@app.get("/string/network")
async def string_network(identifiers: str, species: int = 9606, limit: int = 50):
    await throttle_string()
    url = "https://string-db.org/api/json/network"
    params = {"identifiers": identifiers, "species": species, "caller_identity": CALLER_ID, "limit": limit}
    return await get_json_cached(url, params, ttl=3600)

# ----------------- STRING → region heuristic -----------------
REGION_TERMS_DEFAULT = [
    "prefrontal cortex","anterior cingulate cortex","mPFC","ACC","nucleus accumbens","ventral striatum",
    "dorsal striatum","caudate","putamen","amygdala","hippocampus","thalamus","hypothalamus",
    "insula","VTA","substantia nigra","cerebellum"
]

async def eupmc_hitcount(q: str) -> int:
    url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
    params = {"query": q, "format": "json", "pageSize": 0}
    data = await get_json_cached(url, params, ttl=3600)
    return int(data.get("hitCount", 0))

def collect_gene_symbols_from_string(edges: List[dict], focus: str) -> List[str]:
    genes = set()
    f = focus.upper()
    for e in edges:
        for k in ("preferredName_A","preferredName_B"):
            g = e.get(k)
            if g and g.upper() != f:
                genes.add(g)
    return list(genes)

@app.get("/heuristics/regions_from_string")
async def regions_from_string(
    receptor: str = Query(..., description="e.g., HTR2A"),
    species: int = 9606,
    limit: int = 40,
    regions: Optional[str] = Query(None, description="comma-separated region terms; default common regions")
):
    """
    Heuristic: rank brain regions by (STRING-weighted) literature co-occurrence.
    """
    # 1) STRING neighbors
    edges = await string_network(receptor, species=species, limit=limit)
    neighbors = collect_gene_symbols_from_string(edges, receptor)

    # precompute STRING confidence per neighbor
    conf: Dict[str, float] = {}
    for e in edges:
        a, b, score = e.get("preferredName_A"), e.get("preferredName_B"), float(e.get("score", 0))
        if a and a.upper() != receptor.upper():
            conf[a] = max(conf.get(a, 0.0), score)
        if b and b.upper() != receptor.upper():
            conf[b] = max(conf.get(b, 0.0), score)

    region_list = [r.strip() for r in (regions.split(",") if regions else REGION_TERMS_DEFAULT) if r.strip()]
    # 2) Europe PMC hitCount per region
    gene_clause = " OR ".join([receptor] + neighbors[:25])  # cap size
    tasks = []
    for region in region_list:
        q = f'("{region}") AND ({gene_clause})'
        tasks.append(eupmc_hitcount(q))
    counts = await asyncio.gather(*tasks)

    # 3) Weighting
    import math
    mean_conf = sum(conf.values())/max(len(conf),1)
    results = []
    for region, hc in zip(region_list, counts):
        score = (math.log10(hc+1.0)) * (mean_conf if conf else 0.2)
        results.append({"region": region, "hits": hc, "weighted_score": round(score, 4)})

    results.sort(key=lambda x: x["weighted_score"], reverse=True)
    return {
        "focus": receptor,
        "neighbors_considered": neighbors[:25],
        "regions_ranked": results,
        "notes": "Exploratory heuristic using STRING neighbors + Europe PMC co-occurrence."
    }

# ----------------- Aggregator (adds region heuristic) --------
@app.get("/mechanism_graph")
async def mechanism_graph(
    receptor: str = Query(..., description="e.g., HTR2A"),
    species: int = 9606,
    symptom: str = "apathy"
):
    gpcr_entry = f"{receptor.lower()}_human" if not receptor.lower().endswith("_human") else receptor.lower()

    # cache-powered parallel fetches
    gpcr = get_json_cached(f"https://gpcrdb.org/services/protein/{gpcr_entry}", None, ttl=86400)
    string_net = get_json_cached("https://string-db.org/api/json/network",
                                 {"identifiers": receptor, "species": species, "caller_identity": CALLER_ID, "limit": 50},
                                 ttl=3600)
    lit = get_json_cached("https://www.ebi.ac.uk/europepmc/webservices/rest/search",
                          {"query": f"{receptor} AND {symptom}", "format": "json", "pageSize": 10},
                          ttl=600)
    # call our local async fn without FastAPI wrapper
    region_scores = regions_from_string.__wrapped__(receptor=receptor, species=species, limit=40, regions=None)

    gpcr_r, string_r, lit_r, regions_r = await asyncio.gather(gpcr, string_net, lit, region_scores)

    return {
        "receptor": receptor,
        "gpcrdb": gpcr_r,
        "string": string_r,
        "literature": lit_r,
        "region_scores": regions_r,
        "notes": "Mechanism aggregator with cache + STRING→region heuristic"
    }