Rajan Sharma commited on
Commit
14fa872
·
verified ·
1 Parent(s): 5042fa6

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +43 -47
retriever.py CHANGED
@@ -1,49 +1,45 @@
1
- \
2
- import os, json
3
- from typing import List
4
- import faiss
5
- import numpy as np
 
 
 
 
6
  from sentence_transformers import SentenceTransformer
7
 
8
- class Retriever:
9
- def __init__(self, index_path: str, meta_path: str):
10
- if not (os.path.exists(index_path) and os.path.exists(meta_path)):
11
- self._ready = False
12
- self._err = f"Missing index or meta at {index_path} / {meta_path}"
13
- return
14
- self.index = faiss.read_index(index_path)
15
- meta = json.loads(open(meta_path, "r", encoding="utf-8").read())
16
- self.docs = meta["docs"]
17
- self.model_name = meta["model"]
18
- self.embed = SentenceTransformer(self.model_name)
19
- self._ready = True
20
- self._err = None
21
-
22
- def ready(self) -> bool:
23
- return self._ready
24
-
25
- def reason(self) -> str:
26
- return self._err or ""
27
-
28
- def retrieve(self, query: str, k: int = 6) -> List[str]:
29
- if not self._ready: return []
30
- q = self.embed.encode([query], convert_to_numpy=True, normalize_embeddings=True)
31
- D, I = self.index.search(q.astype(np.float32), k)
32
- chunks = []
33
- for idx in I[0]:
34
- if 0 <= idx < len(self.docs):
35
- chunks.append(self.docs[idx]["text"])
36
- return chunks
37
-
38
- _retriever = None
39
- def init_retriever(index_path="rag_store/index.faiss", meta_path="rag_store/meta.json"):
40
- global _retriever
41
- if _retriever is None:
42
- _retriever = Retriever(index_path, meta_path)
43
- return _retriever
44
-
45
- def retrieve_context(query: str, k: int = 6) -> str:
46
- r = init_retriever()
47
- if not r.ready():
48
- return "(No policy index found. Run build_policy_index.py to enable RAG.)"
49
- return "\n---\n".join(r.retrieve(query, k=k))
 
1
+ import logging
2
+
3
+ try:
4
+ import faiss
5
+ _HAS_FAISS = True
6
+ except ImportError:
7
+ logging.warning("FAISS not installed — retrieval will be disabled. Install faiss-cpu or faiss-gpu for full functionality.")
8
+ _HAS_FAISS = False
9
+
10
  from sentence_transformers import SentenceTransformer
11
 
12
+ # load embedding model (still works even if FAISS missing)
13
+ _model = SentenceTransformer("all-MiniLM-L6-v2")
14
+
15
+ _index = None
16
+ _docs = []
17
+
18
+ def init_retriever(docs=None):
19
+ """
20
+ Initialize FAISS index if FAISS is available.
21
+ docs: list[str] to index
22
+ """
23
+ global _index, _docs
24
+ if not _HAS_FAISS:
25
+ _docs = docs or []
26
+ return
27
+
28
+ if docs:
29
+ _docs = docs
30
+ embeddings = _model.encode(docs, convert_to_numpy=True)
31
+ d = embeddings.shape[1]
32
+ _index = faiss.IndexFlatL2(d)
33
+ _index.add(embeddings)
34
+
35
+ def retrieve_context(query: str, k: int = 5):
36
+ """
37
+ Retrieve top-k docs matching query.
38
+ Falls back to empty list if FAISS unavailable.
39
+ """
40
+ if not _HAS_FAISS or _index is None or not _docs:
41
+ return []
42
+
43
+ q_emb = _model.encode([query], convert_to_numpy=True)
44
+ D, I = _index.search(q_emb, k)
45
+ return [_docs[i] for i in I[0] if i < len(_docs)]