Agent_studio / modules /rag_retriever.py
Corin1998's picture
Create modules/rag_retriever.py
fa9ab56 verified
raw
history blame contribute delete
890 Bytes
import faiss
import pickle
from typing import List, Tuple
from pathlib import Path
from sentence_transformers import SentenceTransformer
DATA_DIR = Path("data")
INDEX_PATH = DATA_DIR / "vector_store.faiss"
META_PATH = DATA_DIR / "vector_store_meta.pkl"
_model = None
def _embedder():
global _model
if _model is None:
_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
return _model
def retrieve_contexts(query: str, k: int=5) -> List[str]:
if not INDEX_PATH.exists():
return []
index = faiss.read_index(str(INDEX_PATH))
with open(META_PATH, "rb") as f:
meta = pickle.load(f)
vec = _embedder().encode([query], normalize_embeddings=True)
D, I = index.search(vec, k)
contexts = []
for idx in I[0]:
if idx == -1:
continue
contexts.append(meta[idx]["text"])
return contexts