Tom Aarsen commited on
Commit
f60b8b2
·
1 Parent(s): 173908c

Use downloaded int8 embeddings instead of an index in this repository

Browse files
Files changed (2) hide show
  1. app.py +127 -45
  2. requirements.txt +1 -2
app.py CHANGED
@@ -1,52 +1,71 @@
1
  import time
2
  import gradio as gr
3
- from datasets import load_dataset
4
  import pandas as pd
5
  from sentence_transformers import SentenceTransformer
6
  from sentence_transformers.quantization import quantize_embeddings
7
  import faiss
8
- from usearch.index import Index
9
 
10
- # Load titles and texts
11
- title_text_dataset = load_dataset("mixedbread-ai/wikipedia-data-en-2023-11", split="train", num_proc=4).select_columns(["title", "text"])
 
 
12
 
13
- # Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
14
- int8_view = Index.restore("wikipedia_int8_usearch_50m.index", view=True)
15
  binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_ubinary_faiss_50m.index")
16
- binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary("wikipedia_ubinary_ivf_faiss_50m.index")
17
 
18
  # Load the SentenceTransformer model for embedding the queries
19
- model = SentenceTransformer(
20
- "mixedbread-ai/mxbai-embed-large-v1",
21
- prompts={
22
- "retrieval": "Represent this sentence for searching relevant passages: ",
23
- },
24
- default_prompt_name="retrieval",
25
- )
26
-
27
-
28
- def search(query, top_k: int = 100, rescore_multiplier: int = 1, use_approx: bool = False):
 
 
 
 
 
 
 
 
 
 
 
 
29
  # 1. Embed the query as float32
30
  start_time = time.time()
31
- query_embedding = model.encode(query)
32
  embed_time = time.time() - start_time
33
 
34
  # 2. Quantize the query to ubinary
35
  start_time = time.time()
36
- query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
 
 
37
  quantize_time = time.time() - start_time
38
 
39
  # 3. Search the binary index (either exact or approximate)
40
- index = binary_ivf if use_approx else binary_index
41
  start_time = time.time()
42
- _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier)
 
 
43
  binary_ids = binary_ids[0]
44
  search_time = time.time() - start_time
45
 
46
  # 4. Load the corresponding int8 embeddings
47
  start_time = time.time()
48
- int8_embeddings = int8_view[binary_ids].astype(int)
49
- load_time = time.time() - start_time
 
 
50
 
51
  # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
52
  start_time = time.time()
