File size: 4,834 Bytes
93372ef
e75d706
93372ef
 
 
e75d706
93372ef
 
 
36f5ede
6571a73
 
 
 
b65bee4
 
36f5ede
 
 
93372ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b65bee4
93372ef
 
 
b65bee4
 
93372ef
b65bee4
93372ef
 
 
b65bee4
 
 
 
 
 
 
 
93372ef
b65bee4
 
 
 
93372ef
b65bee4
 
93372ef
 
 
 
 
 
 
 
b65bee4
93372ef
b65bee4
 
 
 
 
93372ef
 
 
b65bee4
93372ef
 
b65bee4
 
 
93372ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2800da
 
36f5ede
93372ef
b135825
627db7a
 
 
b135825
93372ef
 
 
 
 
36f5ede
93372ef
 
36f5ede
93372ef
94d2dd6
93372ef
 
58f896d
b135825
 
93372ef
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import re
import os
import sys

import torch
import gradio as gr
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration


language_options = {
    '中文': 'zh_CN',
    '英语': 'en_XX',
    '越南语': 'vi_VN',
    '泰语': 'th_TH',
    '日语': 'ja_XX',
    '韩语': 'ko_KR',
}
languages = list(language_options.keys())

class MBartTranslator:
    """MBartTranslator class provides a simple interface for translating text using the MBart language model.

    The class can translate between 50 languages and is based on the "facebook/mbart-large-50-many-to-many-mmt"
    pre-trained MBart model. However, it is possible to use a different MBart model by specifying its name.

    Attributes:
        model (MBartForConditionalGeneration): The MBart language model.
        tokenizer (MBart50TokenizerFast): The MBart tokenizer.
    """

    def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None):

        self.supported_languages = [
            "ar_AR",
            "cs_CZ",
            "de_DE",
            "en_XX",
            "es_XX",
            "et_EE",
            "fi_FI",
            "fr_XX",
            "gu_IN",
            "hi_IN",
            "it_IT",
            "ja_XX",
            "kk_KZ",
            "ko_KR",
            "lt_LT",
            "lv_LV",
            "my_MM",
            "ne_NP",
            "nl_XX",
            "ro_RO",
            "ru_RU",
            "si_LK",
            "tr_TR",
            "vi_VN",
            "zh_CN",
            "af_ZA",
            "az_AZ",
            "bn_IN",
            "fa_IR",
            "he_IL",
            "hr_HR",
            "id_ID",
            "ka_GE",
            "km_KH",
            "mk_MK",
            "ml_IN",
            "mn_MN",
            "mr_IN",
            "pl_PL",
            "ps_AF",
            "pt_XX",
            "sv_SE",
            "sw_KE",
            "ta_IN",
            "te_IN",
            "th_TH",
            "tl_XX",
            "uk_UA",
            "ur_PK",
            "xh_ZA",
            "gl_ES",
            "sl_SI",
        ]
        print("Building translator")
        print("Loading generator (this may take few minutes the first time as I need to download the model)")
        self.model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
        print("Loading tokenizer")
        self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang)
        print("Translator is ready")

    def translate(self, text: str, input_language: str, output_language: str) -> str:
        """Translate the given text from the input language to the output language.

        Args:
            text (str): The text to translate.
            input_language (str): The input language code (e.g. "hi_IN" for Hindi).
            output_language (str): The output language code (e.g. "en_US" for English).

        Returns:
            str: The translated text.
        """
        if input_language not in self.supported_languages:
            raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}")
        if output_language not in self.supported_languages:
            raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}")

        self.tokenizer.src_lang = input_language
        encoded_input = self.tokenizer(text, return_tensors="pt").to(device)
        generated_tokens = self.model.generate(
            **encoded_input, forced_bos_token_id=self.tokenizer.lang_code_to_id[output_language]
        )
        translated_text = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

        return translated_text[0]


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
translator = MBartTranslator()


def translate(src, dst, content):
    output = translator.translate(content, language_options[src], language_options[dst])
    # output = translator.translate(content, "zh_CN", "en_XX")
    return output

examples=[
    ['中文', '英语', '今天天气真不错!'],
    ['英语', '中文', "Life was a box of chocolates, you never know what you're gonna get."],
    ['中文', '泰语', '别放弃你的梦想,迟早有一天它会在你手里发光。'],
]

demo = gr.Interface(
    fn=translate,
    inputs=[
        gr.Dropdown(
            languages, label="源语言", value=languages[0], show_label=True
        ),
        gr.Dropdown(
            languages, label="目标语言", value=languages[1], show_label=True
        ),
        gr.Textbox(label='内容', placeholder='这里输入要翻译的内容', lines=5)
    ],
    outputs=[
        gr.Textbox(label='结果', lines=5)
    ],
    examples=examples
)


demo.launch(enable_queue=True)