Spaces:
Sleeping
Sleeping
| import time | |
| import gradio as gr | |
| from datasets import load_dataset | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers.quantization import quantize_embeddings | |
| import faiss | |
| from usearch.index import Index | |
| import datetime | |
| # Load titles and texts | |
| title_text_dataset = load_dataset("suanan/BP_CBG_POC", split="train", num_proc=4).select_columns(["url", "title", "text"]) | |
| # Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it. | |
| int8_view = Index.restore("index/BP_CBG_int8_usearch_1m_v2.index", view=True) | |
| binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("index/BP_CBG_ubinary_faiss_1m_v2.index") | |
| # binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary("BP_ubinary_ivf_faiss_50m.index") | |
| # Load the SentenceTransformer model for embedding the queries | |
| model = SentenceTransformer( | |
| "BAAI/bge-large-zh-v1.5", | |
| prompts={ | |
| "retrieval": "Represent this sentence for searching relevant passages: ", | |
| }, | |
| default_prompt_name="retrieval", | |
| ) | |
| def search(query, top_k: int = 100, rescore_multiplier: int = 1, use_approx: bool = False): | |
| # 獲取當前時間 | |
| now = datetime.datetime.now() | |
| print(f"當前時間: {now}, 問題: {query}") | |
| # 1. Embed the query as float32 | |
| start_time = time.time() | |
| query_embedding = model.encode(query) | |
| embed_time = time.time() - start_time | |
| # 2. Quantize the query to ubinary | |
| start_time = time.time() | |
| query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary") | |
| quantize_time = time.time() - start_time | |
| # 3. Search the binary index (either exact or approximate) | |
| # index = binary_ivf if use_approx else binary_index | |
| index = binary_index | |
| start_time = time.time() | |
| _scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier) | |
| binary_ids = binary_ids[0] | |
| search_time = time.time() - start_time | |
| # 4. Load the corresponding int8 embeddings | |
| start_time = time.time() | |
| int8_embeddings = int8_view[binary_ids].astype(int) | |
| load_time = time.time() - start_time | |
| # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings | |
| start_time = time.time() | |
| scores = query_embedding @ int8_embeddings.T | |
| rescore_time = time.time() - start_time | |
| # 6. Sort the scores and return the top_k | |
| start_time = time.time() | |
| indices = scores.argsort()[::-1][:top_k] | |
| top_k_indices = binary_ids[indices] | |
| top_k_scores = scores[indices] | |
| top_k_urls, top_k_titles, top_k_texts = zip( | |
| *[(title_text_dataset[idx]["url"], title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) for idx in top_k_indices.tolist()] | |
| ) | |
| df = pd.DataFrame( | |
| {"Score": [round(value, 2) for value in top_k_scores], "Url": top_k_urls, "Title": top_k_titles, "Text": top_k_texts} | |
| ) | |
| sort_time = time.time() - start_time | |
| return df, { | |
| "Embed Time": f"{embed_time:.4f} s", | |
| "Quantize Time": f"{quantize_time:.4f} s", | |
| "Search Time": f"{search_time:.4f} s", | |
| "Load Time": f"{load_time:.4f} s", | |
| "Rescore Time": f"{rescore_time:.4f} s", | |
| "Sort Time": f"{sort_time:.4f} s", | |
| "Total search Time": f"{quantize_time + search_time + load_time + rescore_time + sort_time:.4f} s", | |
| } | |
| def update_info(value): | |
| return f"{value}筆顯示出來" | |
| with gr.Blocks(title="") as demo: | |
| gr.Markdown( | |
| """ | |
| ## 官網 Dataset & opensource model BAAI/bge-m3 | |
| ### v1 測試POC | |
| Details: | |
| 1. 中文搜尋ok,英文像是:iphone 15,embedding的時候沒有轉成小寫,需要 寫成iPhone才可以準確搜尋到 | |
| 2. 環境資源: python 3.10, linux: ubuntu 22.04, only cpu, ram max:7.7GB min:4.5GB 使用以上資源 | |
| 3. | |
| 建立步驟: | |
| 1. excel 轉成 [dataset](https://huggingface.co/datasets/suanan/BP_POC) [CBG_dataset](https://huggingface.co/datasets/suanan/BP_CBG_POC), 花費約10秒內 | |
| 2. dataset 內 轉成 title & text 做 embedding,以後可以新增keyword來加強搜尋出來的結果排序往前 | |
| 3. 之後透過 Quantized Retrieval - Binary Search solution進行搜尋 | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=75): | |
| query = gr.Textbox( | |
| label="官網 Dataset & opensource model BAAI/bge-m3, v1 測試POC", | |
| placeholder="輸入搜尋關鍵字或問句", | |
| ) | |
| with gr.Column(scale=25): | |
| use_approx = gr.Radio( | |
| choices=[("精確搜尋", False), ("相關搜尋", True)], | |
| value=False, | |
| label="搜尋方法", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| top_k = gr.Slider( | |
| minimum=10, | |
| maximum=1000, | |
| step=5, | |
| value=100, | |
| label="顯示搜尋前幾筆", | |
| ) | |
| info_text = gr.Textbox(value=update_info(top_k.value), interactive=False) | |
| with gr.Column(scale=2): | |
| rescore_multiplier = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| value=1, | |
| label="Rescore multiplier", | |
| info="Search for `rescore_multiplier` as many documents to rescore", | |
| ) | |
| search_button = gr.Button(value="Search") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| output = gr.Dataframe(headers=["Score", "Title", "Text"]) | |
| with gr.Column(scale=1): | |
| json = gr.JSON() | |
| top_k.change(fn=update_info, inputs=top_k, outputs=info_text) | |
| query.submit(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json]) | |
| search_button.click(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json]) | |
| demo.queue() | |
| demo.launch(share=True) | |