File size: 12,513 Bytes
89c6379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
"""
milestone_2a.py
---------------
Week 2A Milestone: Symptom-to-candidate-disease via graph phenotype matching.

Given a list of clinical symptoms, this script:
  1. Maps symptoms to HPO term IDs via the graph (HPOTerm name search)
  2. Runs the MANIFESTS_AS graph traversal to find matching diseases
  3. Runs BioLORD-2023 semantic search in ChromaDB in parallel
  4. Merges both rankings into a unified differential diagnosis list

This is the first real diagnostic query in RareDx.

Usage:
  python milestone_2a.py
  python milestone_2a.py "arachnodactyly" "ectopia lentis" "aortic dilation"
"""

import io
import os
import sys
import time
import concurrent.futures
from pathlib import Path

import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv

# UTF-8 output for Windows
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")

load_dotenv(Path(__file__).parents[2] / ".env")

CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000"))
COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases")
EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
CHROMA_PERSIST_DIR = Path(__file__).parents[2] / "data" / "chromadb"

# Default test case: classic Marfan syndrome presentation
DEFAULT_SYMPTOMS = [
    "arachnodactyly",
    "ectopia lentis",
    "aortic root dilatation",
    "scoliosis",
    "tall stature",
]

symptoms = sys.argv[1:] if len(sys.argv) > 1 else DEFAULT_SYMPTOMS


# ---------------------------------------------------------------------------
# Graph query
# ---------------------------------------------------------------------------

def graph_search(symptom_list: list[str]) -> tuple[list[dict], list[str], str]:
    """
    Returns (ranked_diseases, resolved_hpo_ids, backend_label).
    Tries Neo4j first, then LocalGraphStore.
    """
    # Try Neo4j
    try:
        from neo4j import GraphDatabase
        from neo4j import GraphDatabase as gdb
        neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
        neo4j_user = os.getenv("NEO4J_USER", "neo4j")
        neo4j_pass = os.getenv("NEO4J_PASSWORD", "raredx_password")
        driver = gdb.driver(neo4j_uri, auth=(neo4j_user, neo4j_pass))
        driver.verify_connectivity()

        with driver.session() as session:
            # Resolve symptoms to HPO IDs
            hpo_ids = []
            for sym in symptom_list:
                r = session.run(
                    "MATCH (h:HPOTerm) WHERE toLower(h.term) CONTAINS toLower($s) "
                    "RETURN h.hpo_id AS hid LIMIT 1",
                    s=sym,
                )
                rec = r.single()
                if rec:
                    hpo_ids.append(rec["hid"])

            if not hpo_ids:
                driver.close()
                return [], [], "Neo4j (Docker)"

            # Graph traversal
            result = session.run(
                """
                UNWIND $hpo_ids AS hid
                MATCH (d:Disease)-[r:MANIFESTS_AS]->(h:HPOTerm {hpo_id: hid})
                WHERE r.frequency_order <> 5
                WITH d, count(h) AS match_count,
                     sum(CASE r.frequency_order
                           WHEN 1 THEN 5 WHEN 2 THEN 4
                           WHEN 3 THEN 3 WHEN 4 THEN 2
                           ELSE 1 END) AS freq_score,
                     collect({hpo_id: h.hpo_id, term: h.term,
                              freq: r.frequency_label}) AS matched_hpo
                WHERE match_count >= 1
                RETURN d.orpha_code AS orpha_code, d.name AS name,
                       d.definition AS definition,
                       match_count, freq_score, matched_hpo
                ORDER BY match_count DESC, freq_score DESC
                LIMIT 10
                """,
                hpo_ids=hpo_ids,
            )
            diseases = [dict(r) for r in result]

        driver.close()
        return diseases, hpo_ids, "Neo4j (Docker)"

    except Exception:
        pass

    # LocalGraphStore fallback
    from graph_store import LocalGraphStore
    store = LocalGraphStore()

    # Resolve symptom strings to HPO IDs
    hpo_ids = []
    for sym in symptom_list:
        sym_lower = sym.lower()
        for nid, attrs in store.graph.nodes(data=True):
            if attrs.get("type") == "HPOTerm":
                if sym_lower in attrs.get("term", "").lower():
                    hpo_ids.append(attrs["hpo_id"])
                    break

    diseases = store.find_diseases_by_hpo(hpo_ids, top_n=10)
    return diseases, hpo_ids, "LocalGraphStore (JSON)"


