Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Louis Brulé Naudet. All Rights Reserved. | |
| # This software may be used and distributed according to the terms of the License Agreement. | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import gradio as gr | |
| import polars as pl | |
| import spaces | |
| import torch | |
| from typing import Tuple, List, Union | |
| from dataset import Dataset | |
| from similarity_search import SimilaritySearch | |
| def setup( | |
| description: str, | |
| model_name: str, | |
| device: str, | |
| ndim: int, | |
| metric: str, | |
| dtype: str | |
| ) -> Tuple: | |
| """ | |
| Set up the model and tokenizer for a given pre-trained model ID. | |
| Parameters | |
| ---------- | |
| description : str | |
| A string containing additional description information. | |
| model_name : str | |
| Name of the pre-trained model to load. | |
| device : str | |
| Device to run the model on, e.g., 'cuda' or 'cpu'. | |
| ndim : int | |
| Dimensionality of the model. | |
| metric : str | |
| Metric for similarity search. | |
| dtype : str | |
| Data type of the model. | |
| Returns | |
| ------- | |
| instance : SimilaritySearch | |
| A class dedicated to encoding text data, quantizing embeddings, and managing indices for efficient similarity search. | |
| dataset : datasets.Dataset | |
| The loaded dataset. | |
| dataframe: pl.DataFrame | |
| A Polars DataFrame representing the dataset. | |
| description : str | |
| A string containing additional description information. | |
| """ | |
| try: | |
| instance = SimilaritySearch( | |
| model_name=model_name, | |
| device=device, | |
| ndim=ndim, | |
| metric=metric, | |
| dtype=dtype | |
| ) | |
| instance.load_usearch_index_view( | |
| index_path="./usearch_int8.index", | |
| ) | |
| instance.load_faiss_index( | |
| index_path="./faiss_ubinary.index", | |
| ) | |
| dataset = Dataset.load( | |
| dataset_path="./legalkit.hf" | |
| ) | |
| dataframe = Dataset.convert_to_polars( | |
| dataset=dataset | |
| ) | |
| return instance, dataset, dataframe, description | |
| except Exception as e: | |
| error_message = f"An error occurred during setup: {str(e)}" | |
| raise RuntimeError(error_message) from e | |
| DESCRIPTION = """\ | |
| # LegalKit Retrieval, a binary Search with Scalar (int8) Rescoring through French legal codes | |
| This space showcases the [tsdae-lemone-mbert-base](https://huggingface.co/louisbrulenaudet/tsdae-lemone-mbert-base) | |
| model by Louis Brulé Naudet, a sentence embedding model based on BERT fitted using Transformer-based Sequential Denoising Auto-Encoder for unsupervised sentence embedding learning with one objective : french legal domain adaptation. | |
| This process is designed to be memory efficient and fast, with the binary index being small enough to fit in memory and the int8 index being loaded as a view to save memory. | |
| Additionally, the binary index is much faster (up to 32x) to search than the float32 index, while the rescoring is also extremely efficient. | |
| """ | |
| instance, dataset, dataframe, DESCRIPTION = setup( | |
| model_name="louisbrulenaudet/tsdae-lemone-mbert-base", | |
| description=DESCRIPTION, | |
| device="cpu", | |
| ndim=768, | |
| metric="ip", | |
| dtype="i8" | |
| ) | |
| def search( | |
| query:str, | |
| top_k:int, | |
| rescore_multiplier:int | |
| ) -> any: | |
| """ | |
| Perform a search operation using the initialized GPU space. | |
| Parameters | |
| ---------- | |
| query : str | |
| The query for which similarity search is performed. | |
| top_k : int | |
| The number of top results to return. | |
| rescore_multiplier : int | |
| A multiplier for rescore operation. | |
| Returns | |
| ------- | |
| any | |
| The search results in a suitable format. | |
| Notes | |
| ----- | |
| This function performs a search operation using the initialized GPU space | |
| and returns the search results in a format compatible with the provided | |
| space. | |
| Examples | |
| -------- | |
| >>> results = search(query="example query", top_k=10, rescore_multiplier=2) | |
| """ | |
| global instance | |
| global dataset | |
| global dataframe | |
| top_k_scores, top_k_indices = instance.search( | |
| query=query, | |
| top_k=top_k, | |
| rescore_multiplier=rescore_multiplier | |
| ) | |
| scores_df = pl.DataFrame( | |
| { | |
| "index": top_k_indices, | |
| "score": top_k_scores | |
| } | |
| ).with_columns( | |
| pl.col("index").cast(pl.UInt32) | |
| ) | |
| results_df = dataframe.filter( | |
| pl.col("index").is_in(top_k_indices) | |
| ).join( | |
| scores_df, | |
| how="inner", | |
| on="index" | |
| ).sort( | |
| by="score", | |
| descending=True | |
| ).select( | |
| [ | |
| "score", | |
| "input", | |
| "output", | |
| "start", | |
| "expiration" | |
| ] | |
| ) | |
| return gr.Dataframe( | |
| value=results_df, | |
| visible=True | |
| ) | |
| with gr.Blocks(title="Quantized Retrieval") as demo: | |
| gr.Markdown( | |
| value=DESCRIPTION | |
| ) | |
| gr.DuplicateButton() | |
| with gr.Row(): | |
| with gr.Column(): | |
| query = gr.Textbox( | |
| label="Query for French legal codes articles", | |
| placeholder="Enter a query to search for relevant texts from the French law.", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=20, | |
| label="Number of documents to retrieve", | |
| info="Number of documents to retrieve from the binary search.", | |
| ) | |
| with gr.Column(scale=2): | |
| rescore_multiplier = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| value=4, | |
| label="Rescore multiplier", | |
| info="Search for 'rescore_multiplier' as many documents to rescore.", | |
| ) | |
| search_button = gr.Button(value="Search") | |
| with gr.Row(): | |
| with gr.Column(): | |
| output = gr.Dataframe( | |
| visible=False, | |
| type="polars" | |
| ) | |
| query.submit( | |
| search, | |
| inputs=[ | |
| query, | |
| top_k, | |
| rescore_multiplier | |
| ], | |
| outputs=output | |
| ) | |
| search_button.click( | |
| search, | |
| inputs=[ | |
| query, | |
| top_k, | |
| rescore_multiplier | |
| ], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch( | |
| show_api=False | |
| ) |