NLLB200 / app.py
TGPro1's picture
Update app.py
67803e0 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from functools import lru_cache
# Language mappings for NLLB-200
LANGUAGE_CODES = {
"Arabic": "arb_Arab",
"English": "eng_Latn",
"French": "fra_Latn",
"Spanish": "spa_Latn",
"German": "deu_Latn",
"Italian": "ita_Latn",
"Portuguese": "por_Latn",
"Russian": "rus_Cyrl",
"Japanese": "jpn_Jpan",
"Korean": "kor_Hang",
"Chinese (Simplified)": "zho_Hans",
"Hindi": "hin_Deva",
"Turkish": "tur_Latn",
"Dutch": "nld_Latn",
"Polish": "pol_Latn",
"Swedish": "swe_Latn",
"Arabic (Egyptian)": "arz_Arab",
"Arabic (Moroccan)": "ary_Arab",
"Indonesian": "ind_Latn",
"Vietnamese": "vie_Latn",
"Thai": "tha_Thai",
"Ukrainian": "ukr_Cyrl",
"Romanian": "ron_Latn",
"Greek": "ell_Grek",
"Hebrew": "heb_Hebr",
}
# Load model
print("Loading NLLB-200 model...")
model_name = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"Model loaded on {device}")
# Simple cache dictionary
translation_cache = {}
def translate(text, src_lang, tgt_lang):
if not text or not text.strip():
return ""
text = text.strip()
src_lang_code = LANGUAGE_CODES.get(src_lang, "eng_Latn")
tgt_lang_code = LANGUAGE_CODES.get(tgt_lang, "arb_Arab")
cache_key = f"{src_lang_code}:{tgt_lang_code}:{text}"
if cache_key in translation_cache:
return translation_cache[cache_key]
try:
tokenizer.src_lang = src_lang_code
inputs = tokenizer(text, return_tensors="pt", padding=True, max_length=512, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
translated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang_code], max_length=512, num_beams=5, early_stopping=True)
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
translation_cache[cache_key] = translation
return translation
except Exception as e:
return f"Translation error: {str(e)}"
def gradio_translate(text, src_lang, tgt_lang):
"""Gradio interface function"""
if src_lang == tgt_lang:
return text
result = translate(text, src_lang, tgt_lang)
return result
# Available languages (sorted alphabetically)
LANGUAGES = sorted(LANGUAGE_CODES.keys())
# Create Gradio Interface
with gr.Blocks(title="NLLB-200 Translation API", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🌍 NLLB-200 Translation API
**Meta's No Language Left Behind** - 200 Languages Translation
- βœ… High-quality translation for 200+ languages
- βœ… 44% better than previous models
- βœ… +70% improvement for complex languages (Arabic, Hindi, etc.)
- βœ… Direct translation (no pivot through English)
- βœ… Cached for faster repeated translations
**Powered by**: `facebook/nllb-200-distilled-600M`
"""
)
with gr.Row():
with gr.Column():
src_lang = gr.Dropdown(
choices=LANGUAGES,
value="English",
label="Source Language",
interactive=True
)
input_text = gr.Textbox(
label="Text to Translate",
placeholder="Enter text here...",
lines=5,
max_lines=10
)
with gr.Column():
tgt_lang = gr.Dropdown(
choices=LANGUAGES,
value="Arabic",
label="Target Language",
interactive=True
)
output_text = gr.Textbox(
label="Translation",
lines=5,
max_lines=10,
interactive=False
)
with gr.Row():
translate_btn = gr.Button("Translate πŸš€", variant="primary", size="lg")
clear_btn = gr.Button("Clear", variant="secondary")
# Examples
gr.Examples(
examples=[
["Hello, how are you?", "English", "Arabic"],
["Ω…Ψ±Ψ­Ψ¨Ψ§ΨŒ ΩƒΩŠΩ Ψ­Ψ§Ω„ΩƒΨŸ", "Arabic", "French"],
["Bonjour, comment allez-vous?", "French", "English"],
["This is a test of NLLB-200 translation model.", "English", "Spanish"],
],
inputs=[input_text, src_lang, tgt_lang],
outputs=output_text,
fn=gradio_translate,
cache_examples=False
)
# Event handlers
translate_btn.click(
fn=gradio_translate,
inputs=[input_text, src_lang, tgt_lang],
outputs=output_text
)
clear_btn.click(
fn=lambda: ("", ""),
inputs=None,
outputs=[input_text, output_text]
)
# Also translate on Enter key
input_text.submit(
fn=gradio_translate,
inputs=[input_text, src_lang, tgt_lang],
outputs=output_text
)
gr.Markdown(
"""
---
### API Usage
You can use this Space programmatically via the Gradio API:
```python
from gradio_client import Client
client = Client("TGPro1/NLLB200")
result = client.predict(
"Hello, world!", # text
"English", # source language
"Arabic", # target language
api_name="/predict"
)
print(result)
```
**Supported Languages**: 25+ major languages (see dropdown)
For full list of 200 languages, check the [NLLB-200 documentation](https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200)
"""
)
if __name__ == "__main__":
demo.queue(max_size=10)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)