darkfrostx commited on
Commit
c580fa4
·
verified ·
1 Parent(s): 37671be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -220
app.py CHANGED
@@ -1,67 +1,16 @@
1
  from fastapi import FastAPI, Query, Path, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import RedirectResponse, JSONResponse, StreamingResponse, FileResponse
4
- import httpx, asyncio, time, os, hashlib, json, io, gzip, math
5
  from typing import Dict, Any, Tuple, Optional, List
6
 
7
- # === Block 1: entity resolver (drug vs gene) + safe helpers ===
8
-
9
- KNOWN_DRUG_TARGETS = {
10
- "bupropion": ["SLC6A3","SLC6A2","CHRNA4","CHRNB2","CHRNA3","CHRNB4"],
11
- "bupropion hcl": ["SLC6A3","SLC6A2","CHRNA4","CHRNB2","CHRNA3","CHRNB4"],
12
- "atomoxetine": ["SLC6A2"],
13
- "reboxetine": ["SLC6A2"],
14
- "methylphenidate": ["SLC6A3","SLC6A2"],
15
- "dexmethylphenidate": ["SLC6A3","SLC6A2"],
16
- "lisdexamfetamine": ["SLC6A3","SLC6A2"],
17
- "dextroamphetamine": ["SLC6A3","SLC6A2"],
18
- }
19
-
20
- def _norm(s: str) -> str:
21
- return (s or "").strip().lower()
22
-
23
- async def resolve_mode(receptor: str) -> dict:
24
- """
25
- Decide if 'receptor' is a gene (via MyGene) or a drug (fallback map).
26
- Returns {"mode":"gene"|"drug","canonical":..., "genes":[...]}
27
- """
28
- rq = (receptor or "").strip()
29
- if not rq:
30
- return {"mode":"gene","canonical":rq,"genes":[]}
31
- # Try exact gene symbol via MyGene
32
- try:
33
- async with httpx.AsyncClient(headers=UA, timeout=15) as client:
34
- r = await client.get("https://mygene.info/v3/query",
35
- params={"q": f"symbol:{rq}", "fields": "symbol", "size": 1})
36
- if r.status_code == 200:
37
- js = r.json() or {}
38
- hits = js.get("hits") or []
39
- if hits and hits[0].get("symbol"):
40
- sym = hits[0]["symbol"]
41
- return {"mode":"gene","canonical":sym,"genes":[sym]}
42
- except Exception:
43
- pass
44
- # Not a sure gene: treat as drug
45
- dn = _norm(rq)
46
- genes = KNOWN_DRUG_TARGETS.get(dn, [])
47
- return {"mode":"drug","canonical":rq,"genes":genes}
48
-
49
- # REPLACES the old version
50
- def collect_gene_symbols_from_string(edges: List[dict], focus: str) -> List[str]:
51
- return collect_gene_symbols_from_string_safe(edges, focus)
52
-
53
- """
54
- STRING can fail or return dicts. Be defensive and only read from lists of dicts.
55
- """
56
- genes = set()
57
-
58
-
59
  APP_NAME = "neuro-mechanism-backend"
60
- CALLER_ID = "neuro-mech-backend-demo" # shows in STRING logs / rate fairness
61
  UA = {"User-Agent": f"{APP_NAME}/1.2 (HF Space)"}
62
 
 
63
  app = FastAPI(title=APP_NAME)
64
-
65
  app.add_middleware(
66
  CORSMiddleware,
67
  allow_origins=["*"], allow_credentials=True,
@@ -96,7 +45,7 @@ def endpoints():
96
  ]
97
  })
98
 
99
- # ----------------- tiny in-memory TTL cache -----------------
100
  class TTLCache:
101
  def __init__(self, max_items=512):
102
  self.store: Dict[str, Tuple[float, Any]] = {}
@@ -125,7 +74,7 @@ class TTLCache:
125
 
126
  CACHE = TTLCache()
127
 
128
- # --------------- polite throttling for STRING ----------------
129
  _last_string_call = 0.0
130
  async def throttle_string():
131
  """Be nice to STRING; ~1 req/sec is a good courtesy."""
@@ -136,7 +85,7 @@ async def throttle_string():
136
  await asyncio.sleep(wait)
137
  _last_string_call = time.time()
138
 
139
- # ----------------- Helpers -----------------
140
  async def get_json_cached(url: str, params: Optional[dict], ttl: int):
141
  try:
142
  return await CACHE.get(url, params, ttl)
@@ -154,9 +103,10 @@ def gz_json_bytes(obj: Any) -> bytes:
154
  gz.write(b)
155
  return bio.getvalue()
156
 
