| """ |
| milestone_2a.py |
| --------------- |
| Week 2A Milestone: Symptom-to-candidate-disease via graph phenotype matching. |
| |
| Given a list of clinical symptoms, this script: |
| 1. Maps symptoms to HPO term IDs via the graph (HPOTerm name search) |
| 2. Runs the MANIFESTS_AS graph traversal to find matching diseases |
| 3. Runs BioLORD-2023 semantic search in ChromaDB in parallel |
| 4. Merges both rankings into a unified differential diagnosis list |
| |
| This is the first real diagnostic query in RareDx. |
| |
| Usage: |
| python milestone_2a.py |
| python milestone_2a.py "arachnodactyly" "ectopia lentis" "aortic dilation" |
| """ |
|
|
| import io |
| import os |
| import sys |
| import time |
| import concurrent.futures |
| from pathlib import Path |
|
|
| import chromadb |
| from chromadb.config import Settings |
| from sentence_transformers import SentenceTransformer |
| from dotenv import load_dotenv |
|
|
| |
| sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace") |
| sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace") |
|
|
| load_dotenv(Path(__file__).parents[2] / ".env") |
|
|
| CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") |
| CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000")) |
| COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases") |
| EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023") |
| CHROMA_PERSIST_DIR = Path(__file__).parents[2] / "data" / "chromadb" |
|
|
| |
| DEFAULT_SYMPTOMS = [ |
| "arachnodactyly", |
| "ectopia lentis", |
| "aortic root dilatation", |
| "scoliosis", |
| "tall stature", |
| ] |
|
|
| symptoms = sys.argv[1:] if len(sys.argv) > 1 else DEFAULT_SYMPTOMS |
|
|
|
|
| |
| |
| |
|
|
| def graph_search(symptom_list: list[str]) -> tuple[list[dict], list[str], str]: |
| """ |
| Returns (ranked_diseases, resolved_hpo_ids, backend_label). |
| Tries Neo4j first, then LocalGraphStore. |
| """ |
| |
| try: |
| from neo4j import GraphDatabase |
| from neo4j import GraphDatabase as gdb |
| neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") |
| neo4j_user = os.getenv("NEO4J_USER", "neo4j") |
| neo4j_pass = os.getenv("NEO4J_PASSWORD", "raredx_password") |
| driver = gdb.driver(neo4j_uri, auth=(neo4j_user, neo4j_pass)) |
| driver.verify_connectivity() |
|
|
| with driver.session() as session: |
| |
| hpo_ids = [] |
| for sym in symptom_list: |
| r = session.run( |
| "MATCH (h:HPOTerm) WHERE toLower(h.term) CONTAINS toLower($s) " |
| "RETURN h.hpo_id AS hid LIMIT 1", |
| s=sym, |
| ) |
| rec = r.single() |
| if rec: |
| hpo_ids.append(rec["hid"]) |
|
|
| if not hpo_ids: |
| driver.close() |
| return [], [], "Neo4j (Docker)" |
|
|
| |
| result = session.run( |
| """ |
| UNWIND $hpo_ids AS hid |
| MATCH (d:Disease)-[r:MANIFESTS_AS]->(h:HPOTerm {hpo_id: hid}) |
| WHERE r.frequency_order <> 5 |
| WITH d, count(h) AS match_count, |
| sum(CASE r.frequency_order |
| WHEN 1 THEN 5 WHEN 2 THEN 4 |
| WHEN 3 THEN 3 WHEN 4 THEN 2 |
| ELSE 1 END) AS freq_score, |
| collect({hpo_id: h.hpo_id, term: h.term, |
| freq: r.frequency_label}) AS matched_hpo |
| WHERE match_count >= 1 |
| RETURN d.orpha_code AS orpha_code, d.name AS name, |
| d.definition AS definition, |
| match_count, freq_score, matched_hpo |
| ORDER BY match_count DESC, freq_score DESC |
| LIMIT 10 |
| """, |
| hpo_ids=hpo_ids, |
| ) |
| diseases = [dict(r) for r in result] |
|
|
| driver.close() |
| return diseases, hpo_ids, "Neo4j (Docker)" |
|
|
| except Exception: |
| pass |
|
|
| |
| from graph_store import LocalGraphStore |
| store = LocalGraphStore() |
|
|
| |
| hpo_ids = [] |
| for sym in symptom_list: |
| sym_lower = sym.lower() |
| for nid, attrs in store.graph.nodes(data=True): |
| if attrs.get("type") == "HPOTerm": |
| if sym_lower in attrs.get("term", "").lower(): |
| hpo_ids.append(attrs["hpo_id"]) |
| break |
|
|
| diseases = store.find_diseases_by_hpo(hpo_ids, top_n=10) |
| return diseases, hpo_ids, "LocalGraphStore (JSON)" |
|
|
|
|
| |
| |
| |
|
|
| def chroma_search( |
| symptom_list: list[str], |
| model: SentenceTransformer, |
| n: int = 10, |
| ) -> tuple[list[dict], str]: |
| """Embed symptom list as a clinical query and search ChromaDB.""" |
| query = "Patient presents with: " + ", ".join(symptom_list) + "." |
|
|
| try: |
| client = chromadb.HttpClient( |
| host=CHROMA_HOST, |
| port=CHROMA_PORT, |
| settings=Settings(anonymized_telemetry=False), |
| ) |
| client.heartbeat() |
| backend = "ChromaDB HTTP" |
| except Exception: |
| client = chromadb.PersistentClient( |
| path=str(CHROMA_PERSIST_DIR), |
| settings=Settings(anonymized_telemetry=False), |
| ) |
| backend = "ChromaDB Embedded" |
|
|
| collection = client.get_collection(COLLECTION_NAME) |
| embedding = model.encode([query], normalize_embeddings=True) |
| results = collection.query( |
| query_embeddings=embedding.tolist(), |
| n_results=n, |
| include=["metadatas", "distances"], |
| ) |
|
|
| hits = [] |
| for meta, dist in zip(results["metadatas"][0], results["distances"][0]): |
| hits.append({ |
| "orpha_code": meta.get("orpha_code"), |
| "name": meta.get("name"), |
| "definition": meta.get("definition", ""), |
| "cosine_similarity": round(1 - dist, 4), |
| }) |
| return hits, backend |
|
|
|
|
| |
| |
| |
|
|
| def fuse_rankings( |
| graph_results: list[dict], |
| chroma_results: list[dict], |
| ) -> list[dict]: |
| """ |
| Reciprocal Rank Fusion (RRF) of graph and semantic rankings. |
| RRF score = sum(1 / (k + rank)) for each list the disease appears in. |
| k=60 is the standard constant. |
| """ |
| K = 60 |
| scores: dict[str, dict] = {} |
|
|
| for rank, d in enumerate(graph_results, 1): |
| key = str(d["orpha_code"]) |
| if key not in scores: |
| scores[key] = {"orpha_code": d["orpha_code"], "name": d["name"], |
| "definition": d.get("definition", ""), |
| "graph_rank": None, "chroma_rank": None, |
| "graph_matches": None, "chroma_sim": None, |
| "rrf_score": 0.0} |
| scores[key]["rrf_score"] += 1 / (K + rank) |
| scores[key]["graph_rank"] = rank |
| scores[key]["graph_matches"] = d.get("match_count", 0) |
|
|
| for rank, d in enumerate(chroma_results, 1): |
| key = str(d["orpha_code"]) |
| if key not in scores: |
| scores[key] = {"orpha_code": d["orpha_code"], "name": d["name"], |
| "definition": d.get("definition", ""), |
| "graph_rank": None, "chroma_rank": None, |
| "graph_matches": None, "chroma_sim": None, |
| "rrf_score": 0.0} |
| scores[key]["rrf_score"] += 1 / (K + rank) |
| scores[key]["chroma_rank"] = rank |
| scores[key]["chroma_sim"] = d.get("cosine_similarity") |
|
|
| return sorted(scores.values(), key=lambda x: x["rrf_score"], reverse=True) |
|
|
|
|
| |
| |
| |
|
|
| BOLD = "\033[1m" |
| CYAN = "\033[96m" |
| GREEN = "\033[92m" |
| YELLOW = "\033[93m" |
| MAGENTA= "\033[95m" |
| DIM = "\033[2m" |
| RESET = "\033[0m" |
| LINE = "-" * 66 |
|
|
|
|
| def print_section(title: str, color: str) -> None: |
| print(f"\n{BOLD}{color}{title}{RESET}") |
| print(LINE) |
|
|
|
|
| def print_graph_hits(diseases: list[dict], hpo_ids: list[str], backend: str) -> None: |
| print_section(f"[ Graph Traversal — {backend} ]", CYAN) |
| if not diseases: |
| print(f" {YELLOW}No graph matches. HPO IDs resolved: {hpo_ids}{RESET}") |
| return |
| print(f" {DIM}HPO IDs resolved: {', '.join(hpo_ids)}{RESET}\n") |
| for rank, d in enumerate(diseases[:5], 1): |
| mc = d.get("match_count", d.get("match_count", "?")) |
| total = d.get("total_query_terms", len(symptoms)) |
| print(f" {rank}. ORPHA:{d['orpha_code']} {BOLD}{d['name']}{RESET}") |
| print(f" Phenotype matches: {mc}/{total}") |
| matched = d.get("matched_hpo", []) |
| if matched: |
| terms = ", ".join(m["term"] for m in matched[:4]) |
| print(f" {DIM}Matched: {terms}{RESET}") |
|
|
|
|
| def print_chroma_hits(hits: list[dict], backend: str) -> None: |
| print_section(f"[ Semantic Search — BioLORD-2023 | {backend} ]", GREEN) |
| for rank, h in enumerate(hits[:5], 1): |
| sim = h["cosine_similarity"] |
| bar = "█" * int(sim * 20) + "░" * (20 - int(sim * 20)) |
| print(f" {rank}. [{bar}] {sim:.4f} ORPHA:{h['orpha_code']} {h['name']}") |
|
|
|
|
| def print_fused(fused: list[dict]) -> None: |
| print_section("[ Fused Differential Diagnosis (RRF) ]", MAGENTA) |
| print(f" {'Rank':<5} {'RRF':>6} {'Graph':>5} {'Chroma':>6} Disease") |
| print(f" {'-'*4} {'-'*6} {'-'*5} {'-'*6} {'-'*35}") |
| for rank, d in enumerate(fused[:10], 1): |
| gr = f"#{d['graph_rank']}" if d["graph_rank"] else " - " |
| cr = f"#{d['chroma_rank']}" if d["chroma_rank"] else " - " |
| rrf = d["rrf_score"] |
| print(f" {rank:<5} {rrf:.4f} {gr:>5} {cr:>6} {d['name']}") |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| print("=" * 66) |
| print("RareDx — Week 2A Milestone: Symptom-to-Diagnosis") |
| print("=" * 66) |
| print(f"\n{BOLD}Clinical query symptoms:{RESET}") |
| for s in symptoms: |
| print(f" - {s}") |
|
|
| print(f"\nLoading BioLORD-2023 ...") |
| t0 = time.time() |
| model = SentenceTransformer(EMBED_MODEL) |
| print(f" Ready in {time.time()-t0:.1f}s") |
|
|
| |
| print("\nRunning graph traversal and semantic search in parallel...") |
| t_start = time.time() |
| with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: |
| graph_fut = pool.submit(graph_search, symptoms) |
| chroma_fut = pool.submit(chroma_search, symptoms, model, 10) |
|
|
| graph_diseases, hpo_ids, graph_backend = graph_fut.result() |
| chroma_hits, chroma_backend = chroma_fut.result() |
|
|
| elapsed = time.time() - t_start |
| print(f" Completed in {elapsed:.2f}s") |
|
|
| |
| print_graph_hits(graph_diseases, hpo_ids, graph_backend) |
| print_chroma_hits(chroma_hits, chroma_backend) |
|
|
| |
| fused = fuse_rankings(graph_diseases, chroma_hits) |
| print_fused(fused) |
|
|
| |
| graph_ok = len(graph_diseases) > 0 |
| chroma_ok = len(chroma_hits) > 0 |
| fused_ok = len(fused) > 0 |
|
|
| print(f"\n{LINE}") |
| print(f"{BOLD}Week 2A Milestone Summary{RESET}") |
| print(LINE) |
| print(f" Graph traversal : {'OK' if graph_ok else 'MISS'} — {len(graph_diseases)} candidates — {graph_backend}") |
| print(f" Semantic search : {'OK' if chroma_ok else 'MISS'} — {len(chroma_hits)} candidates — {chroma_backend}") |
| print(f" Fused ranking : {'OK' if fused_ok else 'MISS'} — {len(fused)} unique candidates") |
| print() |
|
|
| if graph_ok and chroma_ok and fused_ok: |
| top = fused[0] |
| print(f" {BOLD}{GREEN}PASSED{RESET} — Top diagnosis: {top['name']} (ORPHA:{top['orpha_code']})") |
| else: |
| print(f" {YELLOW}PARTIAL or FAILED — check individual backends above{RESET}") |
| sys.exit(1) |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|