Tom Aarsen commited on
Commit
e005eea
·
1 Parent(s): cf19736

Rewrite the app frontend; fix accidental exact search bug

Browse files
Files changed (1) hide show
  1. app.py +112 -83
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import time
 
2
  import gradio as gr
3
  from datasets import load_dataset, load_from_disk
4
  from huggingface_hub import hf_hub_download
@@ -10,12 +11,26 @@ import numpy as np
10
 
11
  # Load titles, texts, and int8 embeddings in a lazy Dataset, allowing us to efficiently access specific rows on demand
12
  # Note that we never actually use the int8 embeddings for search directly, they are only used for rescoring after the binary search
13
- title_text_int8_dataset = load_dataset("sentence-transformers/quantized-retrieval-data", split="train").select_columns(["url", "title", "text", "embedding"])
 
 
14
  # title_text_int8_dataset = load_from_disk("wikipedia-mxbai-embed-int8-index").select_columns(["url", "title", "text", "embedding"])
15
 
 
 
16
  # Load the binary indices
17
- binary_index_path = hf_hub_download(repo_id="sentence-transformers/quantized-retrieval-data", filename="wikipedia_ubinary_faiss_50m.index", local_dir=".", repo_type="dataset")
18
- binary_ivf_index_path = hf_hub_download(repo_id="sentence-transformers/quantized-retrieval-data", filename="wikipedia_ubinary_ivf_faiss_50m.index", local_dir=".", repo_type="dataset")
 
 
 
 
 
 
 
 
 
 
19
 
20
  binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary(binary_index_path)
21
  binary_ivf_index: faiss.IndexBinaryIVF = faiss.read_index_binary(binary_ivf_index_path)
@@ -32,16 +47,14 @@ warmup_queries = [
32
  "How to bake a chocolate cake?",
33
  "What is the theory of relativity?",
34
  ]
35
- model.encode(warmup_queries)
36
 
37
 
38
  def search(
39
  query,
40
  top_k: int = 20,
41
  rescore_multiplier: int = 4,
42
- use_approx: bool = False,
43
- display_score: bool = True,
44
- display_binary_rank: bool = False,
45
  ):
46
  # 1. Embed the query as float32
47
  start_time = time.time()
@@ -63,6 +76,7 @@ def search(
63
  )
64
  binary_ids = binary_ids[0]
65
  search_time = time.time() - start_time
 
66
 
67
  # 4. Load the corresponding int8 embeddings
68
  start_time = time.time()
