CIAZIZ commited on
Commit
01d7189
·
verified ·
1 Parent(s): b295110

Update retrieval.py

Browse files
Files changed (1) hide show
  1. retrieval.py +40 -22
retrieval.py CHANGED
@@ -1,8 +1,13 @@
1
- # retrieval.py FAISS + optional BM25 (no reranker; CPU-friendly)
2
- import json, faiss, numpy as np, os, re
3
  from typing import List, Dict, Any
4
- from sentence_transformers import SentenceTransformer
5
 
 
 
 
 
 
 
6
  try:
7
  from rank_bm25 import BM25Okapi
8
  except Exception:
@@ -20,11 +25,17 @@ class Retriever:
20
  for line in f:
21
  self.chunks.append(json.loads(line))
22
 
23
- # faiss
24
- self.faiss = faiss.read_index(FAISS_PATH)
25
- self.embed = SentenceTransformer(embed_model_name)
 
 
 
 
 
 
26
 
27
- # bm25 (optional)
28
  if BM25Okapi is not None:
29
  tokenized = [self._tokenize(c["chunk"]) for c in self.chunks]
30
  self.bm25 = BM25Okapi(tokenized)
@@ -37,26 +48,33 @@ class Retriever:
37
  def _tokenize(self, s):
38
  return self._normalize(s).split()
39
 
40
- def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
41
- q_norm = self._normalize(query)
42
- q_emb = self.embed.encode([q_norm], normalize_embeddings=True)
 
 
43
  D, I = self.faiss.search(np.asarray(q_emb, dtype="float32"), max(k*4, k))
44
- faiss_hits = [self.chunks[i] for i in I[0] if i >= 0]
45
 
46
- # optional BM25
47
- bm25_hits = []
48
- if self.bm25 is not None:
49
- bm25_scores = self.bm25.get_scores(self._tokenize(q_norm))
50
- bm25_ranked = np.argsort(-bm25_scores)[:k*2]
51
- bm25_hits = [self.chunks[i] for i in bm25_ranked]
52
 
53
- # merge uniques (prefer FAISS order)
54
- seen = set(); merged=[]
 
 
 
55
  for h in faiss_hits + bm25_hits:
56
  key = (h["source_row"], h["chunk_id"])
57
  if key in seen:
58
  continue
59
- seen.add(key); merged.append(h)
60
- if len(merged) >= k:
 
61
  break
62
- return merged
 
 
1
+ import json, os, re
 
2
  from typing import List, Dict, Any
3
+ import numpy as np
4
 
5
+ try:
6
+ import faiss
7
+ except Exception:
8
+ faiss = None
9
+
10
+ from sentence_transformers import SentenceTransformer
11
  try:
12
  from rank_bm25 import BM25Okapi
13
  except Exception:
 
25
  for line in f:
26
  self.chunks.append(json.loads(line))
27
 
28
+ # try FAISS
29
+ self.faiss = None
30
+ self.embed = None
31
+ if faiss is not None and os.path.exists(FAISS_PATH):
32
+ try:
33
+ self.faiss = faiss.read_index(FAISS_PATH)
34
+ self.embed = SentenceTransformer(embed_model_name)
35
+ except Exception:
36
+ self.faiss = None
37
 
38
+ # BM25 is immediate
39
  if BM25Okapi is not None:
40
  tokenized = [self._tokenize(c["chunk"]) for c in self.chunks]
41
  self.bm25 = BM25Okapi(tokenized)
 
48
  def _tokenize(self, s):
49
  return self._normalize(s).split()
50
 
51
+ def _faiss_hits(self, query, k):
52
+ if self.faiss is None or self.embed is None:
53
+ return []
54
+ q = self._normalize(query)
55
+ q_emb = self.embed.encode([q], normalize_embeddings=True)
56
  D, I = self.faiss.search(np.asarray(q_emb, dtype="float32"), max(k*4, k))
57
+ return [self.chunks[i] for i in I[0] if i >= 0]
58
 
59
+ def _bm25_hits(self, query, k):
60
+ if self.bm25 is None:
61
+ return []
62
+ scores = self.bm25.get_scores(self._tokenize(query))
63
+ order = np.argsort(-scores)[:k*2]
64
+ return [self.chunks[i] for i in order]
65
 
66
+ def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
67
+ faiss_hits = self._faiss_hits(query, k)
68
+ bm25_hits = self._bm25_hits(query, k)
69
+ pool = []
70
+ seen = set()
71
  for h in faiss_hits + bm25_hits:
72
  key = (h["source_row"], h["chunk_id"])
73
  if key in seen:
74
  continue
75
+ seen.add(key)
76
+ pool.append(h)
77
+ if len(pool) >= k:
78
  break
79
+ # if everything missing, just return first k docs to avoid empty UI
80
+ return pool or self.chunks[:k]