Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import polars as pl | |
| import bm25s | |
| import Stemmer | |
| import stopwordsiso | |
| # favourite_langs = {"English": "en", "Romanian": "ro", "German": "de", "-----": "-----"} | |
| favourite_langs = {"English": "en", "Romanian": "ro", "German": "de"} | |
| options = list(favourite_langs.keys()) | |
| models = ['ENRO', 'DERO'] | |
| def type_search(input_text, sselected_language, tselected_language, model_name, hits=10, search_type="Similarity", toggle_case=True): | |
| if search_type == "Word search": | |
| return search_text(input_text, sselected_language, tselected_language, model_name, hits, search_type, toggle_case) | |
| else: # "Best match search" | |
| return similarity_search(input_text, sselected_language, tselected_language, model_name, hits, search_type, toggle_case) | |
| # English, Romanian | |
| def search_text(input_text, sselected_language, tselected_language, model_name, hits=10, search_type="Similarity", toggle_case=True): | |
| # df = pl.read_csv('hf://datasets/TiberiuCristianLeon/2RO/ENRO/ENRO.tsv', separator='\t') | |
| # df = pl.read_parquet('hf://datasets/TiberiuCristianLeon/RSSNEWS/data/train-00000-of-00001.parquet') | |
| # df = pl.read_parquet('https://huggingface.co/datasets/TiberiuCristianLeon/2RO/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet') | |
| path_to_model = f"https://huggingface.co/api/datasets/TiberiuCristianLeon/2RO/parquet/{model_name.lower()}/train/0.parquet" | |
| df = pl.read_parquet(path_to_model) | |
| if toggle_case: | |
| filtered = df.filter(pl.col(sselected_language).str.contains(input_text).alias("literal")) # case sensitive | |
| else: | |
| filtered = df.filter(pl.col(sselected_language).str.contains(f"(?i){input_text}").alias("literal")) # (?i) case insensitive | |
| # filtered = df.filter(pl.col(sselected_language).str.contains_any([input_text], ascii_case_insensitive=True).alias("contains_any")) | |
| print(toggle_case, filtered.head(hits)) | |
| # print(filtered) | |
| # Extract rows | |
| list_of_arrays = filtered.select([sselected_language, tselected_language]).head(hits) | |
| # for dataframe type="numpy" | |
| # list_of_arrays = filtered.select([sselected_language, tselected_language]).head(hits).to_numpy() | |
| message_text = f'Done! Found {len(list_of_arrays)} entries' | |
| return list_of_arrays, message_text | |
| def similarity_search(input_text, sselected_language, tselected_language, model_name, hits=10, search_type="Similarity", toggle_case=True): | |
| path_to_model = f"https://huggingface.co/api/datasets/TiberiuCristianLeon/2RO/parquet/{model_name.lower()}/train/0.parquet" | |
| df = pl.read_parquet(path_to_model) | |
| df = df.drop_nulls(subset=[sselected_language, sselected_language]) | |
| # Extract both source and target columns | |
| source_corpus = df.select(sselected_language).to_series().to_list() | |
| target_corpus = df.select(tselected_language).to_series().to_list() | |
| # Filter out empty entries and keep track of valid indices | |
| valid_entries = [(src, tgt) for src, tgt in zip(source_corpus, target_corpus)] | |
| # Unpack filtered source and target texts | |
| filtered_source = [entry[0] for entry in valid_entries] | |
| filtered_target = [entry[1] for entry in valid_entries] | |
| # Run BM25 search on filtered source corpus | |
| index_name = f"index{sselected_language}{tselected_language}" | |
| list_of_arrays = bmretriever(filtered_source, input_text, sselected_language, filtered_target, index_name, hits) | |
| message_text = f'Done! Found {len(list_of_arrays)} entries' | |
| return list_of_arrays, message_text | |
| def bmretriever(corpus, query, sselected_language, translations, index_name, k=10): | |
| stemmer = Stemmer.Stemmer(sselected_language.lower()) | |
| try: | |
| corpus_tokens = bm25s.tokenize(corpus, stopwords=favourite_langs[sselected_language], stemmer=stemmer) | |
| except ValueError: | |
| stopwords = stopwordsiso.stopwords[favourite_langs[sselected_language]] | |
| corpus_tokens = bm25s.tokenize(corpus, stopwords=stopwords, stemmer=stemmer) | |
| try: | |
| print('Loading saved retriever index') | |
| retriever = bm25s.BM25.load(index_name, load_corpus=True) | |
| except Exception as loadingerror: | |
| print(loadingerror) | |
| retriever = bm25s.BM25() | |
| retriever.index(corpus_tokens) | |
| retriever.save(index_name, corpus=corpus) # Save the corpus along with the model | |
| query_tokens = bm25s.tokenize(query, stemmer=stemmer) | |
| results, scores = retriever.retrieve(query_tokens, k=k, corpus=corpus) | |
| final_results = [] | |
| for i in range(results.shape[1]): | |
| doc = results[0, i] | |
| score = scores[0, i] | |
| translation = translations[corpus.index(doc)] # Match translation by index | |
| final_results.append((str(doc), str(translation))) | |
| # "score": round(float(score), 2)}) | |
| print(f"Rank {i+1} (score: {score:.2f}): {doc} → {translation}") | |
| return final_results | |
| # Define a function to swap dropdown values | |
| def swap_languages(src_lang, tgt_lang): | |
| return tgt_lang, src_lang | |
| def create_interface(): | |
| with gr.Blocks() as interface: | |
| gr.Markdown("## Search Text in Dataset") | |
| with gr.Row(): | |
| input_text = gr.Textbox(label="Enter text to search:", placeholder="Type your text here...", info="Press Enter key to start search") | |
| with gr.Row(): | |
| sselected_language = gr.Dropdown(choices=options, value = options[0], label="Source language", interactive=True) | |
| tselected_language = gr.Dropdown(choices=options, value = options[1], label="Target language", interactive=True) | |
| swap_button = gr.Button("Swap Languages") | |
| swap_button.click(fn=swap_languages, inputs=[sselected_language, tselected_language], outputs=[sselected_language, tselected_language]) | |
| search_type = gr.Radio(["Best match search", "Word search"], value="Best match search", label="Search type", info="Query word(s) or best match search with BM25") | |
| toggle_case = gr.Checkbox(info="Toggle case sensitive search", label="Case sensitive search", value=True, interactive=True, visible=True) | |
| model_name = gr.Dropdown(choices=models, label="Select a dataset", value = models[0], interactive=True) | |
| search_button = gr.Button("Search") | |
| translated_text = gr.Dataframe(label="Returned entries:", interactive=False, headers=[options[0], options [1]], datatype=["str", "str"], col_count=(2, "fixed"), | |
| wrap=True, show_row_numbers=False, show_copy_button=True) | |
| message_text = gr.Textbox(label="Messages:", placeholder="Display field for status and error messages", interactive=False) | |
| hits = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=10, | |
| step=5, | |
| label="Number of returned hits") | |
| search_button.click( | |
| type_search, | |
| inputs=[input_text, sselected_language, tselected_language, model_name, hits, search_type, toggle_case], | |
| outputs=[translated_text, message_text] | |
| ) | |
| # Submit the form when Enter is pressed in the input_text textbox | |
| input_text.submit( | |
| type_search, | |
| inputs=[input_text, sselected_language, tselected_language, model_name, hits, search_type, toggle_case], | |
| outputs=[translated_text, message_text] | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| interface.launch() | |