# ---------------------------------------------------------------------------
# ChromaDB semantic search
# ---------------------------------------------------------------------------

def chroma_search(
    symptom_list: list[str],
    model: SentenceTransformer,
    n: int = 10,
) -> tuple[list[dict], str]:
    """Embed symptom list as a clinical query and search ChromaDB."""
    query = "Patient presents with: " + ", ".join(symptom_list) + "."

    try:
        client = chromadb.HttpClient(
            host=CHROMA_HOST,
            port=CHROMA_PORT,
            settings=Settings(anonymized_telemetry=False),
        )
        client.heartbeat()
        backend = "ChromaDB HTTP"
    except Exception:
        client = chromadb.PersistentClient(
            path=str(CHROMA_PERSIST_DIR),
            settings=Settings(anonymized_telemetry=False),
        )
        backend = "ChromaDB Embedded"

    collection = client.get_collection(COLLECTION_NAME)
    embedding = model.encode([query], normalize_embeddings=True)
    results = collection.query(
        query_embeddings=embedding.tolist(),
        n_results=n,
        include=["metadatas", "distances"],
    )

    hits = []
    for meta, dist in zip(results["metadatas"][0], results["distances"][0]):
        hits.append({
            "orpha_code": meta.get("orpha_code"),
            "name": meta.get("name"),
            "definition": meta.get("definition", ""),
            "cosine_similarity": round(1 - dist, 4),
        })
    return hits, backend


# ---------------------------------------------------------------------------
# Score fusion
# ---------------------------------------------------------------------------

def fuse_rankings(
    graph_results: list[dict],
    chroma_results: list[dict],
) -> list[dict]:
    """
    Reciprocal Rank Fusion (RRF) of graph and semantic rankings.
    RRF score = sum(1 / (k + rank)) for each list the disease appears in.
    k=60 is the standard constant.
    """
    K = 60
    scores: dict[str, dict] = {}

    for rank, d in enumerate(graph_results, 1):
        key = str(d["orpha_code"])
        if key not in scores:
            scores[key] = {"orpha_code": d["orpha_code"], "name": d["name"],
                           "definition": d.get("definition", ""),
                           "graph_rank": None, "chroma_rank": None,
                           "graph_matches": None, "chroma_sim": None,
                           "rrf_score": 0.0}
        scores[key]["rrf_score"]    += 1 / (K + rank)
        scores[key]["graph_rank"]    = rank
        scores[key]["graph_matches"] = d.get("match_count", 0)

    for rank, d in enumerate(chroma_results, 1):
        key = str(d["orpha_code"])
        if key not in scores:
            scores[key] = {"orpha_code": d["orpha_code"], "name": d["name"],
                           "definition": d.get("definition", ""),
                           "graph_rank": None, "chroma_rank": None,
                           "graph_matches": None, "chroma_sim": None,
                           "rrf_score": 0.0}
        scores[key]["rrf_score"]  += 1 / (K + rank)
        scores[key]["chroma_rank"] = rank
        scores[key]["chroma_sim"]  = d.get("cosine_similarity")

    return sorted(scores.values(), key=lambda x: x["rrf_score"], reverse=True)


# ---------------------------------------------------------------------------
# Display
# ---------------------------------------------------------------------------

BOLD   = "\033[1m"
CYAN   = "\033[96m"
GREEN  = "\033[92m"
YELLOW = "\033[93m"
MAGENTA= "\033[95m"
DIM    = "\033[2m"
RESET  = "\033[0m"
LINE   = "-" * 66


def print_section(title: str, color: str) -> None:
    print(f"\n{BOLD}{color}{title}{RESET}")
    print(LINE)


