| """ |
| Phase 2: RAG Engine — FAISS Vector Database for Medical Knowledge Retrieval. |
| |
| Builds a FAISS index from doctor responses in the training set. |
| At inference, retrieves Top-K relevant medical passages for a patient query. |
| |
| Usage: |
| # Build index |
| python rag_engine.py --mode build --csv ./data/train.csv |
| |
| # Query index |
| python rag_engine.py --mode query --query "I have chest pain and shortness of breath" |
| """ |
| import argparse |
| import os |
| import numpy as np |
| import pandas as pd |
| import faiss |
| from sentence_transformers import SentenceTransformer |
| from tqdm import tqdm |
|
|
| import sys |
| sys.path.insert(0, os.path.dirname(__file__)) |
| from config import ( |
| SENTENCE_EMBED_MODEL, |
| FAISS_INDEX_PATH, |
| RAG_TOP_K, |
| RAG_EMBED_DIM, |
| ) |
|
|
|
|
| class RAGEngine: |
| """FAISS-backed retrieval engine for medical passages.""" |
|
|
| def __init__(self, index_path=FAISS_INDEX_PATH, passages_path=None): |
| self.index_path = index_path |
| |
| |
| if passages_path: |
| self.passages_path = passages_path |
| else: |
| self.passages_path = index_path.replace(".bin", "_passages.npy") |
| |
| self.embedder = SentenceTransformer(SENTENCE_EMBED_MODEL) |
| self.index = None |
| self.passages = None |
|
|
| |
| |
| |
| def build_index(self, passages: list[str], batch_size: int = 256): |
| """ |
| Encode all passages and build a FAISS Inner-Product index. |
| """ |
| print(f"Encoding {len(passages)} passages...") |
| embeddings = self.embedder.encode( |
| passages, |
| batch_size=batch_size, |
| show_progress_bar=True, |
| normalize_embeddings=True, |
| ) |
| embeddings = np.array(embeddings, dtype="float32") |
|
|
| |
| self.index = faiss.IndexFlatIP(RAG_EMBED_DIM) |
| self.index.add(embeddings) |
| self.passages = np.array(passages) |
|
|
| |
| os.makedirs(os.path.dirname(self.index_path) or ".", exist_ok=True) |
| faiss.write_index(self.index, self.index_path) |
| np.save(self.passages_path, self.passages) |
|
|
| print(f"✅ FAISS index built: {self.index.ntotal} vectors") |
| print(f" Saved to: {self.index_path}") |
|
|
| |
| |
| |
| def load_index(self): |
| """Load a previously built FAISS index.""" |
| if not os.path.exists(self.index_path): |
| raise FileNotFoundError(f"No FAISS index at {self.index_path}") |
| self.index = faiss.read_index(self.index_path) |
| self.passages = np.load(self.passages_path, allow_pickle=True) |
| print(f"Loaded FAISS index: {self.index.ntotal} vectors") |
|
|
| |
| |
| |
| def retrieve(self, query: str, top_k: int = RAG_TOP_K) -> list[dict]: |
| """ |
| Returns top-K passages with scores. |
| |
| Returns: |
| list of {"passage": str, "score": float} |
| """ |
| if self.index is None: |
| self.load_index() |
|
|
| q_emb = self.embedder.encode( |
| [query], normalize_embeddings=True |
| ).astype("float32") |
|
|
| scores, indices = self.index.search(q_emb, top_k) |
|
|
| results = [] |
| for score, idx in zip(scores[0], indices[0]): |
| if idx < len(self.passages): |
| results.append({ |
| "passage": str(self.passages[idx]), |
| "score": float(score), |
| }) |
| return results |
|
|
| def retrieve_batch(self, queries: list[str], top_k: int = RAG_TOP_K): |
| """Batch retrieval for training efficiency.""" |
| if self.index is None: |
| self.load_index() |
|
|
| q_embs = self.embedder.encode( |
| queries, normalize_embeddings=True, batch_size=128 |
| ).astype("float32") |
|
|
| scores, indices = self.index.search(q_embs, top_k) |
|
|
| all_results = [] |
| for i in range(len(queries)): |
| results = [] |
| for score, idx in zip(scores[i], indices[i]): |
| if idx < len(self.passages): |
| results.append({ |
| "passage": str(self.passages[idx]), |
| "score": float(score), |
| }) |
| all_results.append(results) |
| return all_results |
|
|
|
|
| |
| |
| |
| def main(args): |
| engine = RAGEngine(args.index_path) |
|
|
| if args.mode == "build": |
| df = pd.read_csv(args.csv) |
| |
| |
| passages = [] |
| for _, row in df.iterrows(): |
| desc = str(row.get("description", "")) |
| resp = str(row.get("doctor_response", "")) |
| if len(resp) > 20: |
| passages.append(f"{desc} | {resp}") |
|
|
| engine.build_index(passages, batch_size=args.batch_size) |
|
|
| elif args.mode == "query": |
| results = engine.retrieve(args.query, top_k=args.top_k) |
| print(f"\nTop-{args.top_k} results for: '{args.query}'\n") |
| for i, r in enumerate(results): |
| print(f" [{i+1}] (score={r['score']:.4f})") |
| print(f" {r['passage'][:200]}...") |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--mode", choices=["build", "query"], required=True) |
| parser.add_argument("--csv", default="./data/train.csv") |
| parser.add_argument("--query", default="") |
| parser.add_argument("--top_k", type=int, default=RAG_TOP_K) |
| parser.add_argument("--batch_size", type=int, default=256) |
| parser.add_argument("--index_path", default=FAISS_INDEX_PATH) |
| args = parser.parse_args() |
| main(args) |
|
|