Spaces:
Sleeping
Sleeping
Update retrieval.py
Browse files- retrieval.py +40 -22
retrieval.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
| 1 |
-
|
| 2 |
-
import json, faiss, numpy as np, os, re
|
| 3 |
from typing import List, Dict, Any
|
| 4 |
-
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
try:
|
| 7 |
from rank_bm25 import BM25Okapi
|
| 8 |
except Exception:
|
|
@@ -20,11 +25,17 @@ class Retriever:
|
|
| 20 |
for line in f:
|
| 21 |
self.chunks.append(json.loads(line))
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
self.faiss =
|
| 25 |
-
self.embed =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
#
|
| 28 |
if BM25Okapi is not None:
|
| 29 |
tokenized = [self._tokenize(c["chunk"]) for c in self.chunks]
|
| 30 |
self.bm25 = BM25Okapi(tokenized)
|
|
@@ -37,26 +48,33 @@ class Retriever:
|
|
| 37 |
def _tokenize(self, s):
|
| 38 |
return self._normalize(s).split()
|
| 39 |
|
| 40 |
-
def
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
D, I = self.faiss.search(np.asarray(q_emb, dtype="float32"), max(k*4, k))
|
| 44 |
-
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
| 55 |
for h in faiss_hits + bm25_hits:
|
| 56 |
key = (h["source_row"], h["chunk_id"])
|
| 57 |
if key in seen:
|
| 58 |
continue
|
| 59 |
-
seen.add(key)
|
| 60 |
-
|
|
|
|
| 61 |
break
|
| 62 |
-
return
|
|
|
|
|
|
| 1 |
+
import json, os, re
|
|
|
|
| 2 |
from typing import List, Dict, Any
|
| 3 |
+
import numpy as np
|
| 4 |
|
| 5 |
+
try:
|
| 6 |
+
import faiss
|
| 7 |
+
except Exception:
|
| 8 |
+
faiss = None
|
| 9 |
+
|
| 10 |
+
from sentence_transformers import SentenceTransformer
|
| 11 |
try:
|
| 12 |
from rank_bm25 import BM25Okapi
|
| 13 |
except Exception:
|
|
|
|
| 25 |
for line in f:
|
| 26 |
self.chunks.append(json.loads(line))
|
| 27 |
|
| 28 |
+
# try FAISS
|
| 29 |
+
self.faiss = None
|
| 30 |
+
self.embed = None
|
| 31 |
+
if faiss is not None and os.path.exists(FAISS_PATH):
|
| 32 |
+
try:
|
| 33 |
+
self.faiss = faiss.read_index(FAISS_PATH)
|
| 34 |
+
self.embed = SentenceTransformer(embed_model_name)
|
| 35 |
+
except Exception:
|
| 36 |
+
self.faiss = None
|
| 37 |
|
| 38 |
+
# BM25 is immediate
|
| 39 |
if BM25Okapi is not None:
|
| 40 |
tokenized = [self._tokenize(c["chunk"]) for c in self.chunks]
|
| 41 |
self.bm25 = BM25Okapi(tokenized)
|
|
|
|
| 48 |
def _tokenize(self, s):
|
| 49 |
return self._normalize(s).split()
|
| 50 |
|
| 51 |
+
def _faiss_hits(self, query, k):
|
| 52 |
+
if self.faiss is None or self.embed is None:
|
| 53 |
+
return []
|
| 54 |
+
q = self._normalize(query)
|
| 55 |
+
q_emb = self.embed.encode([q], normalize_embeddings=True)
|
| 56 |
D, I = self.faiss.search(np.asarray(q_emb, dtype="float32"), max(k*4, k))
|
| 57 |
+
return [self.chunks[i] for i in I[0] if i >= 0]
|
| 58 |
|
| 59 |
+
def _bm25_hits(self, query, k):
|
| 60 |
+
if self.bm25 is None:
|
| 61 |
+
return []
|
| 62 |
+
scores = self.bm25.get_scores(self._tokenize(query))
|
| 63 |
+
order = np.argsort(-scores)[:k*2]
|
| 64 |
+
return [self.chunks[i] for i in order]
|
| 65 |
|
| 66 |
+
def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
| 67 |
+
faiss_hits = self._faiss_hits(query, k)
|
| 68 |
+
bm25_hits = self._bm25_hits(query, k)
|
| 69 |
+
pool = []
|
| 70 |
+
seen = set()
|
| 71 |
for h in faiss_hits + bm25_hits:
|
| 72 |
key = (h["source_row"], h["chunk_id"])
|
| 73 |
if key in seen:
|
| 74 |
continue
|
| 75 |
+
seen.add(key)
|
| 76 |
+
pool.append(h)
|
| 77 |
+
if len(pool) >= k:
|
| 78 |
break
|
| 79 |
+
# if everything missing, just return first k docs to avoid empty UI
|
| 80 |
+
return pool or self.chunks[:k]
|