suryaprakash01's picture
Update src/retriever.py
5dec856 verified
#!/usr/bin/env python3
"""
Complete Medical RAG Pipeline
Query → Hybrid Retrieval → Cross-Encoder Rerank → Gemini Answer
Supports:
- Global RAG (all reports)
- Report-specific RAG via metadata filtering (dropdown-compatible)
"""
import os
import re
import pickle
import numpy as np
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Optional
import time
# ============================
# CONFIG
# ============================
DEFAULT_TOP_K = 5
# ============================
# Embeddings & Reranking
# ============================
from sentence_transformers import SentenceTransformer, CrossEncoder
# ============================
# FAISS
# ============================
import faiss
# ============================
# BM25
# ============================
from rank_bm25 import BM25Okapi
# ============================
# Gemini
# ============================
from google import genai
# =========================================================
# QUERY PROCESSOR
# =========================================================
class MedicalQueryProcessor:
def __init__(self, embedding_model: str):
print(f"Loading embedding model: {embedding_model}")
self.model = SentenceTransformer(embedding_model)
self.dim = self.model.get_sentence_embedding_dimension()
print(f"Embedding dimension: {self.dim}")
def extract_keywords(self, query: str) -> List[str]:
patterns = [
r"\b(cancer|carcinoma|tumor|neoplasm)\b",
r"\b(ER|PR|HER2)\b",
r"\b(stage\s*[IVX]+)\b",
r"\b(grade\s*[123])\b",
r"\b(lymph\s*node)\b",
]
found = []
for p in patterns:
found.extend(re.findall(p, query, flags=re.I))
return list(set(found))
def embed(self, text: str) -> np.ndarray:
return self.model.encode(text, normalize_embeddings=True)
def process(self, query: str) -> Dict:
return {
"query": query,
"keywords": self.extract_keywords(query),
"embedding": self.embed(query),
}
# =========================================================
# HYBRID RETRIEVER
# =========================================================
class HybridRetriever:
def __init__(self, faiss_db_path: str):
db = Path(faiss_db_path)
print(f"Loading FAISS index from: {db}")
self.index = faiss.read_index(str(db / "faiss.index"))
with open(db / "metadata.pkl", "rb") as f:
data = pickle.load(f)
self.chunks = data["chunks"]
print(f"Loaded {len(self.chunks)} chunks")
tokenized = [c["text"].lower().split() for c in self.chunks]
self.bm25 = BM25Okapi(tokenized)
def get_available_reports(self) -> List[str]:
return sorted({c["filename"] for c in self.chunks})
def search(
self,
query_embedding: np.ndarray,
query_text: str,
top_k: int = 40,
) -> List[Dict]:
distances, indices = self.index.search(
query_embedding.reshape(1, -1).astype("float32"),
top_k,
)
results = {}
for idx, dist in zip(indices[0], distances[0]):
if idx >= 0:
results[idx] = 1 - float(dist)
ranked = sorted(results.items(), key=lambda x: x[1], reverse=True)
return [
{
"chunk": self.chunks[idx],
"score": score,
}
for idx, score in ranked
]
# =========================================================
# CROSS-ENCODER RERANKER
# =========================================================
class MedicalReranker:
def __init__(self):
print("Loading cross-encoder...")
self.model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
print("Cross-encoder loaded")
def rerank(
self,
query: str,
candidates: List[Dict],
top_k: int = DEFAULT_TOP_K,
) -> List[Dict]:
if not candidates:
return []
pairs = [(query, c["chunk"]["text"]) for c in candidates]
scores = self.model.predict(pairs)
for c, s in zip(candidates, scores):
c["ce_score"] = float(s)
return sorted(
candidates,
key=lambda x: x["ce_score"],
reverse=True,
)[:top_k]
# =========================================================
# GEMINI GENERATOR
# =========================================================
class GeminiGenerator:
def __init__(self, model_name="models/gemini-flash-lite-latest"):
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
raise RuntimeError("GOOGLE_API_KEY not set")
self.client = genai.Client(api_key=api_key)
self.model_name = model_name
print(f"Gemini model selected: {model_name}")
def generate(self, query: str, chunks: list) -> str:
if not chunks:
return "No relevant information found."
context = ""
for i, c in enumerate(chunks, 1):
context += f"[{i}] {c['chunk']['text']}\n\n"
prompt = f"""
Answer the question using ONLY the sources below.
Cite sources as [1], [2], etc.
SOURCES:
{context}
QUESTION:
{query}
ANSWER:
"""
try:
response = self.client.models.generate_content(
model=self.model_name,
contents=prompt,
)
return response.text
except genai.errors.ClientError as e:
if "RESOURCE_EXHAUSTED" in str(e):
print("Rate limit hit. Retrying in 30s...")
time.sleep(30)
response = self.client.models.generate_content(
model=self.model_name,
contents=prompt,
)
return response.text
raise
# =========================================================
# COMPLETE RAG PIPELINE (WITH REPORT FILTERING)
# =========================================================
class CompleteRAGPipeline:
def __init__(self, faiss_db_path: str, embedding_model: str):
print("Initializing Complete RAG Pipeline...")
self.query_processor = MedicalQueryProcessor(embedding_model)
self.retriever = HybridRetriever(faiss_db_path)
self.reranker = MedicalReranker()
self.llm = GeminiGenerator()
print("RAG Pipeline ready")
def get_available_reports(self) -> List[str]:
return self.retriever.get_available_reports()
def ask(
self,
query: str,
report_name: Optional[str] = None,
) -> Dict:
processed = self.query_processor.process(query)
candidates = self.retriever.search(
processed["embedding"],
query,
)
# ----------------------------------
# METADATA FILTERING (DROPDOWN MODE)
# ----------------------------------
if report_name:
candidates = [
c for c in candidates
if c["chunk"].get("filename") == report_name
]
if not candidates:
return {
"query": query,
"answer": f"No information found for report: {report_name}",
"timestamp": datetime.now().isoformat(),
}
top_chunks = self.reranker.rerank(
query,
candidates,
top_k=DEFAULT_TOP_K,
)
answer = self.llm.generate(query, top_chunks)
return {
"query": query,
"answer": answer,
"sources": top_chunks,
"timestamp": datetime.now().isoformat(),
}
# =========================================================
# MAIN (TEST)
# =========================================================
def main():
FAISS_DB = "output/biomedbert_vector_db"
EMB_MODEL = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
pipeline = CompleteRAGPipeline(FAISS_DB, EMB_MODEL)
reports = pipeline.get_available_reports()
print("Available reports:", reports)
result = pipeline.ask(
"What are the abnormal findings?",
report_name=reports[0] if reports else None,
)
print(result["answer"])
if __name__ == "__main__":
main()