Medica_DecisionSupportAI / retriever.py
Rajan Sharma
Create retriever.py
c044be1 verified
raw
history blame
1.79 kB
# retriever.py
import os, json
from typing import List
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
class Retriever:
def __init__(self, index_path: str, meta_path: str):
if not (os.path.exists(index_path) and os.path.exists(meta_path)):
self._ready = False
self._err = f"Missing index or meta at {index_path} / {meta_path}"
return
self.index = faiss.read_index(index_path)
meta = json.loads(open(meta_path, "r", encoding="utf-8").read())
self.docs = meta["docs"]
self.model_name = meta["model"]
self.embed = SentenceTransformer(self.model_name)
self._ready = True
self._err = None
def ready(self) -> bool:
return self._ready
def reason(self) -> str:
return self._err or ""
def retrieve(self, query: str, k: int = 6) -> List[str]:
if not self._ready: return []
q = self.embed.encode([query], convert_to_numpy=True, normalize_embeddings=True)
D, I = self.index.search(q.astype(np.float32), k)
chunks = []
for idx in I[0]:
if 0 <= idx < len(self.docs):
chunks.append(self.docs[idx]["text"])
return chunks
# convenience
_retriever = None
def init_retriever(index_path="rag_store/index.faiss", meta_path="rag_store/meta.json"):
global _retriever
if _retriever is None:
_retriever = Retriever(index_path, meta_path)
return _retriever
def retrieve_context(query: str, k: int = 6) -> str:
r = init_retriever()
if not r.ready():
# Safe fallback if index not built yet
return ("(No policy index found. Run build_policy_index.py to enable RAG.)")
return "\n---\n".join(r.retrieve(query, k=k))