| """Patients-like-mine β cohort retrieval. |
| |
| Two paths, in order of preference: |
| 1. **Registry kNN** β Neo4j vector index over `Patient.embedding`. Used when |
| a real patient registry exists in the graph (e.g. raras-app). |
| 2. **Literature fallback** β PubMed case reports retrieved by HPO+disease |
| similarity. Used when no registry exists, which is the common case for |
| a case-driven workflow (a doctor pastes a case into My Scientist and |
| wants a comparable cohort *now*). |
| |
| The result is a `Cohort` that surfaces both real patients and curated |
| literature cases, with a `source` field per member so the UI can disambiguate. |
| """ |
| from __future__ import annotations |
| import logging |
| from typing import Optional |
|
|
| from .types import Cohort, CohortMember |
|
|
| logger = logging.getLogger("gemeo.cohort") |
|
|
|
|
| async def _safe_query(cypher: str, params: dict = None, timeout: float = 10.0) -> list: |
| try: |
| from space_graph import _safe_query as q |
| return await q(cypher, params or {}, timeout=timeout) |
| except Exception as e: |
| logger.debug(f"cypher failed: {e}") |
| return [] |
|
|
|
|
| |
|
|
| async def _vector_knn(embedding, limit: int = 20) -> list: |
| if embedding is None: |
| return [] |
| cypher = """ |
| CALL db.index.vector.queryNodes('patient_embedding', $limit, $vec) |
| YIELD node, score |
| WHERE node.wantsCommunity = true OR node.is_profile_public = true |
| OPTIONAL MATCH (node)-[:HAS_CONDITION {isPrimary: true}]->(d:Disease) |
| OPTIONAL MATCH (node)-[:HAS_PHENOTYPE]->(p:Phenotype) |
| RETURN node.supabaseId AS space_id, |
| score AS similarity, |
| d.orphaCode AS dx_orpha, |
| d.name AS dx_name, |
| collect(DISTINCT p.hpoId)[..15] AS hpo_ids, |
| node.state AS sus_region |
| LIMIT $limit |
| """ |
| return await _safe_query(cypher, {"vec": list(embedding), "limit": int(limit)}) |
|
|
|
|
| async def _graph_overlap(hpo_ids: list[str], limit: int = 20) -> list: |
| if not hpo_ids: |
| return [] |
| cypher = """ |
| MATCH (other:Patient)-[:HAS_PHENOTYPE]->(p:Phenotype) |
| WHERE p.hpoId IN $hpo_ids |
| AND (other.wantsCommunity = true OR other.is_profile_public = true) |
| WITH other, collect(DISTINCT p.hpoId) AS shared, count(DISTINCT p) AS overlap |
| OPTIONAL MATCH (other)-[:HAS_CONDITION {isPrimary: true}]->(d:Disease) |
| OPTIONAL MATCH (other)-[:HAS_PHENOTYPE]->(allp:Phenotype) |
| WITH other, shared, overlap, d, count(DISTINCT allp) AS total_hpo |
| RETURN other.supabaseId AS space_id, |
| toFloat(overlap) / toFloat(coalesce($n_query + total_hpo - overlap, 1)) AS similarity, |
| d.orphaCode AS dx_orpha, |
| d.name AS dx_name, |
| shared AS hpo_ids, |
| other.state AS sus_region |
| ORDER BY similarity DESC |
| LIMIT $limit |
| """ |
| return await _safe_query(cypher, { |
| "hpo_ids": hpo_ids, |
| "n_query": len(hpo_ids), |
| "limit": int(limit), |
| }) |
|
|
|
|
| |
|
|
| async def _other_spaces(hpo_ids: list[str], limit: int = 20) -> list: |
| """Other PatientSpaces in the graph (other cases the doctor or the swarm built).""" |
| if not hpo_ids: |
| return [] |
| cypher = """ |
| MATCH (s:PatientSpace)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(p:Phenotype) |
| WHERE p.hpoId IN $hpo_ids |
| WITH s, collect(DISTINCT p.hpoId) AS shared |
| OPTIONAL MATCH (s)-[:HAS_HYPOTHESIS]->(h:SpaceHypothesis {status: 'confirmed'})-[:TARGETS]->(d:Disease) |
| WITH s, shared, d |
| OPTIONAL MATCH (s)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(allp:Phenotype) |
| WITH s, shared, d, count(DISTINCT allp) AS total_hpo |
| WITH s, shared, d, total_hpo, |
| toFloat(size(shared)) / toFloat(coalesce($n_query + total_hpo - size(shared), 1)) AS sim |
| RETURN s.space_id AS space_id, |
| sim AS similarity, |
| d.orphaCode AS dx_orpha, |
| d.name AS dx_name, |
| shared AS hpo_ids |
| ORDER BY similarity DESC |
| LIMIT $limit |
| """ |
| return await _safe_query(cypher, { |
| "hpo_ids": hpo_ids, |
| "n_query": len(hpo_ids), |
| "limit": int(limit), |
| }) |
|
|
|
|
| |
|
|
| async def _literature_cases(hpo_ids: list[str], orpha_codes: list[str], limit: int = 8) -> list: |
| """Pull case reports from the local Neo4j Paper index (raras-app already |
| indexed PubMed papers with embeddings + disease links). |
| |
| Schema (raras-app): |
| (:Paper {pmid, title, abstract, year, embedding}) |
| (:Paper)-[:ABOUT]->(:Disease) |
| (:Paper)-[:MENTIONS]->(:Phenotype) |
| """ |
| cypher = """ |
| MATCH (p:Paper)-[:ABOUT]->(d:Disease) |
| WHERE d.orphaCode IN $orphas |
| OR EXISTS { |
| MATCH (p)-[:MENTIONS]->(pheno:Phenotype) |
| WHERE pheno.hpoId IN $hpos |
| } |
| OPTIONAL MATCH (p)-[:MENTIONS]->(pheno:Phenotype) |
| WHERE pheno.hpoId IN $hpos |
| WITH p, d, collect(DISTINCT pheno.hpoId) AS shared_hpos |
| WITH p, d, shared_hpos, |
| toFloat(size(shared_hpos)) / toFloat(coalesce(size($hpos), 1)) AS sim |
| WHERE p.title CONTAINS 'case' OR p.title CONTAINS 'report' OR p.is_case_report = true OR sim > 0.2 |
| RETURN p.pmid AS pmid, p.title AS title, p.year AS year, |
| d.orphaCode AS dx_orpha, d.name AS dx_name, |
| shared_hpos AS hpo_ids, |
| sim AS similarity |
| ORDER BY similarity DESC, p.year DESC |
| LIMIT $limit |
| """ |
| return await _safe_query(cypher, { |
| "hpos": hpo_ids, |
| "orphas": orpha_codes, |
| "limit": int(limit), |
| }) |
|
|
|
|
| |
|
|
| def _merge(*hits_lists: list, k: int = 10) -> list: |
| seen = {} |
| for hits in hits_lists: |
| for r in hits: |
| sid = r.get("space_id") or (f"pmid:{r['pmid']}" if r.get("pmid") else None) |
| if not sid or sid in seen: |
| continue |
| seen[sid] = { |
| "space_id": sid, |
| "similarity": float(r.get("similarity", 0)), |
| "shared_phenotypes": r.get("hpo_ids", []) or [], |
| "shared_diseases": [r["dx_orpha"]] if r.get("dx_orpha") else [], |
| "confirmed_diagnosis": r.get("dx_name"), |
| "confirmed_orpha": r.get("dx_orpha"), |
| "sus_region": r.get("sus_region"), |
| } |
| members = list(seen.values()) |
| members.sort(key=lambda m: m["similarity"], reverse=True) |
| return members[:k] |
|
|
|
|
| async def find_cohort( |
| *, |
| embedding=None, |
| hpo_ids: list[str] = None, |
| orpha_codes: list[str] = None, |
| k: int = 10, |
| include_literature: bool = True, |
| ) -> Cohort: |
| """Find a cohort of similar cases. |
| |
| Tries, in parallel: |
| 1. Real-patient registry (if Neo4j has Patient nodes with embeddings) |
| 2. Other case spaces in the graph (other PatientSpaces) |
| 3. PubMed case reports indexed in the graph |
| |
| A `Cohort` is returned even if all paths are empty, with `n_total_population=0`. |
| """ |
| hpo_ids = hpo_ids or [] |
| orpha_codes = orpha_codes or [] |
|
|
| import asyncio |
|
|
| async def _safe(c): |
| try: |
| return await c |
| except Exception as e: |
| logger.debug(f"cohort path failed: {e}") |
| return [] |
|
|
| tasks = [] |
| if embedding is not None: |
| tasks.append(_safe(_vector_knn(embedding, limit=max(k * 2, 20)))) |
| else: |
| tasks.append(asyncio.sleep(0, result=[])) |
|
|
| if hpo_ids: |
| tasks.append(_safe(_graph_overlap(hpo_ids, limit=max(k * 2, 20)))) |
| tasks.append(_safe(_other_spaces(hpo_ids, limit=max(k, 10)))) |
| else: |
| tasks.append(asyncio.sleep(0, result=[])) |
| tasks.append(asyncio.sleep(0, result=[])) |
|
|
| if include_literature and (hpo_ids or orpha_codes): |
| tasks.append(_safe(_literature_cases(hpo_ids, orpha_codes, limit=max(k, 8)))) |
| else: |
| tasks.append(asyncio.sleep(0, result=[])) |
|
|
| vector_hits, graph_hits, space_hits, lit_hits = await asyncio.gather(*tasks) |
|
|
| members = _merge(vector_hits, graph_hits, space_hits, lit_hits, k=k) |
| out = [CohortMember(**m) for m in members] |
|
|
| centroid = None |
| if out: |
| from collections import Counter |
| c = Counter( |
| (m.confirmed_orpha, m.confirmed_diagnosis) |
| for m in out if m.confirmed_orpha |
| ) |
| if c: |
| (orpha, name), n = c.most_common(1)[0] |
| centroid = {"orpha": orpha, "name": name, "count": n, "fraction": n / len(out)} |
|
|
| method = ( |
| "knn_fused" if vector_hits else |
| ("graph_spaces" if (graph_hits or space_hits) else "literature") |
| ) |
|
|
| n_total = 0 |
| try: |
| rows = await _safe_query( |
| "MATCH (n) WHERE n:Patient OR n:PatientSpace RETURN count(n) AS n", |
| {}, |
| ) |
| if rows: |
| n_total = int(rows[0].get("n", 0)) |
| except Exception: |
| pass |
|
|
| return Cohort( |
| members=out, |
| method=method, |
| radius=1.0 - (out[-1].similarity if out else 0.0), |
| n_total_population=n_total, |
| centroid_disease=centroid, |
| ) |
|
|