VivekanandaAI / retrieve.py
jyotirmoy05's picture
Upload 6 files
09f4291 verified
"""
Simple retrieval utility for optional RAG.
Usage:
python rag/retrieve.py --index vectorstore/faiss_index --query "What is fear?"
"""
import json
from pathlib import Path
import argparse
import sys
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
sys.path.append(str(Path(__file__).resolve().parents[1]))
from utils import get_utils
def load_index(path: Path):
index = faiss.read_index(str(path / "index.faiss"))
texts = json.loads((path / "texts.json").read_text(encoding="utf-8"))
return index, texts
def main():
parser = argparse.ArgumentParser(description="Retrieve context from FAISS index")
parser.add_argument("--index", type=str, required=True)
parser.add_argument("--query", type=str, required=True)
parser.add_argument("--top_k", type=int, default=5)
args = parser.parse_args()
index, texts = load_index(Path(args.index))
utils = get_utils()
device = utils.device_manager.get_torch_device()
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device)
q_emb = model.encode([args.query], convert_to_numpy=True)
faiss.normalize_L2(q_emb)
scores, idxs = index.search(q_emb, args.top_k)
print("RETRIEVED CONTEXT\n" + "-" * 80)
for i, (score, idx) in enumerate(zip(scores[0], idxs[0])):
print(f"[{i+1}] score={score:.4f}\n{texts[idx][:500]}\n")
print("-- Combine into block --\n")
block = "\n\n".join(texts[i] for i in idxs[0])
print(block)
if __name__ == "__main__":
main()