raredx / backend /api /pipeline.py
Aswin92's picture
Upload folder using huggingface_hub
89c6379 verified
"""
pipeline.py
-----------
DiagnosisPipeline — the core reasoning engine for RareDx.
Shared between the FastAPI app (loaded once at startup) and the
milestone_2b.py script (instantiated directly).
Steps:
1. SymptomParser → map clinical note phrases to HPO IDs (BioLORD semantic)
2. GraphSearch → MANIFESTS_AS traversal ranked by phenotype overlap
3. ChromaSearch → BioLORD semantic search over HPO-enriched embeddings
4. RRF Fusion → merge both rankings via Reciprocal Rank Fusion
5. FusionNode → hallucination guard: flag candidates lacking evidence
"""
import os
import sys
import time
import concurrent.futures
from pathlib import Path
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
load_dotenv(Path(__file__).parents[2] / ".env")
# Ensure scripts/ and api/ are importable
SCRIPTS_DIR = Path(__file__).parents[1] / "scripts"
API_DIR = Path(__file__).parent
sys.path.insert(0, str(SCRIPTS_DIR))
sys.path.insert(0, str(API_DIR))
CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8000"))
COLLECTION_NAME = os.getenv("CHROMA_COLLECTION", "rare_diseases")
EMBED_MODEL = os.getenv("EMBED_MODEL", "FremyCompany/BioLORD-2023")
CHROMA_PERSIST = Path(__file__).parents[2] / "data" / "chromadb"
RRF_K = 60 # Standard constant for Reciprocal Rank Fusion
class DiagnosisPipeline:
"""
Initialise once per process; call .diagnose(note) for each request.
Thread-safe: graph traversal and ChromaDB query run in parallel threads.
"""
def __init__(self) -> None:
print("Initialising DiagnosisPipeline...")
# BioLORD model (shared by symptom parser + ChromaDB query)
print(" Loading BioLORD-2023...")
self.model = SentenceTransformer(EMBED_MODEL)
# SymptomParser (also builds / loads HPO index)
from symptom_parser import SymptomParser
self.symptom_parser = SymptomParser(self.model)
# ChromaDB client
self.chroma_col, self.chroma_backend = self._init_chroma()
# Graph store
from graph_store import LocalGraphStore
self.graph_store = LocalGraphStore()
self.graph_backend = "LocalGraphStore (JSON)"
# Hallucination guard
from hallucination_guard import FusionNode
self.fusion_node = FusionNode(
min_graph_matches=2,
min_vector_sim=0.65,
require_frequent_match=True,
)
print("Pipeline ready.")
# ------------------------------------------------------------------
# Initialisation helpers
# ------------------------------------------------------------------
def _init_chroma(self):
try:
client = chromadb.HttpClient(
host=CHROMA_HOST, port=CHROMA_PORT,
settings=Settings(anonymized_telemetry=False),
)
client.heartbeat()
col = client.get_collection(COLLECTION_NAME)
return col, "ChromaDB HTTP (Docker)"
except Exception:
client = chromadb.PersistentClient(
path=str(CHROMA_PERSIST),
settings=Settings(anonymized_telemetry=False),
)
col = client.get_collection(COLLECTION_NAME)
return col, "ChromaDB Embedded"
# ------------------------------------------------------------------
# Core diagnosis
# ------------------------------------------------------------------
def diagnose(
self,
note: str,
top_n: int = 10,
threshold: float = 0.55,
) -> dict:
t_start = time.time()
# Step 1: symptom parsing
self.symptom_parser.threshold = threshold
hpo_matches = self.symptom_parser.parse(note)
hpo_ids = [m.hpo_id for m in hpo_matches]
phrases = [m.phrase for m in hpo_matches]
# Steps 2 & 3: parallel graph + vector search
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
graph_fut = pool.submit(self._graph_search, hpo_ids, top_n)
chroma_fut = pool.submit(self._chroma_search, note, top_n)
graph_hits = graph_fut.result()
chroma_hits = chroma_fut.result()
# Step 4: RRF fusion
fused = self._rrf_fuse(graph_hits, chroma_hits)[:top_n]
# Step 5: Hallucination guard
passed, flagged = self.fusion_node.filter(fused, total_query_terms=len(hpo_ids))
# Top diagnosis is the highest-ranked *passed* candidate;
# fall back to highest-ranked overall if everything is flagged.
top = passed[0] if passed else (fused[0] if fused else None)
return {
"note": note,
"phrases_extracted": phrases,
"hpo_matches": [
{"phrase": m.phrase, "hpo_id": m.hpo_id,
"term": m.term, "score": m.score}
for m in hpo_matches
],
"hpo_ids_used": hpo_ids,
"candidates": fused, # all candidates, flag fields attached
"passed_candidates": passed,
"flagged_candidates": flagged,
"top_diagnosis": top,
"graph_backend": self.graph_backend,
"chroma_backend": self.chroma_backend,
"elapsed_seconds": round(time.time() - t_start, 3),
}
# ------------------------------------------------------------------
# Graph traversal
# ------------------------------------------------------------------
def _graph_search(self, hpo_ids: list[str], top_n: int) -> list[dict]:
if not hpo_ids:
return []
return self.graph_store.find_diseases_by_hpo(hpo_ids, top_n=top_n)
# ------------------------------------------------------------------
# ChromaDB semantic search
# ------------------------------------------------------------------
def _chroma_search(self, note: str, top_n: int) -> list[dict]:
emb = self.model.encode([note], normalize_embeddings=True)
results = self.chroma_col.query(
query_embeddings=emb.tolist(),
n_results=top_n,
include=["metadatas", "distances"],
)
hits = []
for meta, dist in zip(results["metadatas"][0], results["distances"][0]):
hits.append({
"orpha_code": meta.get("orpha_code"),
"name": meta.get("name"),
"definition": meta.get("definition", ""),
"cosine_similarity": round(1 - dist, 4),
})
return hits
# ------------------------------------------------------------------
# RRF fusion
# ------------------------------------------------------------------
def _rrf_fuse(
self,
graph_results: list[dict],
chroma_results: list[dict],
) -> list[dict]:
scores: dict[str, dict] = {}
for rank, d in enumerate(graph_results, 1):
key = str(d["orpha_code"])
if key not in scores:
scores[key] = self._base_entry(d)
scores[key]["rrf_score"] += 1 / (RRF_K + rank)
scores[key]["graph_rank"] = rank
scores[key]["graph_matches"] = d.get("match_count", 0)
scores[key]["matched_hpo"] = d.get("matched_hpo", [])
for rank, d in enumerate(chroma_results, 1):
key = str(d["orpha_code"])
if key not in scores:
scores[key] = self._base_entry(d)
scores[key]["rrf_score"] += 1 / (RRF_K + rank)
scores[key]["chroma_rank"] = rank
scores[key]["chroma_sim"] = d.get("cosine_similarity")
ranked = sorted(scores.values(), key=lambda x: x["rrf_score"], reverse=True)
for i, entry in enumerate(ranked, 1):
entry["rank"] = i
entry["rrf_score"] = round(entry["rrf_score"], 5)
return ranked
@staticmethod
def _base_entry(d: dict) -> dict:
return {
"rank": 0,
"orpha_code": str(d["orpha_code"]),
"name": d.get("name", ""),
"definition": d.get("definition", ""),
"rrf_score": 0.0,
"graph_rank": None,
"chroma_rank": None,
"graph_matches": None,
"chroma_sim": None,
"matched_hpo": [],
}