""" pipeline.py ----------- DiagnosisPipeline — the core reasoning engine for RareDx. Shared between the FastAPI app (loaded once at startup) and the milestone_2b.py script (instantiated directly). Steps: 1. SymptomParser → map clinical note phrases to HPO IDs (BioLORD semantic) 2. GraphSearch → MANIFESTS_AS traversal ranked by phenotype overlap 3. ChromaSearch → BioLORD semantic search over HPO-enriched embeddings 4. RRF Fusion → merge both rankings via Reciprocal Rank Fusion 5. FusionNode → hallucination guard: flag candidates lacking evidence """ import os import sys import time import concurrent.futures from pathlib import Path import chromadb from chromadb.config import Settings from sentence_transformers import SentenceTransformer from dotenv import load_dotenv load_dotenv(Path(__file__).parents[2] / ".env") # Ensure scripts/ and api/ are importable SCRIPTS_DIR = Path(__file__).parents[1] / "scripts" API_DIR = Path(__file__).parent sys.path.insert(0, str(SCRIPTS_DIR)) sys.path.insert(0, str(API_DIR)) CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000")) COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases") EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023") CHROMA_PERSIST = Path(__file__).parents[2] / "data" / "chromadb" RRF_K = 60 # Standard constant for Reciprocal Rank Fusion class DiagnosisPipeline: """ Initialise once per process; call .diagnose(note) for each request. Thread-safe: graph traversal and ChromaDB query run in parallel threads. """ def __init__(self) -> None: print("Initialising DiagnosisPipeline...") # BioLORD model (shared by symptom parser + ChromaDB query) print(" Loading BioLORD-2023...") self.model = SentenceTransformer(EMBED_MODEL) # SymptomParser (also builds / loads HPO index) from symptom_parser import SymptomParser self.symptom_parser = SymptomParser(self.model) # ChromaDB client self.chroma_col, self.chroma_backend = self._init_chroma() # Graph store from graph_store import LocalGraphStore self.graph_store = LocalGraphStore() self.graph_backend = "LocalGraphStore (JSON)" # Hallucination guard from hallucination_guard import FusionNode self.fusion_node = FusionNode( min_graph_matches=2, min_vector_sim=0.65, require_frequent_match=True, ) print("Pipeline ready.") # ------------------------------------------------------------------ # Initialisation helpers # ------------------------------------------------------------------ def _init_chroma(self): try: client = chromadb.HttpClient( host=CHROMA_HOST, port=CHROMA_PORT, settings=Settings(anonymized_telemetry=False), ) client.heartbeat() col = client.get_collection(COLLECTION_NAME) return col, "ChromaDB HTTP (Docker)" except Exception: client = chromadb.PersistentClient( path=str(CHROMA_PERSIST), settings=Settings(anonymized_telemetry=False), ) col = client.get_collection(COLLECTION_NAME) return col, "ChromaDB Embedded" # ------------------------------------------------------------------ # Core diagnosis # ------------------------------------------------------------------ def diagnose( self, note: str, top_n: int = 10, threshold: float = 0.55, ) -> dict: t_start = time.time() # Step 1: symptom parsing self.symptom_parser.threshold = threshold hpo_matches = self.symptom_parser.parse(note) hpo_ids = [m.hpo_id for m in hpo_matches] phrases = [m.phrase for m in hpo_matches] # Steps 2 & 3: parallel graph + vector search with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: graph_fut = pool.submit(self._graph_search, hpo_ids, top_n) chroma_fut = pool.submit(self._chroma_search, note, top_n) graph_hits = graph_fut.result() chroma_hits = chroma_fut.result() # Step 4: RRF fusion fused = self._rrf_fuse(graph_hits, chroma_hits)[:top_n] # Step 5: Hallucination guard passed, flagged = self.fusion_node.filter(fused, total_query_terms=len(hpo_ids)) # Top diagnosis is the highest-ranked *passed* candidate; # fall back to highest-ranked overall if everything is flagged. top = passed[0] if passed else (fused[0] if fused else None) return { "note": note, "phrases_extracted": phrases, "hpo_matches": [ {"phrase": m.phrase, "hpo_id": m.hpo_id, "term": m.term, "score": m.score} for m in hpo_matches ], "hpo_ids_used": hpo_ids, "candidates": fused, # all candidates, flag fields attached "passed_candidates": passed, "flagged_candidates": flagged, "top_diagnosis": top, "graph_backend": self.graph_backend, "chroma_backend": self.chroma_backend, "elapsed_seconds": round(time.time() - t_start, 3), } # ------------------------------------------------------------------ # Graph traversal # ------------------------------------------------------------------ def _graph_search(self, hpo_ids: list[str], top_n: int) -> list[dict]: if not hpo_ids: return [] return self.graph_store.find_diseases_by_hpo(hpo_ids, top_n=top_n) # ------------------------------------------------------------------ # ChromaDB semantic search # ------------------------------------------------------------------ def _chroma_search(self, note: str, top_n: int) -> list[dict]: emb = self.model.encode([note], normalize_embeddings=True) results = self.chroma_col.query( query_embeddings=emb.tolist(), n_results=top_n, include=["metadatas", "distances"], ) hits = [] for meta, dist in zip(results["metadatas"][0], results["distances"][0]): hits.append({ "orpha_code": meta.get("orpha_code"), "name": meta.get("name"), "definition": meta.get("definition", ""), "cosine_similarity": round(1 - dist, 4), }) return hits # ------------------------------------------------------------------ # RRF fusion # ------------------------------------------------------------------ def _rrf_fuse( self, graph_results: list[dict], chroma_results: list[dict], ) -> list[dict]: scores: dict[str, dict] = {} for rank, d in enumerate(graph_results, 1): key = str(d["orpha_code"]) if key not in scores: scores[key] = self._base_entry(d) scores[key]["rrf_score"] += 1 / (RRF_K + rank) scores[key]["graph_rank"] = rank scores[key]["graph_matches"] = d.get("match_count", 0) scores[key]["matched_hpo"] = d.get("matched_hpo", []) for rank, d in enumerate(chroma_results, 1): key = str(d["orpha_code"]) if key not in scores: scores[key] = self._base_entry(d) scores[key]["rrf_score"] += 1 / (RRF_K + rank) scores[key]["chroma_rank"] = rank scores[key]["chroma_sim"] = d.get("cosine_similarity") ranked = sorted(scores.values(), key=lambda x: x["rrf_score"], reverse=True) for i, entry in enumerate(ranked, 1): entry["rank"] = i entry["rrf_score"] = round(entry["rrf_score"], 5) return ranked @staticmethod def _base_entry(d: dict) -> dict: return { "rank": 0, "orpha_code": str(d["orpha_code"]), "name": d.get("name", ""), "definition": d.get("definition", ""), "rrf_score": 0.0, "graph_rank": None, "chroma_rank": None, "graph_matches": None, "chroma_sim": None, "matched_hpo": [], }