""" Phase 2: RAG Engine — FAISS Vector Database for Medical Knowledge Retrieval. Builds a FAISS index from doctor responses in the training set. At inference, retrieves Top-K relevant medical passages for a patient query. Usage: # Build index python rag_engine.py --mode build --csv ./data/train.csv # Query index python rag_engine.py --mode query --query "I have chest pain and shortness of breath" """ import argparse import os import numpy as np import pandas as pd import faiss from sentence_transformers import SentenceTransformer from tqdm import tqdm import sys sys.path.insert(0, os.path.dirname(__file__)) from config import ( SENTENCE_EMBED_MODEL, FAISS_INDEX_PATH, RAG_TOP_K, RAG_EMBED_DIM, ) class RAGEngine: """FAISS-backed retrieval engine for medical passages.""" def __init__(self, index_path=FAISS_INDEX_PATH, passages_path=None): self.index_path = index_path # If the path is passed directly (like in Hugging Face), use it! if passages_path: self.passages_path = passages_path else: self.passages_path = index_path.replace(".bin", "_passages.npy") self.embedder = SentenceTransformer(SENTENCE_EMBED_MODEL) self.index = None self.passages = None # ---------------------------------------------------------- # Build # ---------------------------------------------------------- def build_index(self, passages: list[str], batch_size: int = 256): """ Encode all passages and build a FAISS Inner-Product index. """ print(f"Encoding {len(passages)} passages...") embeddings = self.embedder.encode( passages, batch_size=batch_size, show_progress_bar=True, normalize_embeddings=True, ) embeddings = np.array(embeddings, dtype="float32") # Build FAISS index (Inner Product = cosine similarity on normalized vecs) self.index = faiss.IndexFlatIP(RAG_EMBED_DIM) self.index.add(embeddings) self.passages = np.array(passages) # Save os.makedirs(os.path.dirname(self.index_path) or ".", exist_ok=True) faiss.write_index(self.index, self.index_path) np.save(self.passages_path, self.passages) print(f"✅ FAISS index built: {self.index.ntotal} vectors") print(f" Saved to: {self.index_path}") # ---------------------------------------------------------- # Load # ---------------------------------------------------------- def load_index(self): """Load a previously built FAISS index.""" if not os.path.exists(self.index_path): raise FileNotFoundError(f"No FAISS index at {self.index_path}") self.index = faiss.read_index(self.index_path) self.passages = np.load(self.passages_path, allow_pickle=True) print(f"Loaded FAISS index: {self.index.ntotal} vectors") # ---------------------------------------------------------- # Retrieve # ---------------------------------------------------------- def retrieve(self, query: str, top_k: int = RAG_TOP_K) -> list[dict]: """ Returns top-K passages with scores. Returns: list of {"passage": str, "score": float} """ if self.index is None: self.load_index() q_emb = self.embedder.encode( [query], normalize_embeddings=True ).astype("float32") scores, indices = self.index.search(q_emb, top_k) results = [] for score, idx in zip(scores[0], indices[0]): if idx < len(self.passages): results.append({ "passage": str(self.passages[idx]), "score": float(score), }) return results def retrieve_batch(self, queries: list[str], top_k: int = RAG_TOP_K): """Batch retrieval for training efficiency.""" if self.index is None: self.load_index() q_embs = self.embedder.encode( queries, normalize_embeddings=True, batch_size=128 ).astype("float32") scores, indices = self.index.search(q_embs, top_k) all_results = [] for i in range(len(queries)): results = [] for score, idx in zip(scores[i], indices[i]): if idx < len(self.passages): results.append({ "passage": str(self.passages[idx]), "score": float(score), }) all_results.append(results) return all_results # ============================================================ # CLI # ============================================================ def main(args): engine = RAGEngine(args.index_path) if args.mode == "build": df = pd.read_csv(args.csv) # Use the 'description' field (short, searchable) concatenated # with the doctor's response for maximum retrieval quality passages = [] for _, row in df.iterrows(): desc = str(row.get("description", "")) resp = str(row.get("doctor_response", "")) if len(resp) > 20: passages.append(f"{desc} | {resp}") engine.build_index(passages, batch_size=args.batch_size) elif args.mode == "query": results = engine.retrieve(args.query, top_k=args.top_k) print(f"\nTop-{args.top_k} results for: '{args.query}'\n") for i, r in enumerate(results): print(f" [{i+1}] (score={r['score']:.4f})") print(f" {r['passage'][:200]}...") print() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--mode", choices=["build", "query"], required=True) parser.add_argument("--csv", default="./data/train.csv") parser.add_argument("--query", default="") parser.add_argument("--top_k", type=int, default=RAG_TOP_K) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--index_path", default=FAISS_INDEX_PATH) args = parser.parse_args() main(args)