Spaces:
Build error
Build error
| import time | |
| import gradio as gr | |
| from datasets import load_dataset | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers.util import quantize_embeddings | |
| import faiss | |
| from usearch.index import Index | |
| # Load titles and texts | |
| title_text_dataset = load_dataset("mixedbread-ai/wikipedia-2023-11-embed-en-pre-1", split="train").select_columns(["title", "text"]) | |
| # Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it. | |
| int8_view = Index.restore("wikipedia_int8_usearch_1m.index", view=True) | |
| binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_ubinary_faiss_1m.index") | |
| # Load the SentenceTransformer model for embedding the queries | |
| model = SentenceTransformer( | |
| "mixedbread-ai/mxbai-embed-large-v1", | |
| prompts={ | |
| "retrieval": "Represent this sentence for searching relevant passages: ", | |
| }, | |
| default_prompt_name="retrieval", | |
| ) | |
| def search(query, top_k: int = 10, rerank_multiplier: int = 4): | |
| # 1. Embed the query as float32 | |
| start_time = time.time() | |
| query_embedding = model.encode(query) | |
| embed_time = time.time() - start_time | |
| # 2. Quantize the query to ubinary | |
| start_time = time.time() | |
| query_embedding_ubinary = quantize_embeddings(query_embedding, "ubinary") | |
| quantize_time = time.time() - start_time | |
| # 3. Search the binary index | |
| start_time = time.time() | |
| _scores, binary_ids = binary_index.search(query_embedding_ubinary, top_k * rerank_multiplier) | |
| binary_ids = binary_ids[0] | |
| search_time = time.time() - start_time | |
| # 4. Load the corresponding int8 embeddings | |
| start_time = time.time() | |
| int8_embeddings = int8_view[binary_ids].astype(int) | |
| load_time = time.time() - start_time | |
| # 5. Rerank the top_k * rerank_multiplier using the float32 query embedding and the int8 document embeddings | |
| start_time = time.time() | |
| scores = query_embedding @ int8_embeddings.T | |
| rerank_time = time.time() - start_time | |
| # 6. Sort the scores and return the top_k | |
| start_time = time.time() | |
| top_k_indices = (-scores).argsort()[-top_k:] | |
| top_k_scores = scores[top_k_indices] | |
| top_k_titles, top_k_texts = zip(*[(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) for idx in binary_ids[top_k_indices].tolist()]) | |
| df = pd.DataFrame({"Score": [round(value, 2) for value in top_k_scores], "Title": top_k_titles, "Text": top_k_texts}) | |
| sort_time = time.time() - start_time | |
| return df, { | |
| "Embed Time": f"{embed_time:.4f} s", | |
| "Quantize Time": f"{quantize_time:.4f} s", | |
| "Search Time": f"{search_time:.4f} s", | |
| "Load Time": f"{load_time:.4f} s", | |
| "Rerank Time": f"{rerank_time:.4f} s", | |
| "Sort Time": f"{sort_time:.4f} s", | |
| "Total Retrieval Time": f"{quantize_time + search_time + load_time + rerank_time + sort_time:.4f} s" | |
| } | |
| with gr.Blocks(title="Quantized Retrieval") as demo: | |
| query = gr.Textbox(label="Query") | |
| search_button = gr.Button(value="Search") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| output = gr.Dataframe(column_widths=["10%", "20%", "80%"], headers=["Score", "Title", "Text"]) | |
| with gr.Column(scale=1): | |
| json = gr.JSON() | |
| search_button.click(search, inputs=[query], outputs=[output, json]) | |
| demo.queue() | |
| demo.launch(debug=True) | |