157
- # ----------------- External API wrappers -----------------
158
  @app.get("/lit/eupmc")
159
  async def europe_pmc_search(query: str, pageSize: int = 5):
 
160
  url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
161
  params = {"query": query, "format": "json", "pageSize": pageSize}
162
  return await get_json_cached(url, params, ttl=600)
@@ -204,12 +154,52 @@ async def gpcrdb_protein(entry: str):
204
  @app.get("/string/network")
205
  async def string_network(identifiers: str, species: int = 9606, limit: int = 50):
206
  await throttle_string()
 
207
  url = "https://string-db.org/api/json/network"
208
  params = {"identifiers": identifiers, "species": species, "caller_identity": CALLER_ID, "limit": limit}
209
  return await get_json_cached(url, params, ttl=3600)
210
 
211
- # ----------------- Synonyms (regions/genes/phenotypes) --------------
212
- # Simple built-in expansions + OLS/MyGene lookups.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  REGION_SEED_SYNONYMS = {
214
  "prefrontal cortex": ["PFC","mPFC","vmPFC","dlPFC","dorsolateral prefrontal cortex","ventromedial prefrontal cortex"],
215
  "anterior cingulate cortex": ["ACC","dACC","pgACC","sgACC","subgenual cingulate"],
@@ -225,7 +215,6 @@ REGION_SEED_SYNONYMS = {
225
  }
226
 
227
  async def ols4_synonyms(term: str, ontology: Optional[str] = None) -> List[str]:
228
- # OLS4 generic search (best-effort parse)
229
  url = "https://www.ebi.ac.uk/ols4/api/search"
230
  params = {"q": term, "rows": 20}
231
  if ontology:
@@ -233,7 +222,7 @@ async def ols4_synonyms(term: str, ontology: Optional[str] = None) -> List[str]:
233
  data = await get_json_cached(url, params, ttl=86400)
234
  syns = []
235
  try:
236
- docs = data.get("response", {}).get("docs", []) or data.get("response", {}).get("docs", [])
237
  for d in docs:
238
  if "synonym" in d:
239
  syns.extend(d.get("synonym", []))
@@ -241,18 +230,14 @@ async def ols4_synonyms(term: str, ontology: Optional[str] = None) -> List[str]:
241
  syns.append(d["label"])
242
  except Exception:
243
  pass
244
- # Dedup & lowercase normalize
245
- out = []
246
- seen = set()
247
  for s in syns:
248
  s2 = s.strip()
249
- if s2.lower() not in seen:
250
- out.append(s2)
251
- seen.add(s2.lower())
252
  return out[:50]
253
 
254
  async def mygene_synonyms(symbol: str) -> List[str]:
255
- # MyGene.info gene synonyms/aliases
256
  url = "https://mygene.info/v3/query"
257
  params = {"q": symbol, "fields": "symbol,name,alias,other_names", "size": 5}
258
  data = await get_json_cached(url, params, ttl=86400)
@@ -265,7 +250,6 @@ async def mygene_synonyms(symbol: str) -> List[str]:
265
  if k in hit and isinstance(hit[k], list): syns.extend(hit[k])
266
  except Exception:
267
  pass
268
- # unique
269
  out, seen = [], set()
270
  for s in syns:
271
  s2 = str(s).strip()
