Spaces:
Runtime error
Runtime error
| import re | |
| import os | |
| import sys | |
| import torch | |
| import gradio as gr | |
| from transformers import MBart50TokenizerFast, MBartForConditionalGeneration | |
| language_options = { | |
| '中文': 'zh_CN', | |
| '英语': 'en_XX', | |
| '越南语': 'vi_VN', | |
| '泰语': 'th_TH', | |
| '日语': 'ja_XX', | |
| '韩语': 'ko_KR', | |
| } | |
| languages = list(language_options.keys()) | |
| class MBartTranslator: | |
| """MBartTranslator class provides a simple interface for translating text using the MBart language model. | |
| The class can translate between 50 languages and is based on the "facebook/mbart-large-50-many-to-many-mmt" | |
| pre-trained MBart model. However, it is possible to use a different MBart model by specifying its name. | |
| Attributes: | |
| model (MBartForConditionalGeneration): The MBart language model. | |
| tokenizer (MBart50TokenizerFast): The MBart tokenizer. | |
| """ | |
| def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None): | |
| self.supported_languages = [ | |
| "ar_AR", | |
| "cs_CZ", | |
| "de_DE", | |
| "en_XX", | |
| "es_XX", | |
| "et_EE", | |
| "fi_FI", | |
| "fr_XX", | |
| "gu_IN", | |
| "hi_IN", | |
| "it_IT", | |
| "ja_XX", | |
| "kk_KZ", | |
| "ko_KR", | |
| "lt_LT", | |
| "lv_LV", | |
| "my_MM", | |
| "ne_NP", | |
| "nl_XX", | |
| "ro_RO", | |
| "ru_RU", | |
| "si_LK", | |
| "tr_TR", | |
| "vi_VN", | |
| "zh_CN", | |
| "af_ZA", | |
| "az_AZ", | |
| "bn_IN", | |
| "fa_IR", | |
| "he_IL", | |
| "hr_HR", | |
| "id_ID", | |
| "ka_GE", | |
| "km_KH", | |
| "mk_MK", | |
| "ml_IN", | |
| "mn_MN", | |
| "mr_IN", | |
| "pl_PL", | |
| "ps_AF", | |
| "pt_XX", | |
| "sv_SE", | |
| "sw_KE", | |
| "ta_IN", | |
| "te_IN", | |
| "th_TH", | |
| "tl_XX", | |
| "uk_UA", | |
| "ur_PK", | |
| "xh_ZA", | |
| "gl_ES", | |
| "sl_SI", | |
| ] | |
| print("Building translator") | |
| print("Loading generator (this may take few minutes the first time as I need to download the model)") | |
| self.model = MBartForConditionalGeneration.from_pretrained(model_name).to(device) | |
| print("Loading tokenizer") | |
| self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang) | |
| print("Translator is ready") | |
| def translate(self, text: str, input_language: str, output_language: str) -> str: | |
| """Translate the given text from the input language to the output language. | |
| Args: | |
| text (str): The text to translate. | |
| input_language (str): The input language code (e.g. "hi_IN" for Hindi). | |
| output_language (str): The output language code (e.g. "en_US" for English). | |
| Returns: | |
| str: The translated text. | |
| """ | |
| if input_language not in self.supported_languages: | |
| raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}") | |
| if output_language not in self.supported_languages: | |
| raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}") | |
| self.tokenizer.src_lang = input_language | |
| encoded_input = self.tokenizer(text, return_tensors="pt").to(device) | |
| generated_tokens = self.model.generate( | |
| **encoded_input, forced_bos_token_id=self.tokenizer.lang_code_to_id[output_language] | |
| ) | |
| translated_text = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
| return translated_text[0] | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| translator = MBartTranslator() | |
| def translate(src, dst, content): | |
| output = translator.translate(content, language_options[src], language_options[dst]) | |
| # output = translator.translate(content, "zh_CN", "en_XX") | |
| return output | |
| examples=[ | |
| ['中文', '英语', '今天天气真不错!'], | |
| ['英语', '中文', "Life was a box of chocolates, you never know what you're gonna get."], | |
| ['中文', '泰语', '别放弃你的梦想,迟早有一天它会在你手里发光。'], | |
| ] | |
| demo = gr.Interface( | |
| fn=translate, | |
| inputs=[ | |
| gr.Dropdown( | |
| languages, label="源语言", value=languages[0], show_label=True | |
| ), | |
| gr.Dropdown( | |
| languages, label="目标语言", value=languages[1], show_label=True | |
| ), | |
| gr.Textbox(label='内容', placeholder='这里输入要翻译的内容', lines=5) | |
| ], | |
| outputs=[ | |
| gr.Textbox(label='结果', lines=5) | |
| ], | |
| examples=examples | |
| ) | |
| demo.launch(enable_queue=True) | |