def print_graph_hits(diseases: list[dict], hpo_ids: list[str], backend: str) -> None:
    print_section(f"[ Graph Traversal — {backend} ]", CYAN)
    if not diseases:
        print(f"  {YELLOW}No graph matches. HPO IDs resolved: {hpo_ids}{RESET}")
        return
    print(f"  {DIM}HPO IDs resolved: {', '.join(hpo_ids)}{RESET}\n")
    for rank, d in enumerate(diseases[:5], 1):
        mc = d.get("match_count", d.get("match_count", "?"))
        total = d.get("total_query_terms", len(symptoms))
        print(f"  {rank}. ORPHA:{d['orpha_code']}  {BOLD}{d['name']}{RESET}")
        print(f"     Phenotype matches: {mc}/{total}")
        matched = d.get("matched_hpo", [])
        if matched:
            terms = ", ".join(m["term"] for m in matched[:4])
            print(f"     {DIM}Matched: {terms}{RESET}")


def print_chroma_hits(hits: list[dict], backend: str) -> None:
    print_section(f"[ Semantic Search — BioLORD-2023 | {backend} ]", GREEN)
    for rank, h in enumerate(hits[:5], 1):
        sim = h["cosine_similarity"]
        bar = "█" * int(sim * 20) + "░" * (20 - int(sim * 20))
        print(f"  {rank}. [{bar}] {sim:.4f}  ORPHA:{h['orpha_code']}  {h['name']}")


def print_fused(fused: list[dict]) -> None:
    print_section("[ Fused Differential Diagnosis (RRF) ]", MAGENTA)
    print(f"  {'Rank':<5} {'RRF':>6}  {'Graph':>5}  {'Chroma':>6}  Disease")
    print(f"  {'-'*4}  {'-'*6}  {'-'*5}  {'-'*6}  {'-'*35}")
    for rank, d in enumerate(fused[:10], 1):
        gr  = f"#{d['graph_rank']}"  if d["graph_rank"]  else "  -  "
        cr  = f"#{d['chroma_rank']}" if d["chroma_rank"] else "  -  "
        rrf = d["rrf_score"]
        print(f"  {rank:<5} {rrf:.4f}  {gr:>5}  {cr:>6}  {d['name']}")


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> None:
    print("=" * 66)
    print("RareDx — Week 2A Milestone: Symptom-to-Diagnosis")
    print("=" * 66)
    print(f"\n{BOLD}Clinical query symptoms:{RESET}")
    for s in symptoms:
        print(f"  - {s}")

    print(f"\nLoading BioLORD-2023 ...")
    t0 = time.time()
    model = SentenceTransformer(EMBED_MODEL)
    print(f"  Ready in {time.time()-t0:.1f}s")

    # Parallel: graph traversal + semantic search
    print("\nRunning graph traversal and semantic search in parallel...")
    t_start = time.time()
    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
        graph_fut  = pool.submit(graph_search, symptoms)
        chroma_fut = pool.submit(chroma_search, symptoms, model, 10)

        graph_diseases, hpo_ids, graph_backend = graph_fut.result()
        chroma_hits, chroma_backend            = chroma_fut.result()

    elapsed = time.time() - t_start
    print(f"  Completed in {elapsed:.2f}s")

    # Display individual results
    print_graph_hits(graph_diseases, hpo_ids, graph_backend)
    print_chroma_hits(chroma_hits, chroma_backend)

    # Fuse
    fused = fuse_rankings(graph_diseases, chroma_hits)
    print_fused(fused)

    # Summary
    graph_ok  = len(graph_diseases) > 0
    chroma_ok = len(chroma_hits) > 0
    fused_ok  = len(fused) > 0

    print(f"\n{LINE}")
    print(f"{BOLD}Week 2A Milestone Summary{RESET}")
    print(LINE)
    print(f"  Graph traversal : {'OK' if graph_ok  else 'MISS'}{len(graph_diseases)} candidates — {graph_backend}")
    print(f"  Semantic search : {'OK' if chroma_ok else 'MISS'}{len(chroma_hits)} candidates — {chroma_backend}")
    print(f"  Fused ranking   : {'OK' if fused_ok  else 'MISS'}{len(fused)} unique candidates")
    print()

    if graph_ok and chroma_ok and fused_ok:
        top = fused[0]
        print(f"  {BOLD}{GREEN}PASSED{RESET} — Top diagnosis: {top['name']} (ORPHA:{top['orpha_code']})")
    else:
        print(f"  {YELLOW}PARTIAL or FAILED — check individual backends above{RESET}")
        sys.exit(1)
    print()


if __name__ == "__main__":
    main()