SearchDataset / app.py
TiberiuCristianLeon's picture
Update app.py
9d10f73 verified
import gradio as gr
import polars as pl
import bm25s
import Stemmer
import stopwordsiso
# favourite_langs = {"English": "en", "Romanian": "ro", "German": "de", "-----": "-----"}
favourite_langs = {"English": "en", "Romanian": "ro", "German": "de"}
options = list(favourite_langs.keys())
models = ['ENRO', 'DERO']
def type_search(input_text, sselected_language, tselected_language, model_name, hits=10, search_type="Similarity", toggle_case=True):
if search_type == "Word search":
return search_text(input_text, sselected_language, tselected_language, model_name, hits, search_type, toggle_case)
else: # "Best match search"
return similarity_search(input_text, sselected_language, tselected_language, model_name, hits, search_type, toggle_case)
# English, Romanian
def search_text(input_text, sselected_language, tselected_language, model_name, hits=10, search_type="Similarity", toggle_case=True):
# df = pl.read_csv('hf://datasets/TiberiuCristianLeon/2RO/ENRO/ENRO.tsv', separator='\t')
# df = pl.read_parquet('hf://datasets/TiberiuCristianLeon/RSSNEWS/data/train-00000-of-00001.parquet')
# df = pl.read_parquet('https://huggingface.co/datasets/TiberiuCristianLeon/2RO/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet')
path_to_model = f"https://huggingface.co/api/datasets/TiberiuCristianLeon/2RO/parquet/{model_name.lower()}/train/0.parquet"
df = pl.read_parquet(path_to_model)
if toggle_case:
filtered = df.filter(pl.col(sselected_language).str.contains(input_text).alias("literal")) # case sensitive
else:
filtered = df.filter(pl.col(sselected_language).str.contains(f"(?i){input_text}").alias("literal")) # (?i) case insensitive
# filtered = df.filter(pl.col(sselected_language).str.contains_any([input_text], ascii_case_insensitive=True).alias("contains_any"))
print(toggle_case, filtered.head(hits))
# print(filtered)
# Extract rows
list_of_arrays = filtered.select([sselected_language, tselected_language]).head(hits)
# for dataframe type="numpy"
# list_of_arrays = filtered.select([sselected_language, tselected_language]).head(hits).to_numpy()
message_text = f'Done! Found {len(list_of_arrays)} entries'
return list_of_arrays, message_text
def similarity_search(input_text, sselected_language, tselected_language, model_name, hits=10, search_type="Similarity", toggle_case=True):
path_to_model = f"https://huggingface.co/api/datasets/TiberiuCristianLeon/2RO/parquet/{model_name.lower()}/train/0.parquet"
df = pl.read_parquet(path_to_model)
df = df.drop_nulls(subset=[sselected_language, sselected_language])
# Extract both source and target columns
source_corpus = df.select(sselected_language).to_series().to_list()
target_corpus = df.select(tselected_language).to_series().to_list()
# Filter out empty entries and keep track of valid indices
valid_entries = [(src, tgt) for src, tgt in zip(source_corpus, target_corpus)]
# Unpack filtered source and target texts
filtered_source = [entry[0] for entry in valid_entries]
filtered_target = [entry[1] for entry in valid_entries]
# Run BM25 search on filtered source corpus
index_name = f"index{sselected_language}{tselected_language}"
list_of_arrays = bmretriever(filtered_source, input_text, sselected_language, filtered_target, index_name, hits)
message_text = f'Done! Found {len(list_of_arrays)} entries'
return list_of_arrays, message_text
def bmretriever(corpus, query, sselected_language, translations, index_name, k=10):
stemmer = Stemmer.Stemmer(sselected_language.lower())
try:
corpus_tokens = bm25s.tokenize(corpus, stopwords=favourite_langs[sselected_language], stemmer=stemmer)
except ValueError:
stopwords = stopwordsiso.stopwords[favourite_langs[sselected_language]]
corpus_tokens = bm25s.tokenize(corpus, stopwords=stopwords, stemmer=stemmer)
try:
print('Loading saved retriever index')
retriever = bm25s.BM25.load(index_name, load_corpus=True)
except Exception as loadingerror:
print(loadingerror)
retriever = bm25s.BM25()
retriever.index(corpus_tokens)
retriever.save(index_name, corpus=corpus) # Save the corpus along with the model
query_tokens = bm25s.tokenize(query, stemmer=stemmer)
results, scores = retriever.retrieve(query_tokens, k=k, corpus=corpus)
final_results = []
for i in range(results.shape[1]):
doc = results[0, i]
score = scores[0, i]
translation = translations[corpus.index(doc)] # Match translation by index
final_results.append((str(doc), str(translation)))
# "score": round(float(score), 2)})
print(f"Rank {i+1} (score: {score:.2f}): {doc}{translation}")
return final_results
# Define a function to swap dropdown values
def swap_languages(src_lang, tgt_lang):
return tgt_lang, src_lang
def create_interface():
with gr.Blocks() as interface:
gr.Markdown("## Search Text in Dataset")
with gr.Row():
input_text = gr.Textbox(label="Enter text to search:", placeholder="Type your text here...", info="Press Enter key to start search")
with gr.Row():
sselected_language = gr.Dropdown(choices=options, value = options[0], label="Source language", interactive=True)
tselected_language = gr.Dropdown(choices=options, value = options[1], label="Target language", interactive=True)
swap_button = gr.Button("Swap Languages")
swap_button.click(fn=swap_languages, inputs=[sselected_language, tselected_language], outputs=[sselected_language, tselected_language])
search_type = gr.Radio(["Best match search", "Word search"], value="Best match search", label="Search type", info="Query word(s) or best match search with BM25")
toggle_case = gr.Checkbox(info="Toggle case sensitive search", label="Case sensitive search", value=True, interactive=True, visible=True)
model_name = gr.Dropdown(choices=models, label="Select a dataset", value = models[0], interactive=True)
search_button = gr.Button("Search")
translated_text = gr.Dataframe(label="Returned entries:", interactive=False, headers=[options[0], options [1]], datatype=["str", "str"], col_count=(2, "fixed"),
wrap=True, show_row_numbers=False, show_copy_button=True)
message_text = gr.Textbox(label="Messages:", placeholder="Display field for status and error messages", interactive=False)
hits = gr.Slider(
minimum=1,
maximum=100,
value=10,
step=5,
label="Number of returned hits")
search_button.click(
type_search,
inputs=[input_text, sselected_language, tselected_language, model_name, hits, search_type, toggle_case],
outputs=[translated_text, message_text]
)
# Submit the form when Enter is pressed in the input_text textbox
input_text.submit(
type_search,
inputs=[input_text, sselected_language, tselected_language, model_name, hits, search_type, toggle_case],
outputs=[translated_text, message_text]
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch()