| """ |
| pipeline.py |
| ----------- |
| DiagnosisPipeline — the core reasoning engine for RareDx. |
| |
| Shared between the FastAPI app (loaded once at startup) and the |
| milestone_2b.py script (instantiated directly). |
| |
| Steps: |
| 1. SymptomParser → map clinical note phrases to HPO IDs (BioLORD semantic) |
| 2. GraphSearch → MANIFESTS_AS traversal ranked by phenotype overlap |
| 3. ChromaSearch → BioLORD semantic search over HPO-enriched embeddings |
| 4. RRF Fusion → merge both rankings via Reciprocal Rank Fusion |
| 5. FusionNode → hallucination guard: flag candidates lacking evidence |
| """ |
|
|
| import os |
| import sys |
| import time |
| import concurrent.futures |
| from pathlib import Path |
|
|
| import chromadb |
| from chromadb.config import Settings |
| from sentence_transformers import SentenceTransformer |
| from dotenv import load_dotenv |
|
|
| load_dotenv(Path(__file__).parents[2] / ".env") |
|
|
| |
| SCRIPTS_DIR = Path(__file__).parents[1] / "scripts" |
| API_DIR = Path(__file__).parent |
| sys.path.insert(0, str(SCRIPTS_DIR)) |
| sys.path.insert(0, str(API_DIR)) |
|
|
| CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") |
| CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000")) |
| COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases") |
| EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023") |
| CHROMA_PERSIST = Path(__file__).parents[2] / "data" / "chromadb" |
|
|
| RRF_K = 60 |
|
|
|
|
| class DiagnosisPipeline: |
| """ |
| Initialise once per process; call .diagnose(note) for each request. |
| Thread-safe: graph traversal and ChromaDB query run in parallel threads. |
| """ |
|
|
| def __init__(self) -> None: |
| print("Initialising DiagnosisPipeline...") |
|
|
| |
| print(" Loading BioLORD-2023...") |
| self.model = SentenceTransformer(EMBED_MODEL) |
|
|
| |
| from symptom_parser import SymptomParser |
| self.symptom_parser = SymptomParser(self.model) |
|
|
| |
| self.chroma_col, self.chroma_backend = self._init_chroma() |
|
|
| |
| from graph_store import LocalGraphStore |
| self.graph_store = LocalGraphStore() |
| self.graph_backend = "LocalGraphStore (JSON)" |
|
|
| |
| from hallucination_guard import FusionNode |
| self.fusion_node = FusionNode( |
| min_graph_matches=2, |
| min_vector_sim=0.65, |
| require_frequent_match=True, |
| ) |
|
|
| print("Pipeline ready.") |
|
|
| |
| |
| |
|
|
| def _init_chroma(self): |
| try: |
| client = chromadb.HttpClient( |
| host=CHROMA_HOST, port=CHROMA_PORT, |
| settings=Settings(anonymized_telemetry=False), |
| ) |
| client.heartbeat() |
| col = client.get_collection(COLLECTION_NAME) |
| return col, "ChromaDB HTTP (Docker)" |
| except Exception: |
| client = chromadb.PersistentClient( |
| path=str(CHROMA_PERSIST), |
| settings=Settings(anonymized_telemetry=False), |
| ) |
| col = client.get_collection(COLLECTION_NAME) |
| return col, "ChromaDB Embedded" |
|
|
| |
| |
| |
|
|
| def diagnose( |
| self, |
| note: str, |
| top_n: int = 10, |
| threshold: float = 0.55, |
| ) -> dict: |
| t_start = time.time() |
|
|
| |
| self.symptom_parser.threshold = threshold |
| hpo_matches = self.symptom_parser.parse(note) |
| hpo_ids = [m.hpo_id for m in hpo_matches] |
| phrases = [m.phrase for m in hpo_matches] |
|
|
| |
| with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: |
| graph_fut = pool.submit(self._graph_search, hpo_ids, top_n) |
| chroma_fut = pool.submit(self._chroma_search, note, top_n) |
| graph_hits = graph_fut.result() |
| chroma_hits = chroma_fut.result() |
|
|
| |
| fused = self._rrf_fuse(graph_hits, chroma_hits)[:top_n] |
|
|
| |
| passed, flagged = self.fusion_node.filter(fused, total_query_terms=len(hpo_ids)) |
|
|
| |
| |
| top = passed[0] if passed else (fused[0] if fused else None) |
|
|
| return { |
| "note": note, |
| "phrases_extracted": phrases, |
| "hpo_matches": [ |
| {"phrase": m.phrase, "hpo_id": m.hpo_id, |
| "term": m.term, "score": m.score} |
| for m in hpo_matches |
| ], |
| "hpo_ids_used": hpo_ids, |
| "candidates": fused, |
| "passed_candidates": passed, |
| "flagged_candidates": flagged, |
| "top_diagnosis": top, |
| "graph_backend": self.graph_backend, |
| "chroma_backend": self.chroma_backend, |
| "elapsed_seconds": round(time.time() - t_start, 3), |
| } |
|
|
| |
| |
| |
|
|
| def _graph_search(self, hpo_ids: list[str], top_n: int) -> list[dict]: |
| if not hpo_ids: |
| return [] |
| return self.graph_store.find_diseases_by_hpo(hpo_ids, top_n=top_n) |
|
|
| |
| |
| |
|
|
| def _chroma_search(self, note: str, top_n: int) -> list[dict]: |
| emb = self.model.encode([note], normalize_embeddings=True) |
| results = self.chroma_col.query( |
| query_embeddings=emb.tolist(), |
| n_results=top_n, |
| include=["metadatas", "distances"], |
| ) |
| hits = [] |
| for meta, dist in zip(results["metadatas"][0], results["distances"][0]): |
| hits.append({ |
| "orpha_code": meta.get("orpha_code"), |
| "name": meta.get("name"), |
| "definition": meta.get("definition", ""), |
| "cosine_similarity": round(1 - dist, 4), |
| }) |
| return hits |
|
|
| |
| |
| |
|
|
| def _rrf_fuse( |
| self, |
| graph_results: list[dict], |
| chroma_results: list[dict], |
| ) -> list[dict]: |
| scores: dict[str, dict] = {} |
|
|
| for rank, d in enumerate(graph_results, 1): |
| key = str(d["orpha_code"]) |
| if key not in scores: |
| scores[key] = self._base_entry(d) |
| scores[key]["rrf_score"] += 1 / (RRF_K + rank) |
| scores[key]["graph_rank"] = rank |
| scores[key]["graph_matches"] = d.get("match_count", 0) |
| scores[key]["matched_hpo"] = d.get("matched_hpo", []) |
|
|
| for rank, d in enumerate(chroma_results, 1): |
| key = str(d["orpha_code"]) |
| if key not in scores: |
| scores[key] = self._base_entry(d) |
| scores[key]["rrf_score"] += 1 / (RRF_K + rank) |
| scores[key]["chroma_rank"] = rank |
| scores[key]["chroma_sim"] = d.get("cosine_similarity") |
|
|
| ranked = sorted(scores.values(), key=lambda x: x["rrf_score"], reverse=True) |
|
|
| for i, entry in enumerate(ranked, 1): |
| entry["rank"] = i |
| entry["rrf_score"] = round(entry["rrf_score"], 5) |
| return ranked |
|
|
| @staticmethod |
| def _base_entry(d: dict) -> dict: |
| return { |
| "rank": 0, |
| "orpha_code": str(d["orpha_code"]), |
| "name": d.get("name", ""), |
| "definition": d.get("definition", ""), |
| "rrf_score": 0.0, |
| "graph_rank": None, |
| "chroma_rank": None, |
| "graph_matches": None, |
| "chroma_sim": None, |
| "matched_hpo": [], |
| } |
|
|