@@ -85,43 +99,85 @@ def search(
85
 
86
  # 7. Load titles and texts for the top_k results
87
  start_time = time.time()
88
- top_k_titles = title_text_int8_dataset[top_k_indices]["title"]
89
  top_k_urls = title_text_int8_dataset[top_k_indices]["url"]
90
  top_k_texts = title_text_int8_dataset[top_k_indices]["text"]
91
- top_k_titles = [f"[{title}]({url})" for title, url in zip(top_k_titles, top_k_urls)]
92
  load_text_time = time.time() - start_time
93
 
94
- rank = np.arange(1, top_k + 1)
95
- data = {
96
- "Score": [f"{score:.2f}" for score in top_k_scores],
97
- "#": rank,
98
- "Binary #": indices + 1,
99
- "Title": top_k_titles,
100
- "Text": top_k_texts,
101
- }
102
- if not display_score:
103
- del data["Score"]
104
- if not display_binary_rank:
105
- del data["Binary #"]
106
- del data["#"]
107
- df = pd.DataFrame(data)
108
-
109
- return df, {
110
- "Embed Time": f"{embed_time:.4f} s",
111
- "Quantize Time": f"{quantize_time:.4f} s",
112
- "Search Time": f"{search_time:.4f} s",
113
- "Load int8 Time": f"{load_int8_time:.4f} s",
114
- "Rescore Time": f"{rescore_time:.4f} s",
115
- "Sort Time": f"{sort_time:.4f} s",
116
- "Load Text Time": f"{load_text_time:.4f} s",
117
- "Total Retrieval Time": f"{quantize_time + search_time + load_int8_time + rescore_time + sort_time + load_text_time:.4f} s",
118
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
  with gr.Blocks(title="Quantized Retrieval") as demo:
122
- gr.Markdown(
123
- """
124
- ## Quantized Retrieval - Binary Search with Scalar (int8) Rescoring
 
 
 
 
 
125
  This demo showcases retrieval using [quantized embeddings](https://huggingface.co/blog/embedding-quantization) on a CPU. The corpus consists of [41 million texts](https://huggingface.co/datasets/sentence-transformers/quantized-retrieval-data) from Wikipedia articles.
126
 
127
  <details><summary>Click to learn about the retrieval process</summary>
@@ -148,41 +204,24 @@ Notes:
148
  - The approximate search index (a binary Inverted File Index (IVF)) is in beta and has not been trained with a lot of data.
149
 
150
  </details>
 
 
151
  """
152
- )
153
- with gr.Row():
154
- with gr.Column(scale=60):
155
  query = gr.Textbox(
156
  label="Query for Wikipedia articles",
157
  placeholder="Enter a query to search for relevant texts from Wikipedia.",
158
  )
159
- with gr.Column(scale=25):
160
- use_approx = gr.Radio(
161
- choices=[("Exact Search", False), ("Approximate Search", True)],
162
- value=True,
163
- label="Search Settings",
164
- )
165
- with gr.Column(scale=15):
166
- display_score = gr.Checkbox(
167
- label="Display Score",
168
- value=True,
169
- )
170
- display_binary_rank = gr.Checkbox(
171
- label='Display Binary Rank',
172
- value=False,
173
- )
174
-
175
- with gr.Row():
176
- with gr.Column(scale=2):
177
  top_k = gr.Slider(
178
  minimum=10,
179
  maximum=1000,
180
  step=1,
181
  value=20,
182
  label="Number of documents to retrieve",
183
- info="Number of documents to retrieve from the binary search",
184
  )
185
- with gr.Column(scale=2):
186
  rescore_multiplier = gr.Slider(
187
  minimum=1,
188
  maximum=10,
@@ -191,17 +230,17 @@ Notes:
191
  label="Rescore multiplier",
192
  info="Search for `rescore_multiplier` as many documents to rescore",
193
  )
194
-
195
- search_button = gr.Button(value="Search")
 
 
 
196
 
197
  with gr.Row():
198
- with gr.Column(scale=4):
199
- output = gr.Dataframe(
200
- headers=["Score", "#", "Binary #", "Title", "Text"],
201
- datatype="markdown",
202
- )
203
  with gr.Column(scale=1):
204
- json = gr.JSON()
205
 
206
  examples = gr.Examples(
207
  examples=[
@@ -212,30 +251,20 @@ Notes:
212
  ],
213
  fn=search,
214
  inputs=[query],
215
- outputs=[output, json],
216
  cache_examples=False,
217
  run_on_click=True,
218
  )
219
 
220
  query.submit(
221
  search,
222
- inputs=[query, top_k, rescore_multiplier, use_approx, display_score, display_binary_rank],
223
- outputs=[output, json],
224
  )
225
  search_button.click(
226
  search,
227
- inputs=[query, top_k, rescore_multiplier, use_approx, display_score, display_binary_rank],
228
- outputs=[output, json],
229
- )
230
- display_score.change(
231
- search,
232
- inputs=[query, top_k, rescore_multiplier, use_approx, display_score, display_binary_rank],
233
- outputs=[output, json],
234
- )
235
- display_binary_rank.change(
236
- search,
237
- inputs=[query, top_k, rescore_multiplier, use_approx, display_score, display_binary_rank],
238
- outputs=[output, json],
239
  )
240
 
241
  demo.queue()
 
1
  import time
2
+ import html
3
  import gradio as gr
4
  from datasets import load_dataset, load_from_disk
5
  from huggingface_hub import hf_hub_download
 
11
 
12
  # Load titles, texts, and int8 embeddings in a lazy Dataset, allowing us to efficiently access specific rows on demand
13
  # Note that we never actually use the int8 embeddings for search directly, they are only used for rescoring after the binary search
14
+ title_text_int8_dataset = load_dataset(
15
+ "sentence-transformers/quantized-retrieval-data", split="train"
16
+ ).select_columns(["url", "title", "text", "embedding"])
17
  # title_text_int8_dataset = load_from_disk("wikipedia-mxbai-embed-int8-index").select_columns(["url", "title", "text", "embedding"])
18
 
19
+ TOTAL_NUM_DOCS = title_text_int8_dataset.num_rows
20
+
21
  # Load the binary indices
22
+ binary_index_path = hf_hub_download(
23
+ repo_id="sentence-transformers/quantized-retrieval-data",
24
+ filename="wikipedia_ubinary_faiss_50m.index",
25
+ local_dir=".",
26
+ repo_type="dataset",
27
+ )
28
+ binary_ivf_index_path = hf_hub_download(
29
+ repo_id="sentence-transformers/quantized-retrieval-data",
30
+ filename="wikipedia_ubinary_ivf_faiss_50m.index",
31
+ local_dir=".",
32
+ repo_type="dataset",
33
+ )
34
 
35
  binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary(binary_index_path)
36
  binary_ivf_index: faiss.IndexBinaryIVF = faiss.read_index_binary(binary_ivf_index_path)
 
47
  "How to bake a chocolate cake?",
48
  "What is the theory of relativity?",
49
  ]
50
+ model.encode_query(warmup_queries)
51
 
52
 
53
  def search(
54
  query,
55
  top_k: int = 20,
56
  rescore_multiplier: int = 4,
57
+ use_approx: bool = True,
 
 
58
  ):
59
  # 1. Embed the query as float32
60
  start_time = time.time()
 
76
  )
