|
|
from typing import List |
|
|
from dataclasses import asdict |
|
|
import pandas as pd |
|
|
import gradio as gr |
|
|
|
|
|
from SmartSearch.database.chromadb import ChromaDB |
|
|
from SmartSearch.providers.SentenceTransformerEmbedding import SentenceTransformerEmbedding |
|
|
from utils import combine_metadata_with_distance |
|
|
st_chroma = ChromaDB( |
|
|
embedding_function=SentenceTransformerEmbedding(model_name='all-mpnet-base-v2'), |
|
|
collection_name="books_collection" |
|
|
) |
|
|
|
|
|
multilingual_chroma = ChromaDB( |
|
|
embedding_function=SentenceTransformerEmbedding(model_name='paraphrase-multilingual-mpnet-base-v2'), |
|
|
collection_name="books_collection" |
|
|
) |
|
|
|
|
|
|
|
|
def search_novels(query, k, model_type): |
|
|
if model_type == 'base': |
|
|
result = st_chroma.search(query_text=query, n_results=k) |
|
|
else: |
|
|
result = multilingual_chroma.search(query_text=query, n_results=k) |
|
|
|
|
|
result = combine_metadata_with_distance(result['metadatas'], result['distances']) |
|
|
result = pd.DataFrame(result) |
|
|
return result |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
with gr.Row(): |
|
|
query = gr.Textbox(label="Search Query", placeholder="write a query to find the courses") |
|
|
with gr.Row(): |
|
|
search_type = gr.Dropdown(label="Model", choices=['base', 'multilingual'], value='base') |
|
|
k = gr.Number(label="Items Count", value=10) |
|
|
|
|
|
results = gr.Dataframe(label="Search Results") |
|
|
|
|
|
search_button = gr.Button("Search", variant='primary') |
|
|
search_button.click(fn=search_novels, inputs=[query, k, search_type], outputs=results) |
|
|
|
|
|
demo.launch() |