MedEmotion-Assistant / rag_engine.py
shreenikethjoshi's picture
Update rag_engine.py
cf7c6f4 verified
"""
Phase 2: RAG Engine — FAISS Vector Database for Medical Knowledge Retrieval.
Builds a FAISS index from doctor responses in the training set.
At inference, retrieves Top-K relevant medical passages for a patient query.
Usage:
# Build index
python rag_engine.py --mode build --csv ./data/train.csv
# Query index
python rag_engine.py --mode query --query "I have chest pain and shortness of breath"
"""
import argparse
import os
import numpy as np
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import sys
sys.path.insert(0, os.path.dirname(__file__))
from config import (
SENTENCE_EMBED_MODEL,
FAISS_INDEX_PATH,
RAG_TOP_K,
RAG_EMBED_DIM,
)
class RAGEngine:
"""FAISS-backed retrieval engine for medical passages."""
def __init__(self, index_path=FAISS_INDEX_PATH, passages_path=None):
self.index_path = index_path
# If the path is passed directly (like in Hugging Face), use it!
if passages_path:
self.passages_path = passages_path
else:
self.passages_path = index_path.replace(".bin", "_passages.npy")
self.embedder = SentenceTransformer(SENTENCE_EMBED_MODEL)
self.index = None
self.passages = None
# ----------------------------------------------------------
# Build
# ----------------------------------------------------------
def build_index(self, passages: list[str], batch_size: int = 256):
"""
Encode all passages and build a FAISS Inner-Product index.
"""
print(f"Encoding {len(passages)} passages...")
embeddings = self.embedder.encode(
passages,
batch_size=batch_size,
show_progress_bar=True,
normalize_embeddings=True,
)
embeddings = np.array(embeddings, dtype="float32")
# Build FAISS index (Inner Product = cosine similarity on normalized vecs)
self.index = faiss.IndexFlatIP(RAG_EMBED_DIM)
self.index.add(embeddings)
self.passages = np.array(passages)
# Save
os.makedirs(os.path.dirname(self.index_path) or ".", exist_ok=True)
faiss.write_index(self.index, self.index_path)
np.save(self.passages_path, self.passages)
print(f"✅ FAISS index built: {self.index.ntotal} vectors")
print(f" Saved to: {self.index_path}")
# ----------------------------------------------------------
# Load
# ----------------------------------------------------------
def load_index(self):
"""Load a previously built FAISS index."""
if not os.path.exists(self.index_path):
raise FileNotFoundError(f"No FAISS index at {self.index_path}")
self.index = faiss.read_index(self.index_path)
self.passages = np.load(self.passages_path, allow_pickle=True)
print(f"Loaded FAISS index: {self.index.ntotal} vectors")
# ----------------------------------------------------------
# Retrieve
# ----------------------------------------------------------
def retrieve(self, query: str, top_k: int = RAG_TOP_K) -> list[dict]:
"""
Returns top-K passages with scores.
Returns:
list of {"passage": str, "score": float}
"""
if self.index is None:
self.load_index()
q_emb = self.embedder.encode(
[query], normalize_embeddings=True
).astype("float32")
scores, indices = self.index.search(q_emb, top_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx < len(self.passages):
results.append({
"passage": str(self.passages[idx]),
"score": float(score),
})
return results
def retrieve_batch(self, queries: list[str], top_k: int = RAG_TOP_K):
"""Batch retrieval for training efficiency."""
if self.index is None:
self.load_index()
q_embs = self.embedder.encode(
queries, normalize_embeddings=True, batch_size=128
).astype("float32")
scores, indices = self.index.search(q_embs, top_k)
all_results = []
for i in range(len(queries)):
results = []
for score, idx in zip(scores[i], indices[i]):
if idx < len(self.passages):
results.append({
"passage": str(self.passages[idx]),
"score": float(score),
})
all_results.append(results)
return all_results
# ============================================================
# CLI
# ============================================================
def main(args):
engine = RAGEngine(args.index_path)
if args.mode == "build":
df = pd.read_csv(args.csv)
# Use the 'description' field (short, searchable) concatenated
# with the doctor's response for maximum retrieval quality
passages = []
for _, row in df.iterrows():
desc = str(row.get("description", ""))
resp = str(row.get("doctor_response", ""))
if len(resp) > 20:
passages.append(f"{desc} | {resp}")
engine.build_index(passages, batch_size=args.batch_size)
elif args.mode == "query":
results = engine.retrieve(args.query, top_k=args.top_k)
print(f"\nTop-{args.top_k} results for: '{args.query}'\n")
for i, r in enumerate(results):
print(f" [{i+1}] (score={r['score']:.4f})")
print(f" {r['passage'][:200]}...")
print()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["build", "query"], required=True)
parser.add_argument("--csv", default="./data/train.csv")
parser.add_argument("--query", default="")
parser.add_argument("--top_k", type=int, default=RAG_TOP_K)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--index_path", default=FAISS_INDEX_PATH)
args = parser.parse_args()
main(args)