|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import torch |
|
|
from functools import lru_cache |
|
|
|
|
|
|
|
|
LANGUAGE_CODES = { |
|
|
"Arabic": "arb_Arab", |
|
|
"English": "eng_Latn", |
|
|
"French": "fra_Latn", |
|
|
"Spanish": "spa_Latn", |
|
|
"German": "deu_Latn", |
|
|
"Italian": "ita_Latn", |
|
|
"Portuguese": "por_Latn", |
|
|
"Russian": "rus_Cyrl", |
|
|
"Japanese": "jpn_Jpan", |
|
|
"Korean": "kor_Hang", |
|
|
"Chinese (Simplified)": "zho_Hans", |
|
|
"Hindi": "hin_Deva", |
|
|
"Turkish": "tur_Latn", |
|
|
"Dutch": "nld_Latn", |
|
|
"Polish": "pol_Latn", |
|
|
"Swedish": "swe_Latn", |
|
|
"Arabic (Egyptian)": "arz_Arab", |
|
|
"Arabic (Moroccan)": "ary_Arab", |
|
|
"Indonesian": "ind_Latn", |
|
|
"Vietnamese": "vie_Latn", |
|
|
"Thai": "tha_Thai", |
|
|
"Ukrainian": "ukr_Cyrl", |
|
|
"Romanian": "ron_Latn", |
|
|
"Greek": "ell_Grek", |
|
|
"Hebrew": "heb_Hebr", |
|
|
} |
|
|
|
|
|
|
|
|
print("Loading NLLB-200 model...") |
|
|
model_name = "facebook/nllb-200-distilled-600M" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = model.to(device) |
|
|
print(f"Model loaded on {device}") |
|
|
|
|
|
|
|
|
|
|
|
translation_cache = {} |
|
|
|
|
|
def translate(text, src_lang, tgt_lang): |
|
|
if not text or not text.strip(): |
|
|
return "" |
|
|
text = text.strip() |
|
|
src_lang_code = LANGUAGE_CODES.get(src_lang, "eng_Latn") |
|
|
tgt_lang_code = LANGUAGE_CODES.get(tgt_lang, "arb_Arab") |
|
|
cache_key = f"{src_lang_code}:{tgt_lang_code}:{text}" |
|
|
if cache_key in translation_cache: |
|
|
return translation_cache[cache_key] |
|
|
try: |
|
|
tokenizer.src_lang = src_lang_code |
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True, max_length=512, truncation=True) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
with torch.no_grad(): |
|
|
translated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang_code], max_length=512, num_beams=5, early_stopping=True) |
|
|
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] |
|
|
translation_cache[cache_key] = translation |
|
|
return translation |
|
|
except Exception as e: |
|
|
return f"Translation error: {str(e)}" |
|
|
|
|
|
|
|
|
def gradio_translate(text, src_lang, tgt_lang): |
|
|
"""Gradio interface function""" |
|
|
if src_lang == tgt_lang: |
|
|
return text |
|
|
|
|
|
result = translate(text, src_lang, tgt_lang) |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
LANGUAGES = sorted(LANGUAGE_CODES.keys()) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="NLLB-200 Translation API", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# π NLLB-200 Translation API |
|
|
|
|
|
**Meta's No Language Left Behind** - 200 Languages Translation |
|
|
|
|
|
- β
High-quality translation for 200+ languages |
|
|
- β
44% better than previous models |
|
|
- β
+70% improvement for complex languages (Arabic, Hindi, etc.) |
|
|
- β
Direct translation (no pivot through English) |
|
|
- β
Cached for faster repeated translations |
|
|
|
|
|
**Powered by**: `facebook/nllb-200-distilled-600M` |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
src_lang = gr.Dropdown( |
|
|
choices=LANGUAGES, |
|
|
value="English", |
|
|
label="Source Language", |
|
|
interactive=True |
|
|
) |
|
|
input_text = gr.Textbox( |
|
|
label="Text to Translate", |
|
|
placeholder="Enter text here...", |
|
|
lines=5, |
|
|
max_lines=10 |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
tgt_lang = gr.Dropdown( |
|
|
choices=LANGUAGES, |
|
|
value="Arabic", |
|
|
label="Target Language", |
|
|
interactive=True |
|
|
) |
|
|
output_text = gr.Textbox( |
|
|
label="Translation", |
|
|
lines=5, |
|
|
max_lines=10, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
translate_btn = gr.Button("Translate π", variant="primary", size="lg") |
|
|
clear_btn = gr.Button("Clear", variant="secondary") |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Hello, how are you?", "English", "Arabic"], |
|
|
["Ω
Ψ±ΨΨ¨Ψ§Ψ ΩΩΩ ΨΨ§ΩΩΨ", "Arabic", "French"], |
|
|
["Bonjour, comment allez-vous?", "French", "English"], |
|
|
["This is a test of NLLB-200 translation model.", "English", "Spanish"], |
|
|
], |
|
|
inputs=[input_text, src_lang, tgt_lang], |
|
|
outputs=output_text, |
|
|
fn=gradio_translate, |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
|
|
|
translate_btn.click( |
|
|
fn=gradio_translate, |
|
|
inputs=[input_text, src_lang, tgt_lang], |
|
|
outputs=output_text |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=lambda: ("", ""), |
|
|
inputs=None, |
|
|
outputs=[input_text, output_text] |
|
|
) |
|
|
|
|
|
|
|
|
input_text.submit( |
|
|
fn=gradio_translate, |
|
|
inputs=[input_text, src_lang, tgt_lang], |
|
|
outputs=output_text |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
### API Usage |
|
|
|
|
|
You can use this Space programmatically via the Gradio API: |
|
|
|
|
|
```python |
|
|
from gradio_client import Client |
|
|
|
|
|
client = Client("TGPro1/NLLB200") |
|
|
result = client.predict( |
|
|
"Hello, world!", # text |
|
|
"English", # source language |
|
|
"Arabic", # target language |
|
|
api_name="/predict" |
|
|
) |
|
|
print(result) |
|
|
``` |
|
|
|
|
|
**Supported Languages**: 25+ major languages (see dropdown) |
|
|
|
|
|
For full list of 200 languages, check the [NLLB-200 documentation](https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200) |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=10) |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False |
|
|
) |
|
|
|