suryaprakash01's picture
Upload 2277 files
d5c0065 verified
#!/usr/bin/env python3
"""
Complete Medical RAG Pipeline
Query → Hybrid Retrieval → Cross-Encoder Rerank → Gemini Answer
Features
--------
• FAISS semantic retrieval
• BM25 keyword retrieval
• Hybrid score fusion
• Cross-encoder reranking
• Gemini medical QA generation
• Report specific filtering (dropdown support)
"""
import os
import re
import pickle
import time
import numpy as np
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Optional
import faiss
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder
from google import genai
# ============================================
# CONFIG
# ============================================
DEFAULT_TOP_K = 5
# ============================================
# QUERY PROCESSOR
# ============================================
class MedicalQueryProcessor:
def __init__(self, embedding_model: str):
print(f"Loading embedding model: {embedding_model}")
self.model = SentenceTransformer(embedding_model)
self.dim = self.model.get_sentence_embedding_dimension()
print(f"Embedding dimension: {self.dim}")
def extract_keywords(self, query: str) -> List[str]:
patterns = [
r"\b(cancer|carcinoma|tumor|neoplasm)\b",
r"\b(ER|PR|HER2)\b",
r"\b(stage\s*[IVX]+)\b",
r"\b(grade\s*[123])\b",
r"\b(lymph\s*node)\b",
]
found = []
for p in patterns:
found.extend(re.findall(p, query, flags=re.I))
return list(set(found))
def embed(self, text: str) -> np.ndarray:
return self.model.encode(
text,
normalize_embeddings=True
)
def process(self, query: str) -> Dict:
return {
"query": query,
"keywords": self.extract_keywords(query),
"embedding": self.embed(query),
}
# ============================================
# HYBRID RETRIEVER
# ============================================
class HybridRetriever:
def __init__(self, faiss_db_path: str):
db = Path(faiss_db_path)
print(f"Loading FAISS index from: {db}")
self.index = faiss.read_index(str(db / "faiss.index"))
with open(db / "metadata.pkl", "rb") as f:
data = pickle.load(f)
self.chunks = data["chunks"]
print(f"Loaded {len(self.chunks)} chunks")
tokenized = [
c["text"].lower().split()
for c in self.chunks
]
self.bm25 = BM25Okapi(tokenized)
def get_available_reports(self) -> List[str]:
return sorted({
c["filename"]
for c in self.chunks
})
def search(
self,
query_embedding: np.ndarray,
query_text: str,
top_k: int = 40
) -> List[Dict]:
# -----------------------------------
# FAISS SEMANTIC SEARCH
# -----------------------------------
distances, indices = self.index.search(
query_embedding.reshape(1, -1).astype("float32"),
top_k
)
faiss_scores = {}
for idx, dist in zip(indices[0], distances[0]):
if idx >= 0:
faiss_scores[int(idx)] = float(dist)
if not faiss_scores:
return []
# -----------------------------------
# BM25 KEYWORD SEARCH
# -----------------------------------
tokens = query_text.lower().split()
bm25_raw = self.bm25.get_scores(tokens)
bm25_max = max(bm25_raw) if max(bm25_raw) > 0 else 1.0
bm25_scores = {
i: bm25_raw[i] / bm25_max
for i in faiss_scores.keys()
}
# -----------------------------------
# HYBRID SCORE FUSION
# -----------------------------------
fused_scores = {}
for idx in faiss_scores.keys():
faiss_score = faiss_scores.get(idx, 0)
bm25_score = bm25_scores.get(idx, 0)
fused_scores[idx] = (
0.7 * faiss_score +
0.3 * bm25_score
)
ranked = sorted(
fused_scores.items(),
key=lambda x: x[1],
reverse=True
)
results = []
for idx, score in ranked:
results.append(
{
"chunk": self.chunks[idx],
"score": score
}
)
return results
# ============================================
# CROSS ENCODER RERANKER
# ============================================
class MedicalReranker:
def __init__(self):
print("Loading cross-encoder reranker...")
self.model = CrossEncoder(
"cross-encoder/ms-marco-MiniLM-L-6-v2"
)
print("Cross-encoder ready")
def rerank(
self,
query: str,
candidates: List[Dict],
top_k: int = DEFAULT_TOP_K
) -> List[Dict]:
if not candidates:
return []
pairs = [
(query, c["chunk"]["text"])
for c in candidates
]
scores = self.model.predict(pairs)
for c, s in zip(candidates, scores):
c["ce_score"] = float(s)
ranked = sorted(
candidates,
key=lambda x: x["ce_score"],
reverse=True
)
return ranked[:top_k]
# ============================================
# GEMINI GENERATOR
# ============================================
class GeminiGenerator:
def __init__(
self,
model_name="models/gemini-flash-lite-latest"
):
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
raise RuntimeError("GOOGLE_API_KEY not set")
self.client = genai.Client(api_key=api_key)
self.model_name = model_name
print(f"Using Gemini model: {model_name}")
def generate(
self,
query: str,
chunks: List[Dict]
) -> str:
if not chunks:
return "No relevant information found."
context = ""
for i, c in enumerate(chunks, 1):
chunk_text = c['chunk'].get('text', '')
# If the database was built without the original text files, it defaults to "Document: <filename>"
# We can still provide the extracted entities to Gemini so it has context!
if chunk_text.startswith("Document:") and "entities" in c['chunk']:
entities = [e.get("text", "") for e in c['chunk']["entities"]]
chunk_text += "\nExtracted Medical Entities: " + ", ".join(entities)
context += f"[{i}] {chunk_text}\n\n"
prompt = f"""
You are an expert medical assistant. Answer the medical question using ONLY the provided text or entities from the pathology documents below.
Treat the provided information as your complete source material. Do not state that you cannot access files, as the contents/entities are provided directly below.
Cite your sources in your text using their corresponding numbers like [1], [2], etc.
--- PROVIDED DOCUMENT CONTENTS / ENTITIES ---
{context}
--- QUESTION ---
{query}
--- ANSWER ---
"""
try:
response = self.client.models.generate_content(
model=self.model_name,
contents=prompt
)
return response.text
except Exception as e:
if "RESOURCE_EXHAUSTED" in str(e):
print("Rate limit reached. Waiting 30 seconds...")
time.sleep(30)
response = self.client.models.generate_content(
model=self.model_name,
contents=prompt
)
return response.text
raise
# ============================================
# COMPLETE RAG PIPELINE
# ============================================
class CompleteRAGPipeline:
def __init__(
self,
faiss_db_path: str,
embedding_model: str
):
print("\nInitializing Medical RAG Pipeline\n")
self.query_processor = MedicalQueryProcessor(
embedding_model
)
self.retriever = HybridRetriever(
faiss_db_path
)
self.reranker = MedicalReranker()
self.llm = GeminiGenerator()
print("\nPipeline ready\n")
def get_available_reports(self) -> List[str]:
return self.retriever.get_available_reports()
def ask(
self,
query: str,
report_name: Optional[str] = None,
top_k: int = 5
) -> Dict:
processed = self.query_processor.process(query)
candidates = self.retriever.search(
processed["embedding"],
query
)
# ----------------------------------
# REPORT FILTERING
# ----------------------------------
if report_name:
candidates = [
c for c in candidates
if c["chunk"].get("filename") == report_name
]
if not candidates:
return {
"query": query,
"answer": f"No information found for report: {report_name}",
"timestamp": datetime.now().isoformat(),
"sources": [],
"num_sources": 0
}
# ----------------------------------
# RERANK
# ----------------------------------
top_chunks = self.reranker.rerank(
query,
candidates,
top_k=top_k
)
# ----------------------------------
# GENERATE ANSWER
# ----------------------------------
answer = self.llm.generate(
query,
top_chunks
)
return {
"query": query,
"answer": answer,
"sources": top_chunks,
"timestamp": datetime.now().isoformat(),
"num_sources": len(top_chunks)
}
# ============================================
# MAIN TEST
# ============================================
def main():
FAISS_DB = "output/biomedbert_vector_db"
EMB_MODEL = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
pipeline = CompleteRAGPipeline(
FAISS_DB,
EMB_MODEL
)
reports = pipeline.get_available_reports()
print("\nAvailable reports:")
print(reports)
result = pipeline.ask(
"What abnormal findings are present?",
report_name=reports[0] if reports else None
)
print("\nAnswer:\n")
print(result["answer"])
if __name__ == "__main__":
main()