@@ -58,22 +77,40 @@ def search(query, top_k: int = 100, rescore_multiplier: int = 1, use_approx: boo
58
  indices = scores.argsort()[::-1][:top_k]
59
  top_k_indices = binary_ids[indices]
60
  top_k_scores = scores[indices]
61
- top_k_titles, top_k_texts = zip(
62
- *[(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) for idx in top_k_indices.tolist()]
63
- )
64
- df = pd.DataFrame(
65
- {"Score": [round(value, 2) for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts}
66
- )
67
  sort_time = time.time() - start_time
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  return df, {
70
  "Embed Time": f"{embed_time:.4f} s",
71
  "Quantize Time": f"{quantize_time:.4f} s",
72
  "Search Time": f"{search_time:.4f} s",
73
- "Load Time": f"{load_time:.4f} s",
74
  "Rescore Time": f"{rescore_time:.4f} s",
75
  "Sort Time": f"{sort_time:.4f} s",
76
- "Total Retrieval Time": f"{quantize_time + search_time + load_time + rescore_time + sort_time:.4f} s",
 
77
  }
78
 
79
 
@@ -81,17 +118,18 @@ with gr.Blocks(title="Quantized Retrieval") as demo:
81
  gr.Markdown(
82
  """
83
  ## Quantized Retrieval - Binary Search with Scalar (int8) Rescoring
84
- This demo showcases retrieval using [quantized embeddings](https://huggingface.co/blog/embedding-quantization) on a CPU. The corpus consists of 41 million texts from Wikipedia articles.
85
 
86
  <details><summary>Click to learn about the retrieval process</summary>
87
 
88
  Details:
89
  1. The query is embedded using the [`mixedbread-ai/mxbai-embed-large-v1`](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1) SentenceTransformer model.
90
  2. The query is quantized to binary using the `quantize_embeddings` function from the SentenceTransformers library.
91
- 3. A binary index (41M binary embeddings; 5.2GB of memory/disk space) is searched using the quantized query for the top 40 documents.
92
- 4. The top 40 documents are loaded on the fly from an int8 index on disk (41M int8 embeddings; 0 bytes of memory, 47.5GB of disk space).
93
- 5. The top 40 documents are rescored using the float32 query and the int8 embeddings to get the top 10 documents.
94
- 6. The top 10 documents are sorted by score and displayed.
 
95
 
96
  This process is designed to be memory efficient and fast, with the binary index being small enough to fit in memory and the int8 index being loaded as a view to save memory.
97
  In total, this process requires keeping 1) the model in memory, 2) the binary index in memory, and 3) the int8 index on disk. With a dimensionality of 1024,
@@ -103,13 +141,13 @@ Additionally, the binary index is much faster (up to 32x) to search than the flo
103
  Feel free to check out the [code for this demo](https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/blob/main/app.py) to learn more about how to apply this in practice.
104
 
105
  Notes:
106
- - The approximate search index (a binary Inverted File Index (IVF)) is in beta and has not been trained with a lot of data. A better IVF index will be released soon.
107
 
108
  </details>
109
  """
110
  )
111
  with gr.Row():
112
- with gr.Column(scale=75):
113
  query = gr.Textbox(
114
  label="Query for Wikipedia articles",
115
  placeholder="Enter a query to search for relevant texts from Wikipedia.",
@@ -118,7 +156,16 @@ Notes:
118
  use_approx = gr.Radio(
119
  choices=[("Exact Search", False), ("Approximate Search", True)],
120
  value=True,
121
- label="Search Index",
 
 
 
 
 
 
 
 
 
122
  )
123
 
124
  with gr.Row():
@@ -126,8 +173,8 @@ Notes:
126
  top_k = gr.Slider(
127
  minimum=10,
128
  maximum=1000,
129
- step=5,
130
- value=100,
131
  label="Number of documents to retrieve",
132
  info="Number of documents to retrieve from the binary search",
133
  )
@@ -136,7 +183,7 @@ Notes:
136
  minimum=1,
137
  maximum=10,
138
  step=1,
139
- value=1,
140
  label="Rescore multiplier",
141
  info="Search for `rescore_multiplier` as many documents to rescore",
142
  )
@@ -145,12 +192,47 @@ Notes:
145
 
146
  with gr.Row():
147
  with gr.Column(scale=4):
148
- output = gr.Dataframe(headers=["Score", "Title", "Text"])
 
 
 
149
  with gr.Column(scale=1):
150
  json = gr.JSON()
151
 
152
- query.submit(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json])
153
- search_button.click(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  demo.queue()
156
  demo.launch()
 
1
  import time
2
  import gradio as gr
3
+ from datasets import load_dataset, load_from_disk
4
  import pandas as pd
5
  from sentence_transformers import SentenceTransformer
6
  from sentence_transformers.quantization import quantize_embeddings
7
  import faiss
8
+ import numpy as np
9
 
10
+ # Load titles, texts, and int8 embeddings in a lazy Dataset, allowing us to efficiently access specific rows on demand
11
+ # Note that we never actually use the int8 embeddings for search directly, they are only used for rescoring after the binary search
12
+ title_text_int8_dataset = load_dataset("sentence-transformers/wikipedia-mxbai-embed-int8-index", split="train").select_columns(["title", "text", "embedding"])
13
+ # title_text_int8_dataset = load_from_disk("wikipedia-mxbai-embed-int8-index").select_columns(["url", "title", "text", "embedding"])
14
 
15
+ # Load the binary indices
 
16
  binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_ubinary_faiss_50m.index")
17
+ binary_ivf_index: faiss.IndexBinaryIVF = faiss.read_index_binary("wikipedia_ubinary_ivf_faiss_50m.index")
18
 
19
  # Load the SentenceTransformer model for embedding the queries
20
+ model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
21
+ if model.device.type == "cuda":
22
+ model.bfloat16()
23
+
24
+ warmup_queries = [
25
+ "What is the capital of France?",
26
+ "Who is the president of the United States?",
27
+ "What is the largest mammal?",
28
+ "How to bake a chocolate cake?",
29
+ "What is the theory of relativity?",
30
+ ]
31
+ model.encode(warmup_queries)
32
+
33
+
34
+ def search(
35
+ query,
36
+ top_k: int = 20,
37
+ rescore_multiplier: int = 4,
38
+ use_approx: bool = False,
39
+ display_score: bool = True,
40
+ display_binary_rank: bool = False,
41
+ ):
42
  # 1. Embed the query as float32
43
  start_time = time.time()
44
+ query_embedding = model.encode_query(query)
45
  embed_time = time.time() - start_time
46
 
47
  # 2. Quantize the query to ubinary
48
  start_time = time.time()
49
+ query_embedding_ubinary = quantize_embeddings(
50
+ query_embedding.reshape(1, -1), "ubinary"
51
+ )
52
  quantize_time = time.time() - start_time
53
 
54
  # 3. Search the binary index (either exact or approximate)
55
+ index = binary_ivf_index if use_approx else binary_index
56
  start_time = time.time()
57
+ _scores, binary_ids = index.search(
58
+ query_embedding_ubinary, top_k * rescore_multiplier
59
+ )
60
  binary_ids = binary_ids[0]
61
  search_time = time.time() - start_time
62
 
63
  # 4. Load the corresponding int8 embeddings
64
  start_time = time.time()
65
+ int8_embeddings = np.array(
66
+ title_text_int8_dataset[binary_ids]["embedding"], dtype=np.int8
67
+ )
68
+ load_int8_time = time.time() - start_time
69
 
70
  # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
71
  start_time = time.time()
 
77
  indices = scores.argsort()[::-1][:top_k]
78
  top_k_indices = binary_ids[indices]
79
  top_k_scores = scores[indices]
 
 
 
 
 
 
80
  sort_time = time.time() - start_time
81
 
82
+ # 7. Load titles and texts for the top_k results
83
+ start_time = time.time()
84
+ top_k_titles = title_text_int8_dataset[top_k_indices]["title"]
85
+ top_k_urls = title_text_int8_dataset[top_k_indices]["url"]
86
+ top_k_texts = title_text_int8_dataset[top_k_indices]["text"]
87
+ top_k_titles = [f"[{title}]({url})" for title, url in zip(top_k_titles, top_k_urls)]
88
+ load_text_time = time.time() - start_time
89
+
90
+ rank = np.arange(1, top_k + 1)
91
+ data = {
92
+ "Score": [f"{score:.2f}" for score in top_k_scores],
93
+ "#": rank,
94
+ "Binary #": indices + 1,
95
+ "Title": top_k_titles,
96
+ "Text": top_k_texts,
97
+ }
98
+ if not display_score:
99
+ del data["Score"]
100
+ if not display_binary_rank:
101
+ del data["Binary #"]
102
+ del data["#"]
103
+ df = pd.DataFrame(data)
104
+
105
  return df, {
106
  "Embed Time": f"{embed_time:.4f} s",
107
  "Quantize Time": f"{quantize_time:.4f} s",
108
  "Search Time": f"{search_time:.4f} s",
109
+ "Load int8 Time": f"{load_int8_time:.4f} s",
110
  "Rescore Time": f"{rescore_time:.4f} s",
111
  "Sort Time": f"{sort_time:.4f} s",
112
+ "Load Text Time": f"{load_text_time:.4f} s",
113
+ "Total Retrieval Time": f"{quantize_time + search_time + load_int8_time + rescore_time + sort_time + load_text_time:.4f} s",
114
  }
115
 
116
 
 
118
  gr.Markdown(
119
  """
120
  ## Quantized Retrieval - Binary Search with Scalar (int8) Rescoring
121
+ This demo showcases retrieval using [quantized embeddings](https://huggingface.co/blog/embedding-quantization) on a CPU. The corpus consists of [41 million texts](https://huggingface.co/datasets/sentence-transformers/wikipedia-mxbai-embed-int8-index) from Wikipedia articles.
122
 
123
  <details><summary>Click to learn about the retrieval process</summary>
124
 
125
  Details:
126
  1. The query is embedded using the [`mixedbread-ai/mxbai-embed-large-v1`](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1) SentenceTransformer model.
127
  2. The query is quantized to binary using the `quantize_embeddings` function from the SentenceTransformers library.
128
+ 3. A binary index (41M binary embeddings; 5.2GB of memory/disk space) is searched using the quantized query for the top 80 documents.
129
+ 4. The top 80 documents are loaded on the fly from an int8 index on disk (41M int8 embeddings; 0 bytes of memory, 47.5GB of disk space).
130
+ 5. The top 80 documents are rescored using the float32 query and the int8 embeddings to get the top 20 documents.
131
+ 6. The top 20 documents are sorted by score.
132
+ 7. The titles and texts of the top 20 documents are loaded on the fly from disk and displayed.
133
 
134
  This process is designed to be memory efficient and fast, with the binary index being small enough to fit in memory and the int8 index being loaded as a view to save memory.
135
  In total, this process requires keeping 1) the model in memory, 2) the binary index in memory, and 3) the int8 index on disk. With a dimensionality of 1024,
 
141
  Feel free to check out the [code for this demo](https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/blob/main/app.py) to learn more about how to apply this in practice.
142
 
143
  Notes:
144
+ - The approximate search index (a binary Inverted File Index (IVF)) is in beta and has not been trained with a lot of data.
145
 
146
  </details>
147
  """
148
  )
149
  with gr.Row():
150
+ with gr.Column(scale=60):
151
  query = gr.Textbox(
152
  label="Query for Wikipedia articles",
153
  placeholder="Enter a query to search for relevant texts from Wikipedia.",
 
156
  use_approx = gr.Radio(
157
  choices=[("Exact Search", False), ("Approximate Search", True)],
158
  value=True,
159
+ label="Search Settings",
160
+ )
161
+ with gr.Column(scale=15):
162
+ display_score = gr.Checkbox(
163
+ label="Display Score",
164
+ value=True,
165
+ )
166
+ display_binary_rank = gr.Checkbox(
167
+ label='Display Binary Rank',
168
+ value=False,
169
  )
170
 
171
  with gr.Row():
 
173
  top_k = gr.Slider(
174
  minimum=10,
175
  maximum=1000,
176
+ step=1,
177
+ value=20,
178
  label="Number of documents to retrieve",
179
  info="Number of documents to retrieve from the binary search",
180
  )
 
183
  minimum=1,
184
  maximum=10,
185
  step=1,
186
+ value=4,
187
  label="Rescore multiplier",
188
  info="Search for `rescore_multiplier` as many documents to rescore",
189
  )
 
192
 
193
  with gr.Row():
194
  with gr.Column(scale=4):
195
+ output = gr.Dataframe(
196
+ headers=["Score", "#", "Binary #", "Title", "Text"],
197
+ datatype="markdown",
198
+ )
199
  with gr.Column(scale=1):
200
  json = gr.JSON()
201
 
202
+ examples = gr.Examples(
203
+ examples=[
204
+ "What is the coldest metal to the touch?",
205
+ "Who won the FIFA World Cup in 2018?",
206
+ "How to make a paper airplane?",
207
+ "Who was the first woman to cross the Pacific ocean by plane?",
208
+ ],
209
+ fn=search,
210
+ inputs=[query],
211
+ outputs=[output, json],
212
+ cache_examples=False,
213
+ run_on_click=True,
214
+ )
215
+
216
+ query.submit(
217
+ search,
218
+ inputs=[query, top_k, rescore_multiplier, use_approx, display_score, display_binary_rank],
219
+ outputs=[output, json],
220
+ )
221
+ search_button.click(
222
+ search,
223
+ inputs=[query, top_k, rescore_multiplier, use_approx, display_score, display_binary_rank],
224
+ outputs=[output, json],
225
+ )
226
+ display_score.change(
227
+ search,
228
+ inputs=[query, top_k, rescore_multiplier, use_approx, display_score, display_binary_rank],
229
+ outputs=[output, json],
230
+ )
231
+ display_binary_rank.change(
232
+ search,
233
+ inputs=[query, top_k, rescore_multiplier, use_approx, display_score, display_binary_rank],
234
+ outputs=[output, json],
235
+ )
236
 
237
  demo.queue()
238
  demo.launch()
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
- git+https://github.com/tomaarsen/sentence-transformers@feat/quantization
2
  datasets
3
  pandas
4
  huggingface_hub>=0.24.0
5
 
6
- usearch
7
  faiss-cpu
 
1
+ sentence-transformers==5.2.0
2
  datasets
3
  pandas
4
  huggingface_hub>=0.24.0
5
 
 
6
  faiss-cpu