""" milestone_2a.py --------------- Week 2A Milestone: Symptom-to-candidate-disease via graph phenotype matching. Given a list of clinical symptoms, this script: 1. Maps symptoms to HPO term IDs via the graph (HPOTerm name search) 2. Runs the MANIFESTS_AS graph traversal to find matching diseases 3. Runs BioLORD-2023 semantic search in ChromaDB in parallel 4. Merges both rankings into a unified differential diagnosis list This is the first real diagnostic query in RareDx. Usage: python milestone_2a.py python milestone_2a.py "arachnodactyly" "ectopia lentis" "aortic dilation" """ import io import os import sys import time import concurrent.futures from pathlib import Path import chromadb from chromadb.config import Settings from sentence_transformers import SentenceTransformer from dotenv import load_dotenv # UTF-8 output for Windows sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace") sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace") load_dotenv(Path(__file__).parents[2] / ".env") CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000")) COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases") EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023") CHROMA_PERSIST_DIR = Path(__file__).parents[2] / "data" / "chromadb" # Default test case: classic Marfan syndrome presentation DEFAULT_SYMPTOMS = [ "arachnodactyly", "ectopia lentis", "aortic root dilatation", "scoliosis", "tall stature", ] symptoms = sys.argv[1:] if len(sys.argv) > 1 else DEFAULT_SYMPTOMS # --------------------------------------------------------------------------- # Graph query # --------------------------------------------------------------------------- def graph_search(symptom_list: list[str]) -> tuple[list[dict], list[str], str]: """ Returns (ranked_diseases, resolved_hpo_ids, backend_label). Tries Neo4j first, then LocalGraphStore. """ # Try Neo4j try: from neo4j import GraphDatabase from neo4j import GraphDatabase as gdb neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") neo4j_user = os.getenv("NEO4J_USER", "neo4j") neo4j_pass = os.getenv("NEO4J_PASSWORD", "raredx_password") driver = gdb.driver(neo4j_uri, auth=(neo4j_user, neo4j_pass)) driver.verify_connectivity() with driver.session() as session: # Resolve symptoms to HPO IDs hpo_ids = [] for sym in symptom_list: r = session.run( "MATCH (h:HPOTerm) WHERE toLower(h.term) CONTAINS toLower($s) " "RETURN h.hpo_id AS hid LIMIT 1", s=sym, ) rec = r.single() if rec: hpo_ids.append(rec["hid"]) if not hpo_ids: driver.close() return [], [], "Neo4j (Docker)" # Graph traversal result = session.run( """ UNWIND $hpo_ids AS hid MATCH (d:Disease)-[r:MANIFESTS_AS]->(h:HPOTerm {hpo_id: hid}) WHERE r.frequency_order <> 5 WITH d, count(h) AS match_count, sum(CASE r.frequency_order WHEN 1 THEN 5 WHEN 2 THEN 4 WHEN 3 THEN 3 WHEN 4 THEN 2 ELSE 1 END) AS freq_score, collect({hpo_id: h.hpo_id, term: h.term, freq: r.frequency_label}) AS matched_hpo WHERE match_count >= 1 RETURN d.orpha_code AS orpha_code, d.name AS name, d.definition AS definition, match_count, freq_score, matched_hpo ORDER BY match_count DESC, freq_score DESC LIMIT 10 """, hpo_ids=hpo_ids, ) diseases = [dict(r) for r in result] driver.close() return diseases, hpo_ids, "Neo4j (Docker)" except Exception: pass # LocalGraphStore fallback from graph_store import LocalGraphStore store = LocalGraphStore() # Resolve symptom strings to HPO IDs hpo_ids = [] for sym in symptom_list: sym_lower = sym.lower() for nid, attrs in store.graph.nodes(data=True): if attrs.get("type") == "HPOTerm": if sym_lower in attrs.get("term", "").lower(): hpo_ids.append(attrs["hpo_id"]) break diseases = store.find_diseases_by_hpo(hpo_ids, top_n=10) return diseases, hpo_ids, "LocalGraphStore (JSON)" # --------------------------------------------------------------------------- # ChromaDB semantic search # --------------------------------------------------------------------------- def chroma_search( symptom_list: list[str], model: SentenceTransformer, n: int = 10, ) -> tuple[list[dict], str]: """Embed symptom list as a clinical query and search ChromaDB.""" query = "Patient presents with: " + ", ".join(symptom_list) + "." try: client = chromadb.HttpClient( host=CHROMA_HOST, port=CHROMA_PORT, settings=Settings(anonymized_telemetry=False), ) client.heartbeat() backend = "ChromaDB HTTP" except Exception: client = chromadb.PersistentClient( path=str(CHROMA_PERSIST_DIR), settings=Settings(anonymized_telemetry=False), ) backend = "ChromaDB Embedded" collection = client.get_collection(COLLECTION_NAME) embedding = model.encode([query], normalize_embeddings=True) results = collection.query( query_embeddings=embedding.tolist(), n_results=n, include=["metadatas", "distances"], ) hits = [] for meta, dist in zip(results["metadatas"][0], results["distances"][0]): hits.append({ "orpha_code": meta.get("orpha_code"), "name": meta.get("name"), "definition": meta.get("definition", ""), "cosine_similarity": round(1 - dist, 4), }) return hits, backend # --------------------------------------------------------------------------- # Score fusion # --------------------------------------------------------------------------- def fuse_rankings( graph_results: list[dict], chroma_results: list[dict], ) -> list[dict]: """ Reciprocal Rank Fusion (RRF) of graph and semantic rankings. RRF score = sum(1 / (k + rank)) for each list the disease appears in. k=60 is the standard constant. """ K = 60 scores: dict[str, dict] = {} for rank, d in enumerate(graph_results, 1): key = str(d["orpha_code"]) if key not in scores: scores[key] = {"orpha_code": d["orpha_code"], "name": d["name"], "definition": d.get("definition", ""), "graph_rank": None, "chroma_rank": None, "graph_matches": None, "chroma_sim": None, "rrf_score": 0.0} scores[key]["rrf_score"] += 1 / (K + rank) scores[key]["graph_rank"] = rank scores[key]["graph_matches"] = d.get("match_count", 0) for rank, d in enumerate(chroma_results, 1): key = str(d["orpha_code"]) if key not in scores: scores[key] = {"orpha_code": d["orpha_code"], "name": d["name"], "definition": d.get("definition", ""), "graph_rank": None, "chroma_rank": None, "graph_matches": None, "chroma_sim": None, "rrf_score": 0.0} scores[key]["rrf_score"] += 1 / (K + rank) scores[key]["chroma_rank"] = rank scores[key]["chroma_sim"] = d.get("cosine_similarity") return sorted(scores.values(), key=lambda x: x["rrf_score"], reverse=True) # --------------------------------------------------------------------------- # Display # --------------------------------------------------------------------------- BOLD = "\033[1m" CYAN = "\033[96m" GREEN = "\033[92m" YELLOW = "\033[93m" MAGENTA= "\033[95m" DIM = "\033[2m" RESET = "\033[0m" LINE = "-" * 66 def print_section(title: str, color: str) -> None: print(f"\n{BOLD}{color}{title}{RESET}") print(LINE) def print_graph_hits(diseases: list[dict], hpo_ids: list[str], backend: str) -> None: print_section(f"[ Graph Traversal — {backend} ]", CYAN) if not diseases: print(f" {YELLOW}No graph matches. HPO IDs resolved: {hpo_ids}{RESET}") return print(f" {DIM}HPO IDs resolved: {', '.join(hpo_ids)}{RESET}\n") for rank, d in enumerate(diseases[:5], 1): mc = d.get("match_count", d.get("match_count", "?")) total = d.get("total_query_terms", len(symptoms)) print(f" {rank}. ORPHA:{d['orpha_code']} {BOLD}{d['name']}{RESET}") print(f" Phenotype matches: {mc}/{total}") matched = d.get("matched_hpo", []) if matched: terms = ", ".join(m["term"] for m in matched[:4]) print(f" {DIM}Matched: {terms}{RESET}") def print_chroma_hits(hits: list[dict], backend: str) -> None: print_section(f"[ Semantic Search — BioLORD-2023 | {backend} ]", GREEN) for rank, h in enumerate(hits[:5], 1): sim = h["cosine_similarity"] bar = "█" * int(sim * 20) + "░" * (20 - int(sim * 20)) print(f" {rank}. [{bar}] {sim:.4f} ORPHA:{h['orpha_code']} {h['name']}") def print_fused(fused: list[dict]) -> None: print_section("[ Fused Differential Diagnosis (RRF) ]", MAGENTA) print(f" {'Rank':<5} {'RRF':>6} {'Graph':>5} {'Chroma':>6} Disease") print(f" {'-'*4} {'-'*6} {'-'*5} {'-'*6} {'-'*35}") for rank, d in enumerate(fused[:10], 1): gr = f"#{d['graph_rank']}" if d["graph_rank"] else " - " cr = f"#{d['chroma_rank']}" if d["chroma_rank"] else " - " rrf = d["rrf_score"] print(f" {rank:<5} {rrf:.4f} {gr:>5} {cr:>6} {d['name']}") # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: print("=" * 66) print("RareDx — Week 2A Milestone: Symptom-to-Diagnosis") print("=" * 66) print(f"\n{BOLD}Clinical query symptoms:{RESET}") for s in symptoms: print(f" - {s}") print(f"\nLoading BioLORD-2023 ...") t0 = time.time() model = SentenceTransformer(EMBED_MODEL) print(f" Ready in {time.time()-t0:.1f}s") # Parallel: graph traversal + semantic search print("\nRunning graph traversal and semantic search in parallel...") t_start = time.time() with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: graph_fut = pool.submit(graph_search, symptoms) chroma_fut = pool.submit(chroma_search, symptoms, model, 10) graph_diseases, hpo_ids, graph_backend = graph_fut.result() chroma_hits, chroma_backend = chroma_fut.result() elapsed = time.time() - t_start print(f" Completed in {elapsed:.2f}s") # Display individual results print_graph_hits(graph_diseases, hpo_ids, graph_backend) print_chroma_hits(chroma_hits, chroma_backend) # Fuse fused = fuse_rankings(graph_diseases, chroma_hits) print_fused(fused) # Summary graph_ok = len(graph_diseases) > 0 chroma_ok = len(chroma_hits) > 0 fused_ok = len(fused) > 0 print(f"\n{LINE}") print(f"{BOLD}Week 2A Milestone Summary{RESET}") print(LINE) print(f" Graph traversal : {'OK' if graph_ok else 'MISS'} — {len(graph_diseases)} candidates — {graph_backend}") print(f" Semantic search : {'OK' if chroma_ok else 'MISS'} — {len(chroma_hits)} candidates — {chroma_backend}") print(f" Fused ranking : {'OK' if fused_ok else 'MISS'} — {len(fused)} unique candidates") print() if graph_ok and chroma_ok and fused_ok: top = fused[0] print(f" {BOLD}{GREEN}PASSED{RESET} — Top diagnosis: {top['name']} (ORPHA:{top['orpha_code']})") else: print(f" {YELLOW}PARTIAL or FAILED — check individual backends above{RESET}") sys.exit(1) print() if __name__ == "__main__": main()