Spaces:
Sleeping
Sleeping
| """ | |
| Simple retrieval utility for optional RAG. | |
| Usage: | |
| python rag/retrieve.py --index vectorstore/faiss_index --query "What is fear?" | |
| """ | |
| import json | |
| from pathlib import Path | |
| import argparse | |
| import sys | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| sys.path.append(str(Path(__file__).resolve().parents[1])) | |
| from utils import get_utils | |
| def load_index(path: Path): | |
| index = faiss.read_index(str(path / "index.faiss")) | |
| texts = json.loads((path / "texts.json").read_text(encoding="utf-8")) | |
| return index, texts | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Retrieve context from FAISS index") | |
| parser.add_argument("--index", type=str, required=True) | |
| parser.add_argument("--query", type=str, required=True) | |
| parser.add_argument("--top_k", type=int, default=5) | |
| args = parser.parse_args() | |
| index, texts = load_index(Path(args.index)) | |
| utils = get_utils() | |
| device = utils.device_manager.get_torch_device() | |
| model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device) | |
| q_emb = model.encode([args.query], convert_to_numpy=True) | |
| faiss.normalize_L2(q_emb) | |
| scores, idxs = index.search(q_emb, args.top_k) | |
| print("RETRIEVED CONTEXT\n" + "-" * 80) | |
| for i, (score, idx) in enumerate(zip(scores[0], idxs[0])): | |
| print(f"[{i+1}] score={score:.4f}\n{texts[idx][:500]}\n") | |
| print("-- Combine into block --\n") | |
| block = "\n\n".join(texts[i] for i in idxs[0]) | |
| print(block) | |
| if __name__ == "__main__": | |
| main() |