Spaces:
Runtime error
Runtime error
File size: 4,834 Bytes
93372ef e75d706 93372ef e75d706 93372ef 36f5ede 6571a73 b65bee4 36f5ede 93372ef b65bee4 93372ef b65bee4 93372ef b65bee4 93372ef b65bee4 93372ef b65bee4 93372ef b65bee4 93372ef b65bee4 93372ef b65bee4 93372ef b65bee4 93372ef b65bee4 93372ef d2800da 36f5ede 93372ef b135825 627db7a b135825 93372ef 36f5ede 93372ef 36f5ede 93372ef 94d2dd6 93372ef 58f896d b135825 93372ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import re
import os
import sys
import torch
import gradio as gr
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
language_options = {
'中文': 'zh_CN',
'英语': 'en_XX',
'越南语': 'vi_VN',
'泰语': 'th_TH',
'日语': 'ja_XX',
'韩语': 'ko_KR',
}
languages = list(language_options.keys())
class MBartTranslator:
"""MBartTranslator class provides a simple interface for translating text using the MBart language model.
The class can translate between 50 languages and is based on the "facebook/mbart-large-50-many-to-many-mmt"
pre-trained MBart model. However, it is possible to use a different MBart model by specifying its name.
Attributes:
model (MBartForConditionalGeneration): The MBart language model.
tokenizer (MBart50TokenizerFast): The MBart tokenizer.
"""
def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None):
self.supported_languages = [
"ar_AR",
"cs_CZ",
"de_DE",
"en_XX",
"es_XX",
"et_EE",
"fi_FI",
"fr_XX",
"gu_IN",
"hi_IN",
"it_IT",
"ja_XX",
"kk_KZ",
"ko_KR",
"lt_LT",
"lv_LV",
"my_MM",
"ne_NP",
"nl_XX",
"ro_RO",
"ru_RU",
"si_LK",
"tr_TR",
"vi_VN",
"zh_CN",
"af_ZA",
"az_AZ",
"bn_IN",
"fa_IR",
"he_IL",
"hr_HR",
"id_ID",
"ka_GE",
"km_KH",
"mk_MK",
"ml_IN",
"mn_MN",
"mr_IN",
"pl_PL",
"ps_AF",
"pt_XX",
"sv_SE",
"sw_KE",
"ta_IN",
"te_IN",
"th_TH",
"tl_XX",
"uk_UA",
"ur_PK",
"xh_ZA",
"gl_ES",
"sl_SI",
]
print("Building translator")
print("Loading generator (this may take few minutes the first time as I need to download the model)")
self.model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
print("Loading tokenizer")
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang)
print("Translator is ready")
def translate(self, text: str, input_language: str, output_language: str) -> str:
"""Translate the given text from the input language to the output language.
Args:
text (str): The text to translate.
input_language (str): The input language code (e.g. "hi_IN" for Hindi).
output_language (str): The output language code (e.g. "en_US" for English).
Returns:
str: The translated text.
"""
if input_language not in self.supported_languages:
raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}")
if output_language not in self.supported_languages:
raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}")
self.tokenizer.src_lang = input_language
encoded_input = self.tokenizer(text, return_tensors="pt").to(device)
generated_tokens = self.model.generate(
**encoded_input, forced_bos_token_id=self.tokenizer.lang_code_to_id[output_language]
)
translated_text = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return translated_text[0]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
translator = MBartTranslator()
def translate(src, dst, content):
output = translator.translate(content, language_options[src], language_options[dst])
# output = translator.translate(content, "zh_CN", "en_XX")
return output
examples=[
['中文', '英语', '今天天气真不错!'],
['英语', '中文', "Life was a box of chocolates, you never know what you're gonna get."],
['中文', '泰语', '别放弃你的梦想,迟早有一天它会在你手里发光。'],
]
demo = gr.Interface(
fn=translate,
inputs=[
gr.Dropdown(
languages, label="源语言", value=languages[0], show_label=True
),
gr.Dropdown(
languages, label="目标语言", value=languages[1], show_label=True
),
gr.Textbox(label='内容', placeholder='这里输入要翻译的内容', lines=5)
],
outputs=[
gr.Textbox(label='结果', lines=5)
],
examples=examples
)
demo.launch(enable_queue=True)
|