gemeo-twin-stack / src /gemeo /cohort.py
timmers's picture
GEMEO world-model β€” initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Patients-like-mine β€” cohort retrieval.
Two paths, in order of preference:
1. **Registry kNN** β€” Neo4j vector index over `Patient.embedding`. Used when
a real patient registry exists in the graph (e.g. raras-app).
2. **Literature fallback** β€” PubMed case reports retrieved by HPO+disease
similarity. Used when no registry exists, which is the common case for
a case-driven workflow (a doctor pastes a case into My Scientist and
wants a comparable cohort *now*).
The result is a `Cohort` that surfaces both real patients and curated
literature cases, with a `source` field per member so the UI can disambiguate.
"""
from __future__ import annotations
import logging
from typing import Optional
from .types import Cohort, CohortMember
logger = logging.getLogger("gemeo.cohort")
async def _safe_query(cypher: str, params: dict = None, timeout: float = 10.0) -> list:
try:
from space_graph import _safe_query as q
return await q(cypher, params or {}, timeout=timeout)
except Exception as e:
logger.debug(f"cypher failed: {e}")
return []
# ─── registry path ─────────────────────────────────────────────────────────
async def _vector_knn(embedding, limit: int = 20) -> list:
if embedding is None:
return []
cypher = """
CALL db.index.vector.queryNodes('patient_embedding', $limit, $vec)
YIELD node, score
WHERE node.wantsCommunity = true OR node.is_profile_public = true
OPTIONAL MATCH (node)-[:HAS_CONDITION {isPrimary: true}]->(d:Disease)
OPTIONAL MATCH (node)-[:HAS_PHENOTYPE]->(p:Phenotype)
RETURN node.supabaseId AS space_id,
score AS similarity,
d.orphaCode AS dx_orpha,
d.name AS dx_name,
collect(DISTINCT p.hpoId)[..15] AS hpo_ids,
node.state AS sus_region
LIMIT $limit
"""
return await _safe_query(cypher, {"vec": list(embedding), "limit": int(limit)})
async def _graph_overlap(hpo_ids: list[str], limit: int = 20) -> list:
if not hpo_ids:
return []
cypher = """
MATCH (other:Patient)-[:HAS_PHENOTYPE]->(p:Phenotype)
WHERE p.hpoId IN $hpo_ids
AND (other.wantsCommunity = true OR other.is_profile_public = true)
WITH other, collect(DISTINCT p.hpoId) AS shared, count(DISTINCT p) AS overlap
OPTIONAL MATCH (other)-[:HAS_CONDITION {isPrimary: true}]->(d:Disease)
OPTIONAL MATCH (other)-[:HAS_PHENOTYPE]->(allp:Phenotype)
WITH other, shared, overlap, d, count(DISTINCT allp) AS total_hpo
RETURN other.supabaseId AS space_id,
toFloat(overlap) / toFloat(coalesce($n_query + total_hpo - overlap, 1)) AS similarity,
d.orphaCode AS dx_orpha,
d.name AS dx_name,
shared AS hpo_ids,
other.state AS sus_region
ORDER BY similarity DESC
LIMIT $limit
"""
return await _safe_query(cypher, {
"hpo_ids": hpo_ids,
"n_query": len(hpo_ids),
"limit": int(limit),
})
# ─── case-record path (existing PatientSpace cases in the graph) ──────────
async def _other_spaces(hpo_ids: list[str], limit: int = 20) -> list:
"""Other PatientSpaces in the graph (other cases the doctor or the swarm built)."""
if not hpo_ids:
return []
cypher = """
MATCH (s:PatientSpace)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(p:Phenotype)
WHERE p.hpoId IN $hpo_ids
WITH s, collect(DISTINCT p.hpoId) AS shared
OPTIONAL MATCH (s)-[:HAS_HYPOTHESIS]->(h:SpaceHypothesis {status: 'confirmed'})-[:TARGETS]->(d:Disease)
WITH s, shared, d
OPTIONAL MATCH (s)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(allp:Phenotype)
WITH s, shared, d, count(DISTINCT allp) AS total_hpo
WITH s, shared, d, total_hpo,
toFloat(size(shared)) / toFloat(coalesce($n_query + total_hpo - size(shared), 1)) AS sim
RETURN s.space_id AS space_id,
sim AS similarity,
d.orphaCode AS dx_orpha,
d.name AS dx_name,
shared AS hpo_ids
ORDER BY similarity DESC
LIMIT $limit
"""
return await _safe_query(cypher, {
"hpo_ids": hpo_ids,
"n_query": len(hpo_ids),
"limit": int(limit),
})
# ─── literature fallback ───────────────────────────────────────────────────
async def _literature_cases(hpo_ids: list[str], orpha_codes: list[str], limit: int = 8) -> list:
"""Pull case reports from the local Neo4j Paper index (raras-app already
indexed PubMed papers with embeddings + disease links).
Schema (raras-app):
(:Paper {pmid, title, abstract, year, embedding})
(:Paper)-[:ABOUT]->(:Disease)
(:Paper)-[:MENTIONS]->(:Phenotype)
"""
cypher = """
MATCH (p:Paper)-[:ABOUT]->(d:Disease)
WHERE d.orphaCode IN $orphas
OR EXISTS {
MATCH (p)-[:MENTIONS]->(pheno:Phenotype)
WHERE pheno.hpoId IN $hpos
}
OPTIONAL MATCH (p)-[:MENTIONS]->(pheno:Phenotype)
WHERE pheno.hpoId IN $hpos
WITH p, d, collect(DISTINCT pheno.hpoId) AS shared_hpos
WITH p, d, shared_hpos,
toFloat(size(shared_hpos)) / toFloat(coalesce(size($hpos), 1)) AS sim
WHERE p.title CONTAINS 'case' OR p.title CONTAINS 'report' OR p.is_case_report = true OR sim > 0.2
RETURN p.pmid AS pmid, p.title AS title, p.year AS year,
d.orphaCode AS dx_orpha, d.name AS dx_name,
shared_hpos AS hpo_ids,
sim AS similarity
ORDER BY similarity DESC, p.year DESC
LIMIT $limit
"""
return await _safe_query(cypher, {
"hpos": hpo_ids,
"orphas": orpha_codes,
"limit": int(limit),
})
# ─── merge ─────────────────────────────────────────────────────────────────
def _merge(*hits_lists: list, k: int = 10) -> list:
seen = {}
for hits in hits_lists:
for r in hits:
sid = r.get("space_id") or (f"pmid:{r['pmid']}" if r.get("pmid") else None)
if not sid or sid in seen:
continue
seen[sid] = {
"space_id": sid,
"similarity": float(r.get("similarity", 0)),
"shared_phenotypes": r.get("hpo_ids", []) or [],
"shared_diseases": [r["dx_orpha"]] if r.get("dx_orpha") else [],
"confirmed_diagnosis": r.get("dx_name"),
"confirmed_orpha": r.get("dx_orpha"),
"sus_region": r.get("sus_region"),
}
members = list(seen.values())
members.sort(key=lambda m: m["similarity"], reverse=True)
return members[:k]
async def find_cohort(
*,
embedding=None,
hpo_ids: list[str] = None,
orpha_codes: list[str] = None,
k: int = 10,
include_literature: bool = True,
) -> Cohort:
"""Find a cohort of similar cases.
Tries, in parallel:
1. Real-patient registry (if Neo4j has Patient nodes with embeddings)
2. Other case spaces in the graph (other PatientSpaces)
3. PubMed case reports indexed in the graph
A `Cohort` is returned even if all paths are empty, with `n_total_population=0`.
"""
hpo_ids = hpo_ids or []
orpha_codes = orpha_codes or []
import asyncio
async def _safe(c):
try:
return await c
except Exception as e:
logger.debug(f"cohort path failed: {e}")
return []
tasks = []
if embedding is not None:
tasks.append(_safe(_vector_knn(embedding, limit=max(k * 2, 20))))
else:
tasks.append(asyncio.sleep(0, result=[]))
if hpo_ids:
tasks.append(_safe(_graph_overlap(hpo_ids, limit=max(k * 2, 20))))
tasks.append(_safe(_other_spaces(hpo_ids, limit=max(k, 10))))
else:
tasks.append(asyncio.sleep(0, result=[]))
tasks.append(asyncio.sleep(0, result=[]))
if include_literature and (hpo_ids or orpha_codes):
tasks.append(_safe(_literature_cases(hpo_ids, orpha_codes, limit=max(k, 8))))
else:
tasks.append(asyncio.sleep(0, result=[]))
vector_hits, graph_hits, space_hits, lit_hits = await asyncio.gather(*tasks)
members = _merge(vector_hits, graph_hits, space_hits, lit_hits, k=k)
out = [CohortMember(**m) for m in members]
centroid = None
if out:
from collections import Counter
c = Counter(
(m.confirmed_orpha, m.confirmed_diagnosis)
for m in out if m.confirmed_orpha
)
if c:
(orpha, name), n = c.most_common(1)[0]
centroid = {"orpha": orpha, "name": name, "count": n, "fraction": n / len(out)}
method = (
"knn_fused" if vector_hits else
("graph_spaces" if (graph_hits or space_hits) else "literature")
)
n_total = 0
try:
rows = await _safe_query(
"MATCH (n) WHERE n:Patient OR n:PatientSpace RETURN count(n) AS n",
{},
)
if rows:
n_total = int(rows[0].get("n", 0))
except Exception:
pass
return Cohort(
members=out,
method=method,
radius=1.0 - (out[-1].similarity if out else 0.0),
n_total_population=n_total,
centroid_disease=centroid,
)