Medica_DecisionSupportAI / retriever.py
Rajan Sharma
Update retriever.py
14fa872 verified
raw
history blame
1.21 kB
import logging
try:
import faiss
_HAS_FAISS = True
except ImportError:
logging.warning("FAISS not installed — retrieval will be disabled. Install faiss-cpu or faiss-gpu for full functionality.")
_HAS_FAISS = False
from sentence_transformers import SentenceTransformer
# load embedding model (still works even if FAISS missing)
_model = SentenceTransformer("all-MiniLM-L6-v2")
_index = None
_docs = []
def init_retriever(docs=None):
"""
Initialize FAISS index if FAISS is available.
docs: list[str] to index
"""
global _index, _docs
if not _HAS_FAISS:
_docs = docs or []
return
if docs:
_docs = docs
embeddings = _model.encode(docs, convert_to_numpy=True)
d = embeddings.shape[1]
_index = faiss.IndexFlatL2(d)
_index.add(embeddings)
def retrieve_context(query: str, k: int = 5):
"""
Retrieve top-k docs matching query.
Falls back to empty list if FAISS unavailable.
"""
if not _HAS_FAISS or _index is None or not _docs:
return []
q_emb = _model.encode([query], convert_to_numpy=True)
D, I = _index.search(q_emb, k)
return [_docs[i] for i in I[0] if i < len(_docs)]