|
|
import gradio as gr |
|
|
import torch |
|
|
|
|
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer |
|
|
|
|
|
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B") |
|
|
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B") |
|
|
|
|
|
langs = """Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), |
|
|
Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), |
|
|
Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), |
|
|
Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)""" |
|
|
lang_list = [lang.strip() for lang in langs.split(',')] |
|
|
|
|
|
def translate(src, tgt, text): |
|
|
src = src.split(" ")[-1][1:-1] |
|
|
tgt = tgt.split(" ")[-1][1:-1] |
|
|
|
|
|
|
|
|
tokenizer.src_lang = src |
|
|
encoded_src = tokenizer(text, return_tensors="pt") |
|
|
generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(tgt), use_cache=True) |
|
|
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
return result |
|
|
|
|
|
output_text = gr.outputs.Textbox() |
|
|
gr.Interface(translate, inputs=[gr.inputs.Dropdown(lang_list, label="Source Language"), gr.inputs.Dropdown(lang_list, label="Target Language"), 'text'], outputs=output_text, title="Translate Between 100 languages", |
|
|
description="M2M100-1.2B ๋ชจ๋ธ์ ๊ฐ์ง๊ณ 100๊ฐ๊ตญ์ด ์ธ์ด๋ฅผ ๋ฒ์ญํ๋ ๋ฐ๋ชจํ์ด์ง ์
๋๋ค.").launch() |
|
|
|