raredx / backend /scripts /milestone_2a.py
Aswin92's picture
Upload folder using huggingface_hub
89c6379 verified
"""
milestone_2a.py
---------------
Week 2A Milestone: Symptom-to-candidate-disease via graph phenotype matching.
Given a list of clinical symptoms, this script:
1. Maps symptoms to HPO term IDs via the graph (HPOTerm name search)
2. Runs the MANIFESTS_AS graph traversal to find matching diseases
3. Runs BioLORD-2023 semantic search in ChromaDB in parallel
4. Merges both rankings into a unified differential diagnosis list
This is the first real diagnostic query in RareDx.
Usage:
python milestone_2a.py
python milestone_2a.py "arachnodactyly" "ectopia lentis" "aortic dilation"
"""
import io
import os
import sys
import time
import concurrent.futures
from pathlib import Path
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
# UTF-8 output for Windows
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
load_dotenv(Path(__file__).parents[2] / ".env")
CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000"))
COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases")
EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
CHROMA_PERSIST_DIR = Path(__file__).parents[2] / "data" / "chromadb"
# Default test case: classic Marfan syndrome presentation
DEFAULT_SYMPTOMS = [
"arachnodactyly",
"ectopia lentis",
"aortic root dilatation",
"scoliosis",
"tall stature",
]
symptoms = sys.argv[1:] if len(sys.argv) > 1 else DEFAULT_SYMPTOMS
# ---------------------------------------------------------------------------
# Graph query
# ---------------------------------------------------------------------------
def graph_search(symptom_list: list[str]) -> tuple[list[dict], list[str], str]:
"""
Returns (ranked_diseases, resolved_hpo_ids, backend_label).
Tries Neo4j first, then LocalGraphStore.
"""
# Try Neo4j
try:
from neo4j import GraphDatabase
from neo4j import GraphDatabase as gdb
neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
neo4j_user = os.getenv("NEO4J_USER", "neo4j")
neo4j_pass = os.getenv("NEO4J_PASSWORD", "raredx_password")
driver = gdb.driver(neo4j_uri, auth=(neo4j_user, neo4j_pass))
driver.verify_connectivity()
with driver.session() as session:
# Resolve symptoms to HPO IDs
hpo_ids = []
for sym in symptom_list:
r = session.run(
"MATCH (h:HPOTerm) WHERE toLower(h.term) CONTAINS toLower($s) "
"RETURN h.hpo_id AS hid LIMIT 1",
s=sym,
)
rec = r.single()
if rec:
hpo_ids.append(rec["hid"])
if not hpo_ids:
driver.close()
return [], [], "Neo4j (Docker)"
# Graph traversal
result = session.run(
"""
UNWIND $hpo_ids AS hid
MATCH (d:Disease)-[r:MANIFESTS_AS]->(h:HPOTerm {hpo_id: hid})
WHERE r.frequency_order <> 5
WITH d, count(h) AS match_count,
sum(CASE r.frequency_order
WHEN 1 THEN 5 WHEN 2 THEN 4
WHEN 3 THEN 3 WHEN 4 THEN 2
ELSE 1 END) AS freq_score,
collect({hpo_id: h.hpo_id, term: h.term,
freq: r.frequency_label}) AS matched_hpo
WHERE match_count >= 1
RETURN d.orpha_code AS orpha_code, d.name AS name,
d.definition AS definition,
match_count, freq_score, matched_hpo
ORDER BY match_count DESC, freq_score DESC
LIMIT 10
""",
hpo_ids=hpo_ids,
)
diseases = [dict(r) for r in result]
driver.close()
return diseases, hpo_ids, "Neo4j (Docker)"
except Exception:
pass
# LocalGraphStore fallback
from graph_store import LocalGraphStore
store = LocalGraphStore()
# Resolve symptom strings to HPO IDs
hpo_ids = []
for sym in symptom_list:
sym_lower = sym.lower()
for nid, attrs in store.graph.nodes(data=True):
if attrs.get("type") == "HPOTerm":
if sym_lower in attrs.get("term", "").lower():
hpo_ids.append(attrs["hpo_id"])
break
diseases = store.find_diseases_by_hpo(hpo_ids, top_n=10)
return diseases, hpo_ids, "LocalGraphStore (JSON)"
# ---------------------------------------------------------------------------
# ChromaDB semantic search
# ---------------------------------------------------------------------------
def chroma_search(
symptom_list: list[str],
model: SentenceTransformer,
n: int = 10,
) -> tuple[list[dict], str]:
"""Embed symptom list as a clinical query and search ChromaDB."""
query = "Patient presents with: " + ", ".join(symptom_list) + "."
try:
client = chromadb.HttpClient(
host=CHROMA_HOST,
port=CHROMA_PORT,
settings=Settings(anonymized_telemetry=False),
)
client.heartbeat()
backend = "ChromaDB HTTP"
except Exception:
client = chromadb.PersistentClient(
path=str(CHROMA_PERSIST_DIR),
settings=Settings(anonymized_telemetry=False),
)
backend = "ChromaDB Embedded"
collection = client.get_collection(COLLECTION_NAME)
embedding = model.encode([query], normalize_embeddings=True)
results = collection.query(
query_embeddings=embedding.tolist(),
n_results=n,
include=["metadatas", "distances"],
)
hits = []
for meta, dist in zip(results["metadatas"][0], results["distances"][0]):
hits.append({
"orpha_code": meta.get("orpha_code"),
"name": meta.get("name"),
"definition": meta.get("definition", ""),
"cosine_similarity": round(1 - dist, 4),
})
return hits, backend
# ---------------------------------------------------------------------------
# Score fusion
# ---------------------------------------------------------------------------
def fuse_rankings(
graph_results: list[dict],
chroma_results: list[dict],
) -> list[dict]:
"""
Reciprocal Rank Fusion (RRF) of graph and semantic rankings.
RRF score = sum(1 / (k + rank)) for each list the disease appears in.
k=60 is the standard constant.
"""
K = 60
scores: dict[str, dict] = {}
for rank, d in enumerate(graph_results, 1):
key = str(d["orpha_code"])
if key not in scores:
scores[key] = {"orpha_code": d["orpha_code"], "name": d["name"],
"definition": d.get("definition", ""),
"graph_rank": None, "chroma_rank": None,
"graph_matches": None, "chroma_sim": None,
"rrf_score": 0.0}
scores[key]["rrf_score"] += 1 / (K + rank)
scores[key]["graph_rank"] = rank
scores[key]["graph_matches"] = d.get("match_count", 0)
for rank, d in enumerate(chroma_results, 1):
key = str(d["orpha_code"])
if key not in scores:
scores[key] = {"orpha_code": d["orpha_code"], "name": d["name"],
"definition": d.get("definition", ""),
"graph_rank": None, "chroma_rank": None,
"graph_matches": None, "chroma_sim": None,
"rrf_score": 0.0}
scores[key]["rrf_score"] += 1 / (K + rank)
scores[key]["chroma_rank"] = rank
scores[key]["chroma_sim"] = d.get("cosine_similarity")
return sorted(scores.values(), key=lambda x: x["rrf_score"], reverse=True)
# ---------------------------------------------------------------------------
# Display
# ---------------------------------------------------------------------------
BOLD = "\033[1m"
CYAN = "\033[96m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
MAGENTA= "\033[95m"
DIM = "\033[2m"
RESET = "\033[0m"
LINE = "-" * 66
def print_section(title: str, color: str) -> None:
print(f"\n{BOLD}{color}{title}{RESET}")
print(LINE)
def print_graph_hits(diseases: list[dict], hpo_ids: list[str], backend: str) -> None:
print_section(f"[ Graph Traversal — {backend} ]", CYAN)
if not diseases:
print(f" {YELLOW}No graph matches. HPO IDs resolved: {hpo_ids}{RESET}")
return
print(f" {DIM}HPO IDs resolved: {', '.join(hpo_ids)}{RESET}\n")
for rank, d in enumerate(diseases[:5], 1):
mc = d.get("match_count", d.get("match_count", "?"))
total = d.get("total_query_terms", len(symptoms))
print(f" {rank}. ORPHA:{d['orpha_code']} {BOLD}{d['name']}{RESET}")
print(f" Phenotype matches: {mc}/{total}")
matched = d.get("matched_hpo", [])
if matched:
terms = ", ".join(m["term"] for m in matched[:4])
print(f" {DIM}Matched: {terms}{RESET}")
def print_chroma_hits(hits: list[dict], backend: str) -> None:
print_section(f"[ Semantic Search — BioLORD-2023 | {backend} ]", GREEN)
for rank, h in enumerate(hits[:5], 1):
sim = h["cosine_similarity"]
bar = "█" * int(sim * 20) + "░" * (20 - int(sim * 20))
print(f" {rank}. [{bar}] {sim:.4f} ORPHA:{h['orpha_code']} {h['name']}")
def print_fused(fused: list[dict]) -> None:
print_section("[ Fused Differential Diagnosis (RRF) ]", MAGENTA)
print(f" {'Rank':<5} {'RRF':>6} {'Graph':>5} {'Chroma':>6} Disease")
print(f" {'-'*4} {'-'*6} {'-'*5} {'-'*6} {'-'*35}")
for rank, d in enumerate(fused[:10], 1):
gr = f"#{d['graph_rank']}" if d["graph_rank"] else " - "
cr = f"#{d['chroma_rank']}" if d["chroma_rank"] else " - "
rrf = d["rrf_score"]
print(f" {rank:<5} {rrf:.4f} {gr:>5} {cr:>6} {d['name']}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
print("=" * 66)
print("RareDx — Week 2A Milestone: Symptom-to-Diagnosis")
print("=" * 66)
print(f"\n{BOLD}Clinical query symptoms:{RESET}")
for s in symptoms:
print(f" - {s}")
print(f"\nLoading BioLORD-2023 ...")
t0 = time.time()
model = SentenceTransformer(EMBED_MODEL)
print(f" Ready in {time.time()-t0:.1f}s")
# Parallel: graph traversal + semantic search
print("\nRunning graph traversal and semantic search in parallel...")
t_start = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
graph_fut = pool.submit(graph_search, symptoms)
chroma_fut = pool.submit(chroma_search, symptoms, model, 10)
graph_diseases, hpo_ids, graph_backend = graph_fut.result()
chroma_hits, chroma_backend = chroma_fut.result()
elapsed = time.time() - t_start
print(f" Completed in {elapsed:.2f}s")
# Display individual results
print_graph_hits(graph_diseases, hpo_ids, graph_backend)
print_chroma_hits(chroma_hits, chroma_backend)
# Fuse
fused = fuse_rankings(graph_diseases, chroma_hits)
print_fused(fused)
# Summary
graph_ok = len(graph_diseases) > 0
chroma_ok = len(chroma_hits) > 0
fused_ok = len(fused) > 0
print(f"\n{LINE}")
print(f"{BOLD}Week 2A Milestone Summary{RESET}")
print(LINE)
print(f" Graph traversal : {'OK' if graph_ok else 'MISS'}{len(graph_diseases)} candidates — {graph_backend}")
print(f" Semantic search : {'OK' if chroma_ok else 'MISS'}{len(chroma_hits)} candidates — {chroma_backend}")
print(f" Fused ranking : {'OK' if fused_ok else 'MISS'}{len(fused)} unique candidates")
print()
if graph_ok and chroma_ok and fused_ok:
top = fused[0]
print(f" {BOLD}{GREEN}PASSED{RESET} — Top diagnosis: {top['name']} (ORPHA:{top['orpha_code']})")
else:
print(f" {YELLOW}PARTIAL or FAILED — check individual backends above{RESET}")
sys.exit(1)
print()
if __name__ == "__main__":
main()