FurqanIshaq's picture
Update utils/rag_utils.py
23739f3 verified
# utils/rag_utils.py
import os
import pickle
import numpy as np
class RAGRetriever:
def __init__(self, base_dir):
"""
Initialize retriever for a disease type (e.g., diabetes)
base_dir -> path to dataset folder (e.g., datasets/diabetes)
"""
self.base_dir = base_dir
self.vector_dir = os.path.join(base_dir, "vectorstores")
self.text_dir = os.path.join(base_dir, "rag_text_enriched")
self.json_dir = os.path.join(base_dir, "cleaned")
def load_patient_data(self, patient_id):
"""
Load RAG text + FAISS (or pickle) index for a specific patient.
"""
# πŸ”Ή Paths
vector_path = os.path.join(self.vector_dir, f"{patient_id}_index.pkl")
text_path = os.path.join(self.text_dir, f"{patient_id}.txt")
json_path = os.path.join(self.json_dir, f"{patient_id}.json")
# πŸ”Ή Load FAISS vector (stored as pickle)
with open(vector_path, "rb") as f:
index = pickle.load(f)
# πŸ”Ή Load RAG text chunks
with open(text_path, "r", encoding="utf-8") as f:
text_chunks = f.read().split("\n")
return index, text_chunks, json_path
def retrieve(self, query_embedding, index, text_chunks, top_k=3):
"""
Returns top K relevant text chunks based on cosine similarity.
"""
distances, indices = index.search(
np.array([query_embedding]).astype("float32"), top_k
)
retrieved = [text_chunks[i] for i in indices[0] if i < len(text_chunks)]
return retrieved