Spaces:
Build error
Build error
| import streamlit as st | |
| import pandas as pd | |
| from pathlib import Path | |
| #from transformers import MBartForConditionalGeneration, MBart50TokenizerFast | |
| from transformers import M2M100ForConditionalGeneration | |
| from tokenization_small100 import SMALL100Tokenizer | |
| import io | |
| st.set_page_config(page_title="Translation Demo", page_icon=":milky_way:", layout="wide") | |
| def load_model(): | |
| model = M2M100ForConditionalGeneration.from_pretrained("alirezamsh/small100") | |
| return model | |
| def get_translation(src_code, trg_code, src): | |
| #tokenizer.src_lang = src_code | |
| #encoded = tokenizer(src, return_tensors="pt") | |
| #generated_tokens = model.generate( | |
| #**encoded, | |
| #forced_bos_token_id=tokenizer.lang_code_to_id[trg_code] | |
| #) | |
| #trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
| model = load_model() | |
| tokenizer.tgt_lang = trg_code | |
| encoded = tokenizer(src, return_tensors="pt") | |
| generated_tokens = model.generate(**encoded) | |
| trg = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
| return trg | |
| def open_input(the_file): | |
| sheets = [] | |
| if the_file.name.endswith('.tsv'): | |
| parsed = pd.read_csv(the_file, sep="\t") | |
| elif the_file.name.endswith('.xlsx'): | |
| xlsx = pd.ExcelFile(the_file) | |
| if len(xlsx.sheet_names) > 1: | |
| sheets = [sheet for sheet in xlsx.sheet_names] | |
| parsed = [pd.read_excel(xlsx, sheet) for sheet in sheets] | |
| else: | |
| parsed = pd.read_excel(the_file) | |
| return parsed, sheets | |
| def translate_data(df, s_lang, t_lang, col_for_translation, languages): | |
| translated_data = [] | |
| new_df = df | |
| for text in df[col_for_translation]: | |
| if len(text) > 0 and s_lang in languages and t_lang in languages: | |
| with st.spinner("Translating..."): | |
| try: | |
| target_text = get_translation(s_lang, t_lang, text)[0] | |
| translated_data.append(target_text) | |
| except: | |
| st.subheader("Translation failed :sad:") | |
| break | |
| else: | |
| st.write("Please enter the source text, source language and target language.") | |
| new_df["SMALL-100 translation"] = translated_data | |
| return new_df | |
| def select_column(data, valid_source, valid_target, is_excel=False): | |
| if is_excel: | |
| columns = (col for col in data[0].columns) | |
| else: | |
| columns = (col for col in data.columns) | |
| src_col = st.selectbox( | |
| 'Select the column to translate (WARNING: You can only select a single column - please make sure all columns are named accordingly):', | |
| columns, | |
| ) | |
| if src_col: | |
| col_src_lang = st.selectbox( | |
| 'Source language:', | |
| valid_source, | |
| ) | |
| col_trg_lang = st.selectbox( | |
| 'Target language:', | |
| valid_target, | |
| ) | |
| submitted_cols = st.button("Translate column") | |
| return submitted_cols, src_col, col_src_lang, col_trg_lang | |
| st.subheader("SMALL-100 Translator") | |
| source = "In the beginning the Universe was created. This has made a lot of people very angry and been widely regarded as a bad move." | |
| target = "" | |
| #model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") | |
| #tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") | |
| tokenizer = SMALL100Tokenizer.from_pretrained("alirezamsh/small100") | |
| #valid_languages = ['de_DE', 'en_XX', 'it_IT'] | |
| valid_languages = ['de', 'it', 'en', 'fr', 'sw', 'wo'] | |
| valid_languages_tuple = (lang for lang in valid_languages) | |
| valid_languages_tuple_trg = (lang for lang in valid_languages) | |
| with st.form("my_form"): | |
| left_c, right_c = st.columns(2) | |
| #with left_c: | |
| src_lang = st.selectbox( | |
| 'Source language', | |
| valid_languages_tuple, | |
| ) | |
| #with right_c: | |
| trg_lang = st.selectbox( | |
| 'Target language', | |
| valid_languages_tuple_trg, | |
| ) | |
| source = st.text_area("Source", value=source, height=130, placeholder="Enter the source text...") | |
| submitted = st.form_submit_button("Translate") | |
| if submitted: | |
| if len(source) > 0 and src_lang in valid_languages and trg_lang in valid_languages: | |
| with st.spinner("Translating..."): | |
| try: | |
| target = get_translation(src_lang, trg_lang, source)[0] | |
| st.subheader("Translation done!") | |
| target = st.text_area("Target", value=target, height=130) | |
| except: | |
| st.subheader("Translation failed :sad:") | |
| else: | |
| st.write("Please enter the source text, source language and target language.") | |
| st.subheader('Input XLSX/TSV') | |
| uploaded_file = st.file_uploader("Choose a file") | |
| done = False | |
| if uploaded_file is not None: | |
| valid_col = (lang for lang in valid_languages) | |
| valid_col_trg = (lang for lang in valid_languages) | |
| data, sheets = open_input(uploaded_file) | |
| if len(sheets) > 0: | |
| translated_sheets = [] | |
| submitted_cols, src_col, src_code, trg_code = select_column(data, valid_col, valid_col_trg, is_excel=True) | |
| if submitted_cols: | |
| for sheet in data: | |
| translated_sheets.append(translate_data(sheet, src_code, trg_code, src_col, valid_languages)) | |
| done = True | |
| else: | |
| submitted_cols, src_col, valid_col, valid_col_trg = select_column(data, valid_col, valid_col_trg) | |
| st.subheader("DataFrame") | |
| st.write(data) | |
| st.write(data.describe()) | |
| if submitted_cols: | |
| new_df = translate_data(data, valid_col, valid_col_trg, src_col, valid_languages) | |
| done = True | |
| if done: | |
| st.subheader("Translated DataFrame") | |
| if len(sheets) > 0: | |
| pass | |
| buffer = io.BytesIO() | |
| with pd.ExcelWriter(buffer) as writer: | |
| for idx, sheet in enumerate(translated_sheets): | |
| sheet.to_excel(writer, sheet_name=sheets[idx]) | |
| st.download_button('Download XLSX', buffer, 'translated_file.xlsx', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', key='download-xlsx') | |
| else: | |
| st.write(new_df) | |
| st.write(new_df.describe()) | |
| to_dl = new_df.to_csv(index=False, sep='\t').encode('utf-8') | |
| st.download_button('Download TSV', to_dl, 'translated_file.tsv', 'text/tsv', key='download-tsv') | |
| else: | |
| st.info("☝️ Upload a TSV file") |