Rajan Sharma commited on
Commit
c044be1
·
verified ·
1 Parent(s): b1c2b18

Create retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +51 -0
retriever.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # retriever.py
2
+ import os, json
3
+ from typing import List
4
+ import faiss
5
+ import numpy as np
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ class Retriever:
9
+ def __init__(self, index_path: str, meta_path: str):
10
+ if not (os.path.exists(index_path) and os.path.exists(meta_path)):
11
+ self._ready = False
12
+ self._err = f"Missing index or meta at {index_path} / {meta_path}"
13
+ return
14
+ self.index = faiss.read_index(index_path)
15
+ meta = json.loads(open(meta_path, "r", encoding="utf-8").read())
16
+ self.docs = meta["docs"]
17
+ self.model_name = meta["model"]
18
+ self.embed = SentenceTransformer(self.model_name)
19
+ self._ready = True
20
+ self._err = None
21
+
22
+ def ready(self) -> bool:
23
+ return self._ready
24
+
25
+ def reason(self) -> str:
26
+ return self._err or ""
27
+
28
+ def retrieve(self, query: str, k: int = 6) -> List[str]:
29
+ if not self._ready: return []
30
+ q = self.embed.encode([query], convert_to_numpy=True, normalize_embeddings=True)
31
+ D, I = self.index.search(q.astype(np.float32), k)
32
+ chunks = []
33
+ for idx in I[0]:
34
+ if 0 <= idx < len(self.docs):
35
+ chunks.append(self.docs[idx]["text"])
36
+ return chunks
37
+
38
+ # convenience
39
+ _retriever = None
40
+ def init_retriever(index_path="rag_store/index.faiss", meta_path="rag_store/meta.json"):
41
+ global _retriever
42
+ if _retriever is None:
43
+ _retriever = Retriever(index_path, meta_path)
44
+ return _retriever
45
+
46
+ def retrieve_context(query: str, k: int = 6) -> str:
47
+ r = init_retriever()
48
+ if not r.ready():
49
+ # Safe fallback if index not built yet
50
+ return ("(No policy index found. Run build_policy_index.py to enable RAG.)")
51
+ return "\n---\n".join(r.retrieve(query, k=k))