77
  binary_ids = binary_ids[0]
78
  search_time = time.time() - start_time
79
+ num_docs_searched = len(binary_ids)
80
 
81
  # 4. Load the corresponding int8 embeddings
82
  start_time = time.time()
 
99
 
100
  # 7. Load titles and texts for the top_k results
101
  start_time = time.time()
102
+ raw_top_k_titles = title_text_int8_dataset[top_k_indices]["title"]
103
  top_k_urls = title_text_int8_dataset[top_k_indices]["url"]
104
  top_k_texts = title_text_int8_dataset[top_k_indices]["text"]
 
105
  load_text_time = time.time() - start_time
106
 
107
+ # Build HTML cards for each result so the full row is visible at once
108
+ cards = []
109
+ for i in range(len(top_k_indices)):
110
+ title = html.escape(str(raw_top_k_titles[i]))
111
+ url = html.escape(str(top_k_urls[i]))
112
+ text = html.escape(str(top_k_texts[i]))
113
+ score_str = f"{top_k_scores[i]:.2f}"
114
+ rank_str = str(i + 1)
115
+ binary_rank_str = str(indices[i] + 1)
116
+ card_html = f"""
117
+ <div style=\"border: 1px solid var(--border-color-primary, #e0e0e0); border-radius: 10px; padding: 10px 12px; margin-bottom: 10px; background-color: var(--block-background-fill, transparent); color: inherit;\">
118
+ <div style=\"display: flex; align-items: flex-start; justify-content: space-between; gap: 8px; margin-bottom: 4px;\">
119
+ <div style=\"font-size: 16px; font-weight: 600; min-width: 0;\">
120
+ <a href=\"{url}\" target=\"_blank\" style=\"text-decoration: none; color: var(--link-text-color, #1f6feb);\">{title}</a>
121
+ </div>
122
+ <div style=\"font-size: 12px; color: var(--body-text-color-subdued, #586069); text-align: right; white-space: nowrap;\">
123
+ Score: {score_str} • Rank: {rank_str} • Binary rank: {binary_rank_str}
124
+ </div>
125
+ </div>
126
+ <div style=\"font-size: 13px; line-height: 1.4; max-height: 8em; overflow: hidden;\">{text}</div>
127
+ </div>
128
+ """
129
+ cards.append(card_html)
130
+
131
+ if cards:
132
+ cards_html = "\n".join(cards)
133
+ else:
134
+ cards_html = "<div>No results.</div>"
135
+
136
+ total_retrieval_time = (
137
+ quantize_time
138
+ + search_time
139
+ + load_int8_time
140
+ + rescore_time
141
+ + sort_time
142
+ + load_text_time
143
+ )
144
+ num_docs_retrieved = len(top_k_indices)
145
+ search_mode = "Approximate (IVF)" if use_approx else "Exact"
146
+
147
+ summary_md = f"""
148
+ <div style=\"border: 1px solid var(--border-color-primary, #e0e0e0); border-radius: 10px; padding: 10px 12px; background-color: var(--block-background-fill, transparent);\">
149
+ <h3 style=\"margin-top: 0;\">Search Summary</h3>
150
+ <ul style=\"margin-top: 0; margin-bottom: 8px; padding-left: 18px;\">
151
+ <li>Total docs in corpus: {TOTAL_NUM_DOCS:,}</li>
152
+ <li>Docs searched: {num_docs_searched}</li>
153
+ <li>Docs retrieved: {num_docs_retrieved}</li>
154
+ <li>Search mode: {search_mode}</li>
155
+ </ul>
156
+ <h4>Timings (in seconds)</h4>
157
+ <ul style=\"margin-top: 0; margin-bottom: 0; padding-left: 18px;\">
158
+ <li>Embed on CPU: {embed_time:.4f}</li>
159
+ <li>Quantize: {quantize_time:.4f}</li>
160
+ <li>Search: {search_time:.4f}</li>
161
+ <li>Load int8: {load_int8_time:.4f}</li>
162
+ <li>Rescore: {rescore_time:.4f}</li>
163
+ <li>Sort: {sort_time:.4f}</li>
164
+ <li>Load text: {load_text_time:.4f}</li>
165
+ </ul>
166
+ <strong>Total retrieval time: {total_retrieval_time:.4f} seconds</strong>
167
+ </div>"""
168
+
169
+ return cards_html, summary_md
170
 
