Spaces:
Sleeping
Sleeping
Rajan Sharma
commited on
Create retriever.py
Browse files- retriever.py +51 -0
retriever.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# retriever.py
|
| 2 |
+
import os, json
|
| 3 |
+
from typing import List
|
| 4 |
+
import faiss
|
| 5 |
+
import numpy as np
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
|
| 8 |
+
class Retriever:
|
| 9 |
+
def __init__(self, index_path: str, meta_path: str):
|
| 10 |
+
if not (os.path.exists(index_path) and os.path.exists(meta_path)):
|
| 11 |
+
self._ready = False
|
| 12 |
+
self._err = f"Missing index or meta at {index_path} / {meta_path}"
|
| 13 |
+
return
|
| 14 |
+
self.index = faiss.read_index(index_path)
|
| 15 |
+
meta = json.loads(open(meta_path, "r", encoding="utf-8").read())
|
| 16 |
+
self.docs = meta["docs"]
|
| 17 |
+
self.model_name = meta["model"]
|
| 18 |
+
self.embed = SentenceTransformer(self.model_name)
|
| 19 |
+
self._ready = True
|
| 20 |
+
self._err = None
|
| 21 |
+
|
| 22 |
+
def ready(self) -> bool:
|
| 23 |
+
return self._ready
|
| 24 |
+
|
| 25 |
+
def reason(self) -> str:
|
| 26 |
+
return self._err or ""
|
| 27 |
+
|
| 28 |
+
def retrieve(self, query: str, k: int = 6) -> List[str]:
|
| 29 |
+
if not self._ready: return []
|
| 30 |
+
q = self.embed.encode([query], convert_to_numpy=True, normalize_embeddings=True)
|
| 31 |
+
D, I = self.index.search(q.astype(np.float32), k)
|
| 32 |
+
chunks = []
|
| 33 |
+
for idx in I[0]:
|
| 34 |
+
if 0 <= idx < len(self.docs):
|
| 35 |
+
chunks.append(self.docs[idx]["text"])
|
| 36 |
+
return chunks
|
| 37 |
+
|
| 38 |
+
# convenience
|
| 39 |
+
_retriever = None
|
| 40 |
+
def init_retriever(index_path="rag_store/index.faiss", meta_path="rag_store/meta.json"):
|
| 41 |
+
global _retriever
|
| 42 |
+
if _retriever is None:
|
| 43 |
+
_retriever = Retriever(index_path, meta_path)
|
| 44 |
+
return _retriever
|
| 45 |
+
|
| 46 |
+
def retrieve_context(query: str, k: int = 6) -> str:
|
| 47 |
+
r = init_retriever()
|
| 48 |
+
if not r.ready():
|
| 49 |
+
# Safe fallback if index not built yet
|
| 50 |
+
return ("(No policy index found. Run build_policy_index.py to enable RAG.)")
|
| 51 |
+
return "\n---\n".join(r.retrieve(query, k=k))
|