saeedbenadeeb commited on
Commit
79ad8ab
·
verified ·
1 Parent(s): 0cf8ad2

Upload retriever.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. retriever.py +115 -0
retriever.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hybrid retriever: BM25 (sparse) + FAISS/BGE (dense) with Reciprocal Rank Fusion."""
2
+
3
+ import json
4
+ import logging
5
+ import re
6
+
7
+ import faiss
8
+ import numpy as np
9
+ from rank_bm25 import BM25Okapi
10
+ from sentence_transformers import SentenceTransformer
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def _tokenize(text: str) -> list[str]:
16
+ return re.findall(r"\w+", text.lower())
17
+
18
+
19
+ def reciprocal_rank_fusion(
20
+ ranked_lists: list[list[int]], k: int = 60
21
+ ) -> list[tuple[int, float]]:
22
+ scores: dict[int, float] = {}
23
+ for ranked in ranked_lists:
24
+ for rank, idx in enumerate(ranked):
25
+ scores[idx] = scores.get(idx, 0.0) + 1.0 / (k + rank + 1)
26
+ return sorted(scores.items(), key=lambda x: x[1], reverse=True)
27
+
28
+
29
+ class Retriever:
30
+ def __init__(
31
+ self,
32
+ faiss_index_path: str = "faiss.index",
33
+ chunks_meta_path: str = "chunks_meta.jsonl",
34
+ embedding_model: str = "BAAI/bge-small-en-v1.5",
35
+ top_k: int = 5,
36
+ ):
37
+ self.top_k = top_k
38
+
39
+ logger.info("Loading embedding model: %s", embedding_model)
40
+ self.embed_model = SentenceTransformer(embedding_model)
41
+
42
+ logger.info("Loading FAISS index: %s", faiss_index_path)
43
+ self.index = faiss.read_index(faiss_index_path)
44
+
45
+ logger.info("Loading chunk metadata: %s", chunks_meta_path)
46
+ self.chunks: list[dict] = []
47
+ with open(chunks_meta_path, encoding="utf-8") as f:
48
+ for line in f:
49
+ line = line.strip()
50
+ if line:
51
+ self.chunks.append(json.loads(line))
52
+
53
+ logger.info("Building BM25 index over %d chunks...", len(self.chunks))
54
+ corpus_tokens = [_tokenize(c["text"]) for c in self.chunks]
55
+ self.bm25 = BM25Okapi(corpus_tokens)
56
+
57
+ logger.info("Retriever ready: %d vectors, %d chunks", self.index.ntotal, len(self.chunks))
58
+
59
+ def retrieve(self, query: str, top_k: int | None = None) -> list[dict]:
60
+ k = top_k or self.top_k
61
+ candidates_k = min(k * 20, self.index.ntotal)
62
+
63
+ dense_ranked = self._dense_search(query, candidates_k)
64
+ sparse_ranked = self._sparse_search(query, candidates_k)
65
+ fused = reciprocal_rank_fusion([dense_ranked, sparse_ranked])
66
+
67
+ results = []
68
+ for idx, rrf_score in fused:
69
+ if idx < 0 or idx >= len(self.chunks):
70
+ continue
71
+ chunk = self.chunks[idx].copy()
72
+ chunk["score"] = float(rrf_score)
73
+ results.append(chunk)
74
+
75
+ for r in results:
76
+ if r.get("is_faq"):
77
+ r["score"] = r["score"] * 1.2
78
+ results.sort(key=lambda x: x["score"], reverse=True)
79
+
80
+ return results[:k]
81
+
82
+ def _dense_search(self, query: str, k: int) -> list[int]:
83
+ prefixed = f"Represent this sentence for searching relevant passages: {query}"
84
+ qvec = self.embed_model.encode([prefixed], normalize_embeddings=True)
85
+ qvec = np.array(qvec, dtype=np.float32)
86
+ scores, indices = self.index.search(qvec, k)
87
+ return [int(i) for i in indices[0] if i >= 0]
88
+
89
+ def _sparse_search(self, query: str, k: int) -> list[int]:
90
+ tokens = _tokenize(query)
91
+ if not tokens:
92
+ return []
93
+ bm25_scores = self.bm25.get_scores(tokens)
94
+ top_indices = np.argsort(bm25_scores)[::-1][:k]
95
+ return [int(i) for i in top_indices if bm25_scores[i] > 0]
96
+
97
+ def format_context(self, results: list[dict]) -> str:
98
+ parts = []
99
+ for i, r in enumerate(reversed(results), 1):
100
+ source_label = f"[{r['source'].upper()}]" if r.get("source") else ""
101
+ title_label = f" - {r['title']}" if r.get("title") else ""
102
+ parts.append(f"--- Source {i} {source_label}{title_label} ---\n{r['text']}")
103
+ return "\n\n".join(parts)
104
+
105
+ def format_sources_markdown(self, results: list[dict]) -> str:
106
+ if not results:
107
+ return ""
108
+ lines = ["\n---\n**Sources:**"]
109
+ for i, r in enumerate(results, 1):
110
+ tag = "FAQ" if r.get("is_faq") else r.get("source", "").upper()
111
+ title = r.get("title", "Untitled")[:80]
112
+ score = r.get("score", 0)
113
+ preview = r["text"][:150].replace("\n", " ")
114
+ lines.append(f"{i}. **[{tag}]** {title} (score: {score:.4f})\n _{preview}..._")
115
+ return "\n".join(lines)