Spaces:
Runtime error
Runtime error
| # utils/rag_utils.py | |
| import os | |
| import pickle | |
| import numpy as np | |
| class RAGRetriever: | |
| def __init__(self, base_dir): | |
| """ | |
| Initialize retriever for a disease type (e.g., diabetes) | |
| base_dir -> path to dataset folder (e.g., datasets/diabetes) | |
| """ | |
| self.base_dir = base_dir | |
| self.vector_dir = os.path.join(base_dir, "vectorstores") | |
| self.text_dir = os.path.join(base_dir, "rag_text_enriched") | |
| self.json_dir = os.path.join(base_dir, "cleaned") | |
| def load_patient_data(self, patient_id): | |
| """ | |
| Load RAG text + FAISS (or pickle) index for a specific patient. | |
| """ | |
| # πΉ Paths | |
| vector_path = os.path.join(self.vector_dir, f"{patient_id}_index.pkl") | |
| text_path = os.path.join(self.text_dir, f"{patient_id}.txt") | |
| json_path = os.path.join(self.json_dir, f"{patient_id}.json") | |
| # πΉ Load FAISS vector (stored as pickle) | |
| with open(vector_path, "rb") as f: | |
| index = pickle.load(f) | |
| # πΉ Load RAG text chunks | |
| with open(text_path, "r", encoding="utf-8") as f: | |
| text_chunks = f.read().split("\n") | |
| return index, text_chunks, json_path | |
| def retrieve(self, query_embedding, index, text_chunks, top_k=3): | |
| """ | |
| Returns top K relevant text chunks based on cosine similarity. | |
| """ | |
| distances, indices = index.search( | |
| np.array([query_embedding]).astype("float32"), top_k | |
| ) | |
| retrieved = [text_chunks[i] for i in indices[0] if i < len(text_chunks)] | |
| return retrieved | |