"""Patients-like-mine — cohort retrieval. Two paths, in order of preference: 1. **Registry kNN** — Neo4j vector index over `Patient.embedding`. Used when a real patient registry exists in the graph (e.g. raras-app). 2. **Literature fallback** — PubMed case reports retrieved by HPO+disease similarity. Used when no registry exists, which is the common case for a case-driven workflow (a doctor pastes a case into My Scientist and wants a comparable cohort *now*). The result is a `Cohort` that surfaces both real patients and curated literature cases, with a `source` field per member so the UI can disambiguate. """ from __future__ import annotations import logging from typing import Optional from .types import Cohort, CohortMember logger = logging.getLogger("gemeo.cohort") async def _safe_query(cypher: str, params: dict = None, timeout: float = 10.0) -> list: try: from space_graph import _safe_query as q return await q(cypher, params or {}, timeout=timeout) except Exception as e: logger.debug(f"cypher failed: {e}") return [] # ─── registry path ───────────────────────────────────────────────────────── async def _vector_knn(embedding, limit: int = 20) -> list: if embedding is None: return [] cypher = """ CALL db.index.vector.queryNodes('patient_embedding', $limit, $vec) YIELD node, score WHERE node.wantsCommunity = true OR node.is_profile_public = true OPTIONAL MATCH (node)-[:HAS_CONDITION {isPrimary: true}]->(d:Disease) OPTIONAL MATCH (node)-[:HAS_PHENOTYPE]->(p:Phenotype) RETURN node.supabaseId AS space_id, score AS similarity, d.orphaCode AS dx_orpha, d.name AS dx_name, collect(DISTINCT p.hpoId)[..15] AS hpo_ids, node.state AS sus_region LIMIT $limit """ return await _safe_query(cypher, {"vec": list(embedding), "limit": int(limit)}) async def _graph_overlap(hpo_ids: list[str], limit: int = 20) -> list: if not hpo_ids: return [] cypher = """ MATCH (other:Patient)-[:HAS_PHENOTYPE]->(p:Phenotype) WHERE p.hpoId IN $hpo_ids AND (other.wantsCommunity = true OR other.is_profile_public = true) WITH other, collect(DISTINCT p.hpoId) AS shared, count(DISTINCT p) AS overlap OPTIONAL MATCH (other)-[:HAS_CONDITION {isPrimary: true}]->(d:Disease) OPTIONAL MATCH (other)-[:HAS_PHENOTYPE]->(allp:Phenotype) WITH other, shared, overlap, d, count(DISTINCT allp) AS total_hpo RETURN other.supabaseId AS space_id, toFloat(overlap) / toFloat(coalesce($n_query + total_hpo - overlap, 1)) AS similarity, d.orphaCode AS dx_orpha, d.name AS dx_name, shared AS hpo_ids, other.state AS sus_region ORDER BY similarity DESC LIMIT $limit """ return await _safe_query(cypher, { "hpo_ids": hpo_ids, "n_query": len(hpo_ids), "limit": int(limit), }) # ─── case-record path (existing PatientSpace cases in the graph) ────────── async def _other_spaces(hpo_ids: list[str], limit: int = 20) -> list: """Other PatientSpaces in the graph (other cases the doctor or the swarm built).""" if not hpo_ids: return [] cypher = """ MATCH (s:PatientSpace)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(p:Phenotype) WHERE p.hpoId IN $hpo_ids WITH s, collect(DISTINCT p.hpoId) AS shared OPTIONAL MATCH (s)-[:HAS_HYPOTHESIS]->(h:SpaceHypothesis {status: 'confirmed'})-[:TARGETS]->(d:Disease) WITH s, shared, d OPTIONAL MATCH (s)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(allp:Phenotype) WITH s, shared, d, count(DISTINCT allp) AS total_hpo WITH s, shared, d, total_hpo, toFloat(size(shared)) / toFloat(coalesce($n_query + total_hpo - size(shared), 1)) AS sim RETURN s.space_id AS space_id, sim AS similarity, d.orphaCode AS dx_orpha, d.name AS dx_name, shared AS hpo_ids ORDER BY similarity DESC LIMIT $limit """ return await _safe_query(cypher, { "hpo_ids": hpo_ids, "n_query": len(hpo_ids), "limit": int(limit), }) # ─── literature fallback ─────────────────────────────────────────────────── async def _literature_cases(hpo_ids: list[str], orpha_codes: list[str], limit: int = 8) -> list: """Pull case reports from the local Neo4j Paper index (raras-app already indexed PubMed papers with embeddings + disease links). Schema (raras-app): (:Paper {pmid, title, abstract, year, embedding}) (:Paper)-[:ABOUT]->(:Disease) (:Paper)-[:MENTIONS]->(:Phenotype) """ cypher = """ MATCH (p:Paper)-[:ABOUT]->(d:Disease) WHERE d.orphaCode IN $orphas OR EXISTS { MATCH (p)-[:MENTIONS]->(pheno:Phenotype) WHERE pheno.hpoId IN $hpos } OPTIONAL MATCH (p)-[:MENTIONS]->(pheno:Phenotype) WHERE pheno.hpoId IN $hpos WITH p, d, collect(DISTINCT pheno.hpoId) AS shared_hpos WITH p, d, shared_hpos, toFloat(size(shared_hpos)) / toFloat(coalesce(size($hpos), 1)) AS sim WHERE p.title CONTAINS 'case' OR p.title CONTAINS 'report' OR p.is_case_report = true OR sim > 0.2 RETURN p.pmid AS pmid, p.title AS title, p.year AS year, d.orphaCode AS dx_orpha, d.name AS dx_name, shared_hpos AS hpo_ids, sim AS similarity ORDER BY similarity DESC, p.year DESC LIMIT $limit """ return await _safe_query(cypher, { "hpos": hpo_ids, "orphas": orpha_codes, "limit": int(limit), }) # ─── merge ───────────────────────────────────────────────────────────────── def _merge(*hits_lists: list, k: int = 10) -> list: seen = {} for hits in hits_lists: for r in hits: sid = r.get("space_id") or (f"pmid:{r['pmid']}" if r.get("pmid") else None) if not sid or sid in seen: continue seen[sid] = { "space_id": sid, "similarity": float(r.get("similarity", 0)), "shared_phenotypes": r.get("hpo_ids", []) or [], "shared_diseases": [r["dx_orpha"]] if r.get("dx_orpha") else [], "confirmed_diagnosis": r.get("dx_name"), "confirmed_orpha": r.get("dx_orpha"), "sus_region": r.get("sus_region"), } members = list(seen.values()) members.sort(key=lambda m: m["similarity"], reverse=True) return members[:k] async def find_cohort( *, embedding=None, hpo_ids: list[str] = None, orpha_codes: list[str] = None, k: int = 10, include_literature: bool = True, ) -> Cohort: """Find a cohort of similar cases. Tries, in parallel: 1. Real-patient registry (if Neo4j has Patient nodes with embeddings) 2. Other case spaces in the graph (other PatientSpaces) 3. PubMed case reports indexed in the graph A `Cohort` is returned even if all paths are empty, with `n_total_population=0`. """ hpo_ids = hpo_ids or [] orpha_codes = orpha_codes or [] import asyncio async def _safe(c): try: return await c except Exception as e: logger.debug(f"cohort path failed: {e}") return [] tasks = [] if embedding is not None: tasks.append(_safe(_vector_knn(embedding, limit=max(k * 2, 20)))) else: tasks.append(asyncio.sleep(0, result=[])) if hpo_ids: tasks.append(_safe(_graph_overlap(hpo_ids, limit=max(k * 2, 20)))) tasks.append(_safe(_other_spaces(hpo_ids, limit=max(k, 10)))) else: tasks.append(asyncio.sleep(0, result=[])) tasks.append(asyncio.sleep(0, result=[])) if include_literature and (hpo_ids or orpha_codes): tasks.append(_safe(_literature_cases(hpo_ids, orpha_codes, limit=max(k, 8)))) else: tasks.append(asyncio.sleep(0, result=[])) vector_hits, graph_hits, space_hits, lit_hits = await asyncio.gather(*tasks) members = _merge(vector_hits, graph_hits, space_hits, lit_hits, k=k) out = [CohortMember(**m) for m in members] centroid = None if out: from collections import Counter c = Counter( (m.confirmed_orpha, m.confirmed_diagnosis) for m in out if m.confirmed_orpha ) if c: (orpha, name), n = c.most_common(1)[0] centroid = {"orpha": orpha, "name": name, "count": n, "fraction": n / len(out)} method = ( "knn_fused" if vector_hits else ("graph_spaces" if (graph_hits or space_hits) else "literature") ) n_total = 0 try: rows = await _safe_query( "MATCH (n) WHERE n:Patient OR n:PatientSpace RETURN count(n) AS n", {}, ) if rows: n_total = int(rows[0].get("n", 0)) except Exception: pass return Cohort( members=out, method=method, radius=1.0 - (out[-1].similarity if out else 0.0), n_total_population=n_total, centroid_disease=centroid, )