@@ -284,11 +268,10 @@ async def util_synonyms(term: str, kind: str = Query("region", enum=["region","g
284
  mg = await mygene_synonyms(term_norm)
285
  return {"term": term_norm, "kind": kind, "synonyms": sorted(set([term_norm] + mg))}
286
  else:
287
- # phenotype via OLS (HPO)
288
  ols = await ols4_synonyms(term_norm, ontology="hp")
289
  return {"term": term_norm, "kind": kind, "synonyms": sorted(set([term_norm] + ols))}
290
 
291
- # ----------------- Regions heuristic (improved) -----------------
292
  REGION_TERMS_DEFAULT = [
293
  "prefrontal cortex","anterior cingulate cortex","nucleus accumbens","ventral striatum",
294
  "dorsal striatum","caudate","putamen","amygdala","hippocampus","thalamus","hypothalamus",
@@ -296,6 +279,7 @@ REGION_TERMS_DEFAULT = [
296
  ]
297
 
298
  async def eupmc_hitcount(q: str) -> int:
 
299
  url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
300
  params = {"query": q, "format": "json", "pageSize": 0}
301
  data = await get_json_cached(url, params, ttl=1800)
@@ -304,12 +288,12 @@ async def eupmc_hitcount(q: str) -> int:
304
  except Exception:
305
  return 0
306
 
307
- # safe parser for STRING results
308
- def collect_gene_symbols_from_string(edges: list, focus: str) -> list:
309
  genes = set()
310
- f = (focus or "").upper()
311
  if not isinstance(edges, list):
312
  return []
 
313
  for e in edges:
314
  if not isinstance(e, dict):
315
  continue
@@ -319,7 +303,6 @@ def collect_gene_symbols_from_string(edges: list, focus: str) -> list:
319
  genes.add(g)
320
  return list(genes)
321
 
322
-
323
  @app.get("/heuristics/regions_from_string")
324
  async def regions_from_string(
325
  receptor: str = Query(..., description="gene or drug"),
@@ -330,107 +313,75 @@ async def regions_from_string(
330
  symptom: Optional[str] = None
331
  ):
332
  """
333
- Robust region scoring that works for **genes** and **drugs**.
334
- - gene: (region_syns) AND (receptor OR STRING neighbors OR gene_syns)
335
- - drug: (region_syns) AND (drug)
336
- Falls back tier-by-tier and never 500s.
337
  """
338
  mode = await resolve_mode(receptor)
339
 
340
- # --- candidate regions ---
341
- REGION_TERMS_DEFAULT = [
342
- "prefrontal cortex","anterior cingulate cortex","nucleus accumbens","ventral striatum",
343
- "dorsal striatum","caudate","putamen","amygdala","hippocampus","thalamus","hypothalamus",
344
- "insula","ventral tegmental area","substantia nigra","cerebellum"
345
- ]
346
  region_list = [r.strip() for r in (regions.split(",") if regions else REGION_TERMS_DEFAULT) if r.strip()]
347
 
348
- # --- collect region synonyms ---
349
  region_syns_map = {}
350
  if use_synonyms:
351
- syn_results = []
352
  for rgn in region_list:
353
  try:
354
  syn = await util_synonyms(rgn, "region")
355
- syn_results.append((rgn, syn.get("synonyms", [])[:10] or [rgn]))
356
  except Exception:
357
- syn_results.append((rgn, [rgn]))
358
- region_syns_map = dict(syn_results)
359
  else:
360
  region_syns_map = {rgn: [rgn] for rgn in region_list}
361
 
362
- # --- build RHS terms depending on mode ---
363
- rhs_terms = []
364
- neighbors = []
365
- gene_syns = []
366
 
367
  if mode["mode"] == "gene":
368
- # STRING neighbors (defensive)
369
  try:
370
  sdata = await string_network(receptor, species=species, limit=limit)
371
  except Exception:
372
  sdata = []
373
- neighbors = collect_gene_symbols_from_string_safe(sdata, receptor)
374
-
375
- # limited gene synonyms via MyGene
376
  try:
377
  gs = await mygene_synonyms(receptor)
378
  gene_syns = gs[:20]
379
  except Exception:
380
  gene_syns = []
381
-
382
  rhs_terms = list({t for t in [receptor] + neighbors[:25] + gene_syns[:25] if t})
383
  else:
384
- # drug path — no STRING
385
  rhs_terms = [mode["canonical"]]
386
 
387
- # --- query Europe PMC (hitCount) with tiered fallbacks ---
388
- # (Europe PMC REST 'search' returns a 'hitCount' for fast counts.)
389
- # Docs: europepmc.org/RestfulWebService
390
  results = []
391
  symptom_clause = f" AND ({symptom})" if symptom else ""
392
 
393
- async def hitcount(q: str) -> int:
394
- try:
395
- url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
396
- params = {"query": q, "format": "json", "pageSize": 0}
397
- data = await get_json_cached(url, params, ttl=1800)
398
- return int(data.get("hitCount", 0))
399
- except Exception:
400
- return 0
401
-
402
  for region in region_list:
403
  syns = region_syns_map.get(region, [region])
404
  lhs = " OR ".join(syns)
405
- # Tier 1
406
  if mode["mode"] == "gene":
407
  rhs = " OR ".join(rhs_terms) if rhs_terms else receptor
408
  q1 = f"({lhs}) AND ({rhs}){symptom_clause}"
409
  else:
410
- # drug
411
  q1 = f"({lhs}) AND ({mode['canonical']}){symptom_clause}"
412
 
413
- h1 = await hitcount(q1)
414
  hits, tier = h1, "T1"
415
 
416
  if h1 == 0 and mode["mode"] == "gene":
417
- # Tier 2: region_syns AND receptor
418
  q2 = f"({lhs}) AND ({receptor}){symptom_clause}"
419
- h2 = await hitcount(q2)
420
  hits, tier = h2, "T2"
421
  if h2 == 0:
422
- # Tier 3: label AND receptor
423
  q3 = f"({region}) AND ({receptor}){symptom_clause}"
424
- h3 = await hitcount(q3)
425
  hits, tier = h3, "T3"
426
 
427
  score = math.log10(hits + 1.0)
428
- results.append({
429
- "region": region,
430
- "hits": hits,
431
- "tier": tier,
432
- "weighted_score": round(score, 4)
433
- })
434
 
435
  results.sort(key=lambda x: x["weighted_score"], reverse=True)
436
  return {
@@ -440,80 +391,7 @@ async def regions_from_string(
440
  "notes": f"mode={mode['mode']} genes={mode['genes']}"
441
  }
442
 
443
- """
444
- Rank brain regions by co-mention with (receptor OR STRING neighbors OR synonyms), with fallbacks.
445
- Tiered search:
446
- T1: (region_syns) AND (receptor OR neighbors OR gene_syns)
447
- T2: (region_syns) AND (receptor)
448
- T3: (region) AND (receptor)
449
- Unquoted broad matches are used to avoid exact-phrase misses.
450
- """
451
- # 1) STRING neighbors
452
- edges = await string_network(receptor, species=species, limit=limit)
453
- neighbors = collect_gene_symbols_from_string(edges, receptor)
454
-
455
- # 2) synonyms
456
- region_list = [r.strip() for r in (regions.split(",") if regions else REGION_TERMS_DEFAULT) if r.strip()]
457
- region_syns_map: Dict[str, List[str]] = {}
458
- if use_synonyms:
459
- syn_tasks = [util_synonyms(r, "region") for r in region_list]
460
- # run as local function calls (not HTTP)
461
- syn_results = await asyncio.gather(*[t if asyncio.iscoroutine(t) else asyncio.create_task(t) for t in syn_tasks])
462
- for r, syn in zip(region_list, syn_results):
463
- region_syns_map[r] = syn.get("synonyms", [])[:10] or [r]
464
- # gene synonyms for top neighbors (cap 20)
465
- gene_syns: List[str] = []
466
- for g in neighbors[:20]:
467
- gs = await util_synonyms(g, "gene")
468
- gene_syns.extend(gs.get("synonyms", [])[:5])
469
- gene_syns = list({s for s in gene_syns if s})
470
- else:
471
- for r in region_list:
472
- region_syns_map[r] = [r]
473
- gene_syns = []
474
-
475
- # 3) Europe PMC hits per region, tiered
476
- results = []
477
- # build RHS (receptor OR neighbors OR gene_syns)
478
- rhs_terms = [receptor] + neighbors[:25] + gene_syns[:25]
479
- rhs = " OR ".join({t for t in rhs_terms if t})
480
-
481
- for region in region_list:
482
- syns = region_syns_map.get(region, [region])
483
- lhs = " OR ".join(syns)
484
- symptom_clause = f" AND ({symptom})" if symptom else ""
485
-
486
- # T1
487
- q1 = f"({lhs}) AND ({rhs}){symptom_clause}"
488
- hc1 = await eupmc_hitcount(q1)
489
- score = math.log10(hc1 + 1.0)
490
- if hc1 == 0:
491
- # T2
492
- q2 = f"({lhs}) AND ({receptor}){symptom_clause}"
493
- hc2 = await eupmc_hitcount(q2)
494
- score = math.log10(hc2 + 1.0)
495
- if hc2 == 0:
496
- # T3
497
- q3 = f"({region}) AND ({receptor}){symptom_clause}"
498
- hc3 = await eupmc_hitcount(q3)
499
- score = math.log10(hc3 + 1.0)
500
- results.append({"region": region, "hits": hc3, "tier": "T3", "weighted_score": round(score, 4)})
501
- else:
502
- results.append({"region": region, "hits": hc2, "tier": "T2", "weighted_score": round(score, 4)})
503
- else:
504
- results.append({"region": region, "hits": hc1, "tier": "T1", "weighted_score": round(score, 4)})
505
-
506
- results.sort(key=lambda x: x["weighted_score"], reverse=True)
507
- return {
508
- "focus": receptor,
509
- "neighbors_considered": neighbors[:25],
510
- "regions_ranked": results,
511
- "notes": "Heuristic uses STRING neighbors + Europe PMC co-mentions with synonyms and fallbacks."
512
- }
513
-
514
- # ----------------- Manifest / Section / Download -----------------
515
-
516
- # ephemeral in-memory store of assembled sections (by job_id)
517
  JOBS: Dict[str, Dict[str, Any]] = {}
518
 
519
  @app.get("/mechanism_graph_manifest")
@@ -524,24 +402,18 @@ async def mechanism_graph_manifest(
524
  string_limit: int = 50,
525
  lit_page_size: int = 10
526
  ):
527
- """
528
- Returns a job_id and the list of available sections with approximate sizes.
529
- """
530
  jid = job_key(receptor, symptom)
531
 
532
- # Pre-compute lightweight counts; store minimal context for later sections
533
- # STRING count
534
  sdata = await string_network(receptor, species=species, limit=string_limit)
535
  s_count = len(sdata) if isinstance(sdata, list) else 0
536
 
537
- # Literature hitCount
538
  ldata = await europe_pmc_search(f"{receptor} AND {symptom}", pageSize=0)
539
  try:
540
  lit_hits = int(ldata.get("hitCount", 0))
541
  except Exception:
542
  lit_hits = 0
543
 
544
- # Regions heuristic preview (no synonyms parameter here; section can recalc)
545
  rdata = await regions_from_string(receptor=receptor, species=species, limit=40, regions=None, use_synonyms=True, symptom=symptom)
546
  r_count = len(rdata.get("regions_ranked", [])) if isinstance(rdata, dict) else 0
547
 
@@ -551,7 +423,6 @@ async def mechanism_graph_manifest(
551
  "receptor": receptor, "symptom": symptom,
552
  "counts": {"string_edges": s_count, "literature_hits": lit_hits, "regions": r_count}
553
  }
554
- # other sections are created lazily below
555
  }
556
 
557
  sections = [
@@ -560,7 +431,6 @@ async def mechanism_graph_manifest(
560
  {"name": "literature", "approx_size": f"{lit_hits} hits (pageSize={lit_page_size})"},
561
  {"name": "regions", "approx_size": f"{r_count} entries"}
562
  ]
563
-
564
  return {"job_id": jid, "sections": sections}
565
 
566
  @app.get("/mechanism_graph/{section}")
@@ -573,10 +443,7 @@ async def mechanism_graph_section(
573
  lit_page_size: int = 10,
574
  job_id: Optional[str] = Query(None, description="optional; use manifest if you want stable ids")
575
  ):
576
- """
577
- Returns one section. If job_id is missing or unknown, builds on the fly.
578
- """
579
- # pull context from job if available
580
  ctx = None
581
  if job_id and job_id in JOBS:
582
  ctx = JOBS[job_id].get("_meta", {})
@@ -592,7 +459,6 @@ async def mechanism_graph_section(
592
  jid = job_key(receptor, symptom or "")
593
  JOBS.setdefault(jid, {"_meta": {"receptor": receptor, "symptom": symptom or "", "species": species}})
594
  job_id = jid
595
- # ensure overview exists
596
  if "overview" not in JOBS[job_id]:
597
  sdata = await string_network(receptor, species=species, limit=string_limit)
598
  s_count = len(sdata) if isinstance(sdata, list) else 0
@@ -623,13 +489,11 @@ async def mechanism_graph_section(
623
 
624
  @app.get("/download/{job_id}/{section}")
625
  async def download_section(job_id: str, section: str):
626
- """
627
- Gzipped JSON download of a section; if section not built yet, tries to return what's there.
628
- """
629
  data = JOBS.get(job_id, {}).get(section) or JOBS.get(job_id, {}).get("_meta")
630
  if not data:
631
  raise HTTPException(status_code=404, detail="job/section not found")
632
  gz = gz_json_bytes({"job_id": job_id, "section": section, "data": data})
633
  return StreamingResponse(io.BytesIO(gz),
634
  media_type="application/gzip",
635
- headers={"Content-Disposition": f'attachment; filename="{job_id}_{section}.json.gz"'})
 
1
  from fastapi import FastAPI, Query, Path, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import RedirectResponse, JSONResponse, StreamingResponse
4
+ import httpx, asyncio, time, hashlib, json, io, gzip, math
5
  from typing import Dict, Any, Tuple, Optional, List
6
 
7
+ # ------------------ App constants ------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  APP_NAME = "neuro-mechanism-backend"
9
+ CALLER_ID = "neuro-mech-backend-demo" # polite ID for STRING
10
  UA = {"User-Agent": f"{APP_NAME}/1.2 (HF Space)"}
11
 
12
+ # ------------------ FastAPI app ------------------
13
  app = FastAPI(title=APP_NAME)
 
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"], allow_credentials=True,
 
45
  ]
46
  })
47
 
48
+ # ------------------ tiny in-memory TTL cache ------------------
49
  class TTLCache:
50
  def __init__(self, max_items=512):
51
  self.store: Dict[str, Tuple[float, Any]] = {}
 
74
 
75
  CACHE = TTLCache()
76
 
77
+ # ------------------ polite throttling for STRING ------------------
78
  _last_string_call = 0.0
79
  async def throttle_string():
80
  """Be nice to STRING; ~1 req/sec is a good courtesy."""
 
85
  await asyncio.sleep(wait)
86
  _last_string_call = time.time()
87
 
88
+ # ------------------ small helpers ------------------
89
  async def get_json_cached(url: str, params: Optional[dict], ttl: int):
90
  try:
91
  return await CACHE.get(url, params, ttl)
 
103
  gz.write(b)
104
  return bio.getvalue()
105
 
106
+ # ------------------ External API wrappers ------------------
107
  @app.get("/lit/eupmc")
108
  async def europe_pmc_search(query: str, pageSize: int = 5):
109
+ # Europe PMC returns a 'hitCount' field in the response for quick counts.
110
  url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
111
  params = {"query": query, "format": "json", "pageSize": pageSize}
112
  return await get_json_cached(url, params, ttl=600)
 
154
  @app.get("/string/network")
155
  async def string_network(identifiers: str, species: int = 9606, limit: int = 50):
156
  await throttle_string()
157
+ # STRING API supports a 'caller_identity' parameter – good practice to include it.
158
  url = "https://string-db.org/api/json/network"
159
  params = {"identifiers": identifiers, "species": species, "caller_identity": CALLER_ID, "limit": limit}
160
  return await get_json_cached(url, params, ttl=3600)
161
 
162
+ # ------------------ Entity resolver (drug vs gene) ------------------
163
+ KNOWN_DRUG_TARGETS = {
164
+ "bupropion": ["SLC6A3","SLC6A2","CHRNA4","CHRNB2","CHRNA3","CHRNB4"],
165
+ "bupropion hcl": ["SLC6A3","SLC6A2","CHRNA4","CHRNB2","CHRNA3","CHRNB4"],
166
+ "atomoxetine": ["SLC6A2"],
167
+ "reboxetine": ["SLC6A2"],
168
+ "methylphenidate": ["SLC6A3","SLC6A2"],
169
+ "dexmethylphenidate": ["SLC6A3","SLC6A2"],
170
+ "lisdexamfetamine": ["SLC6A3","SLC6A2"],
171
+ "dextroamphetamine": ["SLC6A3","SLC6A2"],
172
+ }
173
+ def _norm(s: str) -> str:
174
+ return (s or "").strip().lower()
175
+
176
+ async def resolve_mode(receptor: str) -> dict:
177
+ """
178
+ Decide if 'receptor' is a gene (via MyGene) or a drug (fallback map).
179
+ Returns {"mode":"gene"|"drug","canonical":..., "genes":[...]}
180
+ """
181
+ rq = (receptor or "").strip()
182
+ if not rq:
183
+ return {"mode":"gene","canonical":rq,"genes":[]}
184
+ # Try exact gene symbol first (MyGene)
185
+ try:
186
+ async with httpx.AsyncClient(headers=UA, timeout=15) as client:
187
+ r = await client.get("https://mygene.info/v3/query",
188
+ params={"q": f"symbol:{rq}", "fields": "symbol", "size": 1})
189
+ if r.status_code == 200:
190
+ js = r.json() or {}
191
+ hits = js.get("hits") or []
192
+ if hits and hits[0].get("symbol"):
193
+ sym = hits[0]["symbol"]
194
+ return {"mode":"gene","canonical":sym,"genes":[sym]}
195
+ except Exception:
196
+ pass
197
+ # Not a sure gene → treat as drug
198
+ dn = _norm(rq)
199
+ genes = KNOWN_DRUG_TARGETS.get(dn, [])
200
+ return {"mode":"drug","canonical":rq,"genes":genes}
201
+
202
+ # ------------------ Synonyms (regions/genes/phenotypes) ------------------
203
  REGION_SEED_SYNONYMS = {
204
  "prefrontal cortex": ["PFC","mPFC","vmPFC","dlPFC","dorsolateral prefrontal cortex","ventromedial prefrontal cortex"],
205
  "anterior cingulate cortex": ["ACC","dACC","pgACC","sgACC","subgenual cingulate"],
 
215
  }
216
 
217
  async def ols4_synonyms(term: str, ontology: Optional[str] = None) -> List[str]:
 
218
  url = "https://www.ebi.ac.uk/ols4/api/search"
219
  params = {"q": term, "rows": 20}
220
  if ontology:
 
222
  data = await get_json_cached(url, params, ttl=86400)
223
  syns = []
224
  try:
225
+ docs = data.get("response", {}).get("docs", []) or []
226
  for d in docs:
227
  if "synonym" in d:
228
  syns.extend(d.get("synonym", []))
 
230
  syns.append(d["label"])
231
  except Exception:
232
  pass
233
+ out, seen = [], set()
 
 
234
  for s in syns:
235
  s2 = s.strip()
236
+ if s2 and s2.lower() not in seen:
237
+ out.append(s2); seen.add(s2.lower())
 
238
  return out[:50]
239
 
240
  async def mygene_synonyms(symbol: str) -> List[str]:
 
241
  url = "https://mygene.info/v3/query"
242
  params = {"q": symbol, "fields": "symbol,name,alias,other_names", "size": 5}
243
  data = await get_json_cached(url, params, ttl=86400)
 
250
  if k in hit and isinstance(hit[k], list): syns.extend(hit[k])
251
  except Exception:
252
  pass
 
253
  out, seen = [], set()
254
  for s in syns:
255
  s2 = str(s).strip()
 
268
  mg = await mygene_synonyms(term_norm)
269
  return {"term": term_norm, "kind": kind, "synonyms": sorted(set([term_norm] + mg))}
270
  else:
 
271
  ols = await ols4_synonyms(term_norm, ontology="hp")
272
  return {"term": term_norm, "kind": kind, "synonyms": sorted(set([term_norm] + ols))}
273
 
274
+ # ------------------ Regions heuristic ------------------
275
  REGION_TERMS_DEFAULT = [
276
  "prefrontal cortex","anterior cingulate cortex","nucleus accumbens","ventral striatum",
277
  "dorsal striatum","caudate","putamen","amygdala","hippocampus","thalamus","hypothalamus",
 
279
  ]
280
 
281
  async def eupmc_hitcount(q: str) -> int:
282
+ # Europe PMC responses include <hitCount> / "hitCount" for fast counts.
283
  url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
284
  params = {"query": q, "format": "json", "pageSize": 0}
285
  data = await get_json_cached(url, params, ttl=1800)
 
288
  except Exception:
289
  return 0
290
 
291
+ def collect_gene_symbols_from_string(edges: Any, focus: str) -> List[str]:
292
+ """Defensive parse of STRING results (may not always be a list)."""
293
  genes = set()
 
294
  if not isinstance(edges, list):
295
  return []
296
+ f = (focus or "").upper()
297
  for e in edges:
298
  if not isinstance(e, dict):
299
  continue
 
303
  genes.add(g)
304
  return list(genes)
305
 
 
306
  @app.get("/heuristics/regions_from_string")
307
  async def regions_from_string(
308
  receptor: str = Query(..., description="gene or drug"),
 
313
  symptom: Optional[str] = None
314
  ):
315
  """
316
+ Robust region scoring that works for genes AND drugs.
317
+ gene: (region_syns) AND (receptor OR STRING neighbors OR gene_syns)
318
+ drug: (region_syns) AND (drug)
 
319
  """
320
  mode = await resolve_mode(receptor)
321
 
322
+ # candidates
 
 
 
 
 
323
  region_list = [r.strip() for r in (regions.split(",") if regions else REGION_TERMS_DEFAULT) if r.strip()]
324
 
325
+ # region synonyms
326
  region_syns_map = {}
327
  if use_synonyms:
 
328
  for rgn in region_list:
329
  try:
330
  syn = await util_synonyms(rgn, "region")
331
+ region_syns_map[rgn] = syn.get("synonyms", [])[:10] or [rgn]
332
  except Exception:
333
+ region_syns_map[rgn] = [rgn]
 
334
  else:
335
  region_syns_map = {rgn: [rgn] for rgn in region_list}
336
 
337
+ # RHS terms
338
+ rhs_terms: List[str] = []
339
+ neighbors: List[str] = []
340
+ gene_syns: List[str] = []
341
 
342
  if mode["mode"] == "gene":
 
343
  try:
344
  sdata = await string_network(receptor, species=species, limit=limit)
345
  except Exception:
346
  sdata = []
347
+ neighbors = collect_gene_symbols_from_string(sdata, receptor)
 
 
348
  try:
349
  gs = await mygene_synonyms(receptor)
350
  gene_syns = gs[:20]
351
  except Exception:
352
  gene_syns = []
 
353
  rhs_terms = list({t for t in [receptor] + neighbors[:25] + gene_syns[:25] if t})
354
  else:
355
+ # drug path — skip STRING
356
  rhs_terms = [mode["canonical"]]
357
 
358
+ # Europe PMC hit counts (fast)
 
 
359
  results = []
360
  symptom_clause = f" AND ({symptom})" if symptom else ""
361
 
 
 
 
 
 
 
 
 
 
362
  for region in region_list:
363
  syns = region_syns_map.get(region, [region])
364
  lhs = " OR ".join(syns)
 
365
  if mode["mode"] == "gene":
366
  rhs = " OR ".join(rhs_terms) if rhs_terms else receptor
367
  q1 = f"({lhs}) AND ({rhs}){symptom_clause}"
368
  else:
 
369
  q1 = f"({lhs}) AND ({mode['canonical']}){symptom_clause}"
370
 
371
+ h1 = await eupmc_hitcount(q1)
372
  hits, tier = h1, "T1"
373
 
374
  if h1 == 0 and mode["mode"] == "gene":
 
375
  q2 = f"({lhs}) AND ({receptor}){symptom_clause}"
376
+ h2 = await eupmc_hitcount(q2)
377
  hits, tier = h2, "T2"
378
  if h2 == 0:
 
379
  q3 = f"({region}) AND ({receptor}){symptom_clause}"
380
+ h3 = await eupmc_hitcount(q3)
381
  hits, tier = h3, "T3"
382
 
383
  score = math.log10(hits + 1.0)
384
+ results.append({"region": region, "hits": hits, "tier": tier, "weighted_score": round(score, 4)})
 
 
 
 
 
385
 
386
  results.sort(key=lambda x: x["weighted_score"], reverse=True)
387
  return {
 
391
  "notes": f"mode={mode['mode']} genes={mode['genes']}"
392
  }
393
 
394
+ # ------------------ Manifest / Section / Download ------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  JOBS: Dict[str, Dict[str, Any]] = {}
396
 
397
  @app.get("/mechanism_graph_manifest")
 
402
  string_limit: int = 50,
403
  lit_page_size: int = 10
404
  ):
405
+ """Return a job_id + available sections (lightweight counts only)."""
 
 
406
  jid = job_key(receptor, symptom)
407
 
 
 
408
  sdata = await string_network(receptor, species=species, limit=string_limit)
409
  s_count = len(sdata) if isinstance(sdata, list) else 0
410
 
 
411
  ldata = await europe_pmc_search(f"{receptor} AND {symptom}", pageSize=0)
412
  try:
413
  lit_hits = int(ldata.get("hitCount", 0))
414
  except Exception:
415
  lit_hits = 0
416
 
 
417
  rdata = await regions_from_string(receptor=receptor, species=species, limit=40, regions=None, use_synonyms=True, symptom=symptom)
418
  r_count = len(rdata.get("regions_ranked", [])) if isinstance(rdata, dict) else 0
419
 
 
423
  "receptor": receptor, "symptom": symptom,
424
  "counts": {"string_edges": s_count, "literature_hits": lit_hits, "regions": r_count}
425
  }
 
426
  }
427
 
428
  sections = [
 
431
  {"name": "literature", "approx_size": f"{lit_hits} hits (pageSize={lit_page_size})"},
432
  {"name": "regions", "approx_size": f"{r_count} entries"}
433
  ]
 
434
  return {"job_id": jid, "sections": sections}
435
 
436
  @app.get("/mechanism_graph/{section}")
 
443
  lit_page_size: int = 10,
444
  job_id: Optional[str] = Query(None, description="optional; use manifest if you want stable ids")
445
  ):
446
+ """Return one section; builds on-the-fly if no job_id."""
 
 
 
447
  ctx = None
448
  if job_id and job_id in JOBS:
449
  ctx = JOBS[job_id].get("_meta", {})
 
459
  jid = job_key(receptor, symptom or "")
460
  JOBS.setdefault(jid, {"_meta": {"receptor": receptor, "symptom": symptom or "", "species": species}})
461
  job_id = jid
 
462
  if "overview" not in JOBS[job_id]:
463
  sdata = await string_network(receptor, species=species, limit=string_limit)
464
  s_count = len(sdata) if isinstance(sdata, list) else 0
 
489
 
490
  @app.get("/download/{job_id}/{section}")
491
  async def download_section(job_id: str, section: str):
492
+ """Gzipped JSON download of a section; returns what's there."""
 
 
493
  data = JOBS.get(job_id, {}).get(section) or JOBS.get(job_id, {}).get("_meta")
494
  if not data:
495
  raise HTTPException(status_code=404, detail="job/section not found")
496
  gz = gz_json_bytes({"job_id": job_id, "section": section, "data": data})
497
  return StreamingResponse(io.BytesIO(gz),
498
  media_type="application/gzip",
499
+ headers={"Content-Disposition": f'attachment; filename="{job_id}_{section}.json.gz"'} )