Spaces:
Sleeping
Sleeping
zyu
fix: resolved the issue that the input text disappears while generating translation for the first run.
3cb0c3e
| import json | |
| import os | |
| import random | |
| import re | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def load_model(model_name, tokenizer_name): | |
| try: | |
| model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
| except OSError as e: | |
| st.error(f"Error loading model: {e}") | |
| st.error(f"Model not found. Use {DEFAULT_MODEL} instead") | |
| model_path = DEFAULT_MODEL | |
| model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_path) | |
| tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL) | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| raise RuntimeError("Error loading model") | |
| return model, tokenizer | |
| def load_json(file_path): | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| return data | |
| def preprocess(input_text, tokenizer, src_lang, tgt_lang): | |
| # task_prefix = f"translate {src_lang} to {tgt_lang}: " | |
| # input_text = task_prefix + input_text | |
| model_inputs = tokenizer( | |
| input_text, max_length=MAX_SEQ_LEN, padding="max_length", truncation=True, return_tensors="np" | |
| ) | |
| return model_inputs | |
| def translate(input_text, model, tokenizer, src_lang, tgt_lang): | |
| model_inputs = preprocess(input_text, tokenizer, src_lang, tgt_lang) | |
| model_outputs = model.generate(**model_inputs, num_beams=NUM_BEAMS) | |
| prediction = tokenizer.batch_decode(model_outputs.sequences, skip_special_tokens=True) | |
| return prediction[0] | |
| def hold_deterministic(seed): | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| random.seed(seed) | |
| def postprocess(output_text): | |
| output = re.sub(r"<extra_id[^>]*>", "", output_text) | |
| return output | |
| def display_ui(): | |
| st.set_page_config(page_title="DP-NMT DEMO", layout="wide") | |
| st.title("Neural Machine Translation with DP-SGD") | |
| st.write( | |
| "[](https://github.com/trusthlt/dp-nmt)" | |
| " " | |
| "[](https://aclanthology.org/2024.eacl-demo.11/)" | |
| ) | |
| st.write("This is a demo for private neural machine translation with DP-SGD.") | |
| left, right = st.columns(2) | |
| return left, right | |
| def load_selected_model(config, dataset, language_pair, epsilon): | |
| ckpt = config[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)] | |
| logger.info(f"Loading model from {ckpt}") | |
| if "privalingo" in ckpt: | |
| model_path = ckpt # load model from huggingface hub | |
| else: | |
| model_name = DEFAULT_MODEL.split('/')[-1] | |
| model_path = os.path.join(CHECKPOINTS_DIR, ckpt, model_name) | |
| if not os.path.exists(model_path): | |
| st.error(f"Model not found. Using default model: {DEFAULT_MODEL}") | |
| model_path = DEFAULT_MODEL | |
| return model_path | |
| def init_session_state(): | |
| if 'model_state' not in st.session_state: | |
| st.session_state.model_state = { | |
| 'loaded': False, | |
| 'current_config': None | |
| } | |
| if 'translate_in_progress' not in st.session_state: | |
| st.session_state.translate_in_progress = False | |
| if "load_model_in_progress" not in st.session_state: | |
| st.session_state.load_model_in_progress = False | |
| if "select_model_button" in st.session_state and st.session_state.select_model_button == True: | |
| st.session_state.load_model_in_progress = True | |
| if 'translate_button' in st.session_state and st.session_state.translate_button == True: | |
| st.session_state.translate_in_progress = True | |
| if 'translation_result' not in st.session_state: | |
| st.session_state.translation_result = { | |
| 'input': None, | |
| 'output': None | |
| } | |
| def get_translation_result(): | |
| if "translation_result" in st.session_state and st.session_state.translation_result['input'] is not None: | |
| input_text_content = st.session_state.translation_result['input'] | |
| else: | |
| input_text_content = "Enter Text Here" | |
| if "translation_result" in st.session_state and st.session_state.translation_result['output'] is not None: | |
| output_text_content = st.session_state.translation_result['output'] | |
| else: | |
| output_text_content = None | |
| return input_text_content, output_text_content | |
| def set_input_text_content(): | |
| if 'input_text' in st.session_state: | |
| st.session_state.translation_result['input'] = st.session_state.input_text | |
| def main(): | |
| hold_deterministic(SEED) | |
| config = load_json(DATASETS_MODEL_INFO_PATH) | |
| left, right = display_ui() | |
| init_session_state() | |
| with right: | |
| right_placeholder = st.empty() | |
| if st.session_state.load_model_in_progress: | |
| # Placeholder for right column, to display the input text area and translation result. If do not overwrite the | |
| # right column from previous run, the translate button and input text area will be available for user to interace | |
| # during the loading of model. | |
| disable = True | |
| with right_placeholder.container(): | |
| input_text_content, output_text_content = get_translation_result() | |
| input_text = st.text_area("Enter Text", input_text_content, max_chars=MAX_INPUT_LEN, disabled=disable) | |
| msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \ | |
| if st.session_state.model_state['current_config'] is None \ | |
| else f"Current Model: {st.session_state.model_state['current_config']}" | |
| st.write(msg_model) | |
| btn_translate = st.button("Translate", | |
| disabled=disable, | |
| use_container_width=True, | |
| key="translate_button") | |
| with left: | |
| disable = st.session_state.translate_in_progress or st.session_state.load_model_in_progress | |
| dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()), disabled=disable) | |
| language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys()) | |
| language_pair = st.selectbox("Language pair for translation", language_pairs_list, disabled=disable) | |
| src_lang, tgt_lang = language_pair.split("-") | |
| epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys()) | |
| epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True, disabled=disable) | |
| btn_select_model = st.button( | |
| "Select Model", | |
| disabled=disable, | |
| use_container_width=True, | |
| key="select_model_button") | |
| model_status_box = st.empty() | |
| # Load model to cache, if the user has selected a model for the first time | |
| if btn_select_model: | |
| st.session_state.load_model_in_progress = True | |
| current_config = f"{dataset}_{language_pair}_{epsilon}" | |
| st.session_state.model_state['loaded'] = False | |
| model_status_box.write("") | |
| with st.spinner(f'Loading model trained on {dataset} with epsilon {epsilon}...'): | |
| model_path = load_selected_model(config, dataset, language_pair, epsilon) | |
| model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL) | |
| model_status_box.success('Model loaded!') | |
| st.session_state.model_state['current_config'] = current_config | |
| st.session_state.load_model_in_progress = False | |
| st.rerun() | |
| with right_placeholder.container(): | |
| disable = st.session_state.load_model_in_progress or st.session_state.translate_in_progress | |
| input_text_content, output_text_content = get_translation_result() | |
| input_text = st.text_area( | |
| "Enter Text", | |
| input_text_content, | |
| max_chars=MAX_INPUT_LEN, | |
| disabled=disable, | |
| key="input_text", | |
| on_change=set_input_text_content, | |
| ) | |
| msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \ | |
| if st.session_state.model_state['current_config'] is None \ | |
| else f"Current Model: {st.session_state.model_state['current_config']}" | |
| st.write(msg_model) | |
| btn_translate = st.button("Translate", | |
| disabled=(disable or st.session_state.translate_in_progress), | |
| use_container_width=True, | |
| key="translate_button") | |
| result_container = st.empty() | |
| if output_text_content is not None and not st.session_state.translate_in_progress: | |
| with result_container.container(): | |
| st.write("**Translation:**") | |
| output_container = result_container.container(border=True) | |
| output_container.write("".join([postprocess(output_text_content)])) | |
| # Load model from cache when click translate button, if the user has selected a model previously | |
| if not st.session_state.select_model_button and st.session_state.translate_button: | |
| model_config = st.session_state.model_state['current_config'] | |
| if model_config is None: | |
| # If the user click translate button without selecting a model, set st.session_state.translate_in_progress to False, | |
| # to avoid death of program and then refresh the page | |
| st.session_state.translate_in_progress = False | |
| st.rerun() | |
| dataset, language_pair, epsilon = model_config.split("_") | |
| model_path = load_selected_model(config, dataset, language_pair, epsilon) | |
| model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL) | |
| st.session_state.model_state['loaded'] = True | |
| if btn_translate: | |
| st.session_state.translate_in_progress = True | |
| with right: | |
| with st.spinner("Translating..."): | |
| prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang) | |
| st.session_state.translation_result['input'] = input_text | |
| st.session_state.translation_result['output'] = prediction | |
| st.session_state.translate_in_progress = False | |
| st.rerun() | |
| if __name__ == '__main__': | |
| DATASETS_MODEL_INFO_PATH = os.path.join(os.getcwd(), "dataset_and_model_info.json") | |
| logger.info(DATASETS_MODEL_INFO_PATH) | |
| DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH) | |
| DEFAULT_MODEL = 'google/mt5-small' | |
| MAX_SEQ_LEN = 512 | |
| NUM_BEAMS = 3 | |
| SEED = 2023 | |
| MAX_INPUT_LEN = 500 | |
| main() | |