Ashanasri commited on
Commit
0eeb787
Β·
verified Β·
1 Parent(s): 9f0e12c

Upload app/rag/search.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app/rag/search.py +195 -0
app/rag/search.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from __future__ import annotations
2
+ # from typing import List, Dict, Any
3
+ # import re
4
+ # import numpy as np
5
+ # from pathlib import Path
6
+ # import faiss
7
+
8
+ # from app.rag.embeddings import BGEM3Embedder
9
+ # from app.rag.storage import load_faiss, load_jsonl, search as faiss_search
10
+
11
+
12
+ # def clean_text(text: str) -> str:
13
+ # """
14
+ # Clean text extracted from PDFs:
15
+ # - Collapse multiple spaces/tabs
16
+ # - Replace line breaks with spaces (unless paragraph breaks)
17
+ # - Normalize multiple newlines
18
+ # - Add spaces between lowercase-uppercase and letter-digit transitions
19
+ # - Strip leading/trailing whitespace
20
+ # """
21
+ # if not text:
22
+ # return ""
23
+ # text = re.sub(r"[ \t]+", " ", text) # collapse spaces/tabs
24
+ # text = re.sub(r"\n(?!\n)", " ", text) # single newline -> space
25
+ # text = re.sub(r"\n{2,}", "\n", text) # multi newlines -> single newline
26
+ # text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text) # split lowercase-uppercase
27
+ # text = re.sub(r"([a-zA-Z])(\d)", r"\1 \2", text) # split letter-digit
28
+ # text = re.sub(r"(\d)([a-zA-Z])", r"\1 \2", text) # split digit-letter
29
+ # return text.strip()
30
+
31
+
32
+ # class RAGSearcher:
33
+ # def __init__(self, index_path: Path, meta_path: Path, device: str = "cpu"):
34
+ # self.index_path = index_path
35
+ # self.meta_path = meta_path
36
+ # self.meta = load_jsonl(meta_path)
37
+ # self.embedder = BGEM3Embedder(device=device)
38
+ # self.index: faiss.Index = load_faiss(index_path)
39
+ # self.d = self.index.d # embedding dimension sanity check
40
+
41
+ # # Clean text in metadata on load
42
+ # for m in self.meta:
43
+ # m["text"] = clean_text(m.get("text", ""))
44
+
45
+ # def embed_query(self, query: str) -> np.ndarray:
46
+ # return self.embedder.embed_one(query, mode="query")
47
+
48
+ # def top_k(self, query: str, k: int = 5, rerank: bool = True) -> List[Dict[str, Any]]:
49
+ # """
50
+ # Search the FAISS index for top-k passages matching the query.
51
+ # Optionally rerank using fresh embeddings for better accuracy.
52
+ # """
53
+ # q = self.embed_query(query).reshape(1, -1)
54
+ # scores, ids = faiss_search(self.index, q, top_k=k)
55
+ # ids_row = ids[0].tolist()
56
+ # scores_row = scores[0].tolist()
57
+
58
+ # items = []
59
+ # for i, sc in zip(ids_row, scores_row):
60
+ # if i < 0:
61
+ # continue
62
+ # m = self.meta[i]
63
+ # items.append({
64
+ # "id": i,
65
+ # "score": float(sc),
66
+ # "page": m["page"],
67
+ # "chunk_index": m["chunk_index"],
68
+ # "source": m["source"],
69
+ # "text": m["text"],
70
+ # })
71
+
72
+ # if rerank and items:
73
+ # # Re-embed candidate passages and recompute cosine similarity
74
+ # passages = [it["text"] for it in items]
75
+ # P = self.embedder.embed_texts(passages, mode="passage")
76
+ # qv = q.astype("float32") # [1, d]
77
+ # rerank_scores = (P @ qv.T).reshape(-1) # cosine sim with L2 normed vectors
78
+ # for it, rs in zip(items, rerank_scores.tolist()):
79
+ # it["rerank_score"] = float(rs)
80
+ # items.sort(key=lambda x: x.get("rerank_score", x["score"]), reverse=True)
81
+ # else:
82
+ # items.sort(key=lambda x: x["score"], reverse=True)
83
+
84
+ # return items
85
+
86
+
87
+
88
+
89
+ from __future__ import annotations
90
+
91
+ """
92
+ search.py
93
+ =========
94
+ RAGSearcher β€” wraps FAISS index + BGE-M3 embedder.
95
+ Exposes _top_k_sync() (blocking) used by utils.py pipelines.
96
+ """
97
+
98
+ import asyncio
99
+ import re
100
+ from pathlib import Path
101
+ from typing import Any, Dict, List
102
+
103
+ import faiss
104
+ import numpy as np
105
+
106
+ from app.rag.embeddings import BGEM3Embedder
107
+ from app.rag.storage import load_faiss, load_jsonl, search as faiss_search
108
+
109
+
110
+ # ── Text cleaner ──────────────────────────────────────────────────────────────
111
+
112
+ def clean_text(text: str) -> str:
113
+ if not text:
114
+ return ""
115
+ text = re.sub(r"[ \t]+", " ", text)
116
+ text = re.sub(r"\n(?!\n)", " ", text)
117
+ text = re.sub(r"\n{2,}", "\n", text)
118
+ text = re.sub(r"([a-z])([A-Z])", r"\1 \2", text)
119
+ text = re.sub(r"([a-zA-Z])(\d)", r"\1 \2", text)
120
+ text = re.sub(r"(\d)([a-zA-Z])", r"\1 \2", text)
121
+ return text.strip()
122
+
123
+
124
+ # ── RAGSearcher ───────────────────────────────────────────────────────────────
125
+
126
+ class RAGSearcher:
127
+ """
128
+ Loads FAISS index + metadata once.
129
+ _top_k_sync() -> blocking retrieval (called by utils.answer_query*)
130
+ top_k() -> async wrapper (optional direct use)
131
+ """
132
+
133
+ def __init__(self, index_path: Path, meta_path: Path, device: str = "cpu"):
134
+ self.meta = load_jsonl(meta_path)
135
+ self.embedder = BGEM3Embedder(device=device)
136
+ self.index: faiss.Index = load_faiss(index_path)
137
+ self.d = self.index.d
138
+
139
+ # Clean all metadata text once at load time
140
+ for m in self.meta:
141
+ m["text"] = clean_text(m.get("text", ""))
142
+
143
+ def embed_query(self, query: str) -> np.ndarray:
144
+ return self.embedder.embed_one(query, mode="query")
145
+
146
+ # ── Blocking retrieval ────────────────────────────────────────────────────
147
+
148
+ def _top_k_sync(
149
+ self, query: str, k: int = 5, rerank: bool = True
150
+ ) -> List[Dict[str, Any]]:
151
+ """
152
+ 1. Embed query with BGE-M3
153
+ 2. FAISS cosine search (top-k)
154
+ 3. Rerank via fresh passage embeddings (cosine rescore)
155
+ Returns list of hit dicts sorted by best score.
156
+ """
157
+ q = self.embed_query(query).reshape(1, -1)
158
+ scores, ids = faiss_search(self.index, q, top_k=k)
159
+
160
+ items = []
161
+ for i, sc in zip(ids[0].tolist(), scores[0].tolist()):
162
+ if i < 0:
163
+ continue
164
+ m = self.meta[i]
165
+ items.append({
166
+ "id": i,
167
+ "score": float(sc),
168
+ "page": m.get("page"),
169
+ "chunk_index": m.get("chunk_index"),
170
+ "source": m.get("source"),
171
+ "text": m["text"],
172
+ })
173
+
174
+ print(f"Retrieved chunk {i} with initial score {sc:.4f}")
175
+
176
+ if rerank and items:
177
+ passages = [it["text"] for it in items]
178
+ P = self.embedder.embed_texts(passages, mode="passage")
179
+
180
+ rerank_scores = (P @ q.astype("float32").T).reshape(-1)
181
+ for it, rs in zip(items, rerank_scores.tolist()):
182
+ it["rerank_score"] = float(rs)
183
+ items.sort(key=lambda x: x.get("rerank_score", x["score"]), reverse=True)
184
+ else:
185
+ items.sort(key=lambda x: x["score"], reverse=True)
186
+
187
+ return items
188
+
189
+
190
+ async def top_k(
191
+ self, query: str, k: int = 5, rerank: bool = True
192
+ ) -> List[Dict[str, Any]]:
193
+ """Non-blocking version for direct async use."""
194
+ loop = asyncio.get_event_loop()
195
+ return await loop.run_in_executor(None, self._top_k_sync, query, k, rerank)