171
 
172
  with gr.Blocks(title="Quantized Retrieval") as demo:
173
+ with gr.Row():
174
+ with gr.Column(scale=3):
175
+ gr.Markdown(
176
+ """
177
+ <div style='border: 1px solid var(--border-color-primary, #e0e0e0); border-radius: 10px; padding: 12px 14px; background-color: var(--block-background-fill, transparent);'>
178
+
179
+ <h1 style='margin-top: 0;'>Quantized Retrieval - Binary Search with Scalar (int8) Rescoring</h1>
180
+
181
  This demo showcases retrieval using [quantized embeddings](https://huggingface.co/blog/embedding-quantization) on a CPU. The corpus consists of [41 million texts](https://huggingface.co/datasets/sentence-transformers/quantized-retrieval-data) from Wikipedia articles.
182
 
183
  <details><summary>Click to learn about the retrieval process</summary>
 
204
  - The approximate search index (a binary Inverted File Index (IVF)) is in beta and has not been trained with a lot of data.
205
 
206
  </details>
207
+
208
+ </div>
209
  """
210
+ )
 
 
211
  query = gr.Textbox(
212
  label="Query for Wikipedia articles",
213
  placeholder="Enter a query to search for relevant texts from Wikipedia.",
214
  )
215
+ search_button = gr.Button(value="Search", variant="secondary")
216
+ with gr.Column(scale=1, min_width=0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  top_k = gr.Slider(
218
  minimum=10,
219
  maximum=1000,
220
  step=1,
221
  value=20,
222
  label="Number of documents to retrieve",
223
+ info="Number of documents to retrieve using binary search",
224
  )
 
225
  rescore_multiplier = gr.Slider(
226
  minimum=1,
227
  maximum=10,
 
230
  label="Rescore multiplier",
231
  info="Search for `rescore_multiplier` as many documents to rescore",
232
  )
233
+ use_approx = gr.Radio(
234
+ choices=[("Approximate Search", True), ("Exact Search", False)],
235
+ value=True,
236
+ label="Search Settings",
237
+ )
238
 
239
  with gr.Row():
240
+ with gr.Column(scale=3):
241
+ cards = gr.HTML(label="Results")
 
 
 
242
  with gr.Column(scale=1):
243
+ summary = gr.Markdown(label="Search Summary")
244
 
245
  examples = gr.Examples(
246
  examples=[
 
251
  ],
252
  fn=search,
253
  inputs=[query],
254
+ outputs=[cards, summary],
255
  cache_examples=False,
256
  run_on_click=True,
257
  )
258
 
259
  query.submit(
260
  search,
261
+ inputs=[query, top_k, rescore_multiplier, use_approx],
262
+ outputs=[cards, summary],
263
  )
264
  search_button.click(
265
  search,
266
+ inputs=[query, top_k, rescore_multiplier, use_approx],
267
+ outputs=[cards, summary],
 
 
 
 
 
 
 
 
 
 
268
  )
269
 
270
  demo.queue()