Spaces:
Running
Running
Commit
·
cc13458
1
Parent(s):
8f8dd51
Update app.py
Browse files
app.py
CHANGED
|
@@ -89,25 +89,56 @@
|
|
| 89 |
# # Launch the app
|
| 90 |
# interface.launch()
|
| 91 |
|
| 92 |
-
|
| 93 |
import gradio as gr
|
| 94 |
-
from transformers import MarianMTModel, MarianTokenizer
|
| 95 |
|
| 96 |
-
#
|
| 97 |
-
model_name = "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
#
|
| 100 |
-
|
| 101 |
-
|
|
|
|
| 102 |
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
# Gradio interface
|
| 111 |
-
gr.Interface(
|
| 112 |
-
fn=translate,
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
# # Launch the app
|
| 90 |
# interface.launch()
|
| 91 |
|
| 92 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 93 |
import gradio as gr
|
|
|
|
| 94 |
|
| 95 |
+
# Load the tokenizer and model
|
| 96 |
+
model_name = "facebook/nllb-200-distilled-600M"
|
| 97 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 98 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 99 |
+
|
| 100 |
+
# Language code map
|
| 101 |
+
lang_map = {
|
| 102 |
+
"English": "eng_Latn",
|
| 103 |
+
"Afrikaans": "afr_Latn",
|
| 104 |
+
"Zulu": "zul_Latn",
|
| 105 |
+
"Xhosa": "xho_Latn",
|
| 106 |
+
"French": "fra_Latn",
|
| 107 |
+
"Spanish": "spa_Latn",
|
| 108 |
+
"Swahili": "swh_Latn",
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
|
| 112 |
+
# Translation function
|
| 113 |
+
def translate(text, src_lang, tgt_lang):
|
| 114 |
+
src_code = lang_map[src_lang]
|
| 115 |
+
tgt_code = lang_map[tgt_lang]
|
| 116 |
|
| 117 |
+
tokenizer.src_lang = src_code
|
| 118 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True)
|
| 119 |
|
| 120 |
+
generated_tokens = model.generate(
|
| 121 |
+
**inputs, forced_bos_token_id=tokenizer.lang_code_to_id[tgt_code]
|
| 122 |
+
)
|
| 123 |
+
translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
| 124 |
+
return translated
|
| 125 |
|
| 126 |
|
| 127 |
# Gradio interface
|
| 128 |
+
iface = gr.Interface(
|
| 129 |
+
fn=translate,
|
| 130 |
+
inputs=[
|
| 131 |
+
gr.Textbox(label="Enter text"),
|
| 132 |
+
gr.Dropdown(
|
| 133 |
+
choices=list(lang_map.keys()), label="From Language", value="English"
|
| 134 |
+
),
|
| 135 |
+
gr.Dropdown(
|
| 136 |
+
choices=list(lang_map.keys()), label="To Language", value="Afrikaans"
|
| 137 |
+
),
|
| 138 |
+
],
|
| 139 |
+
outputs="text",
|
| 140 |
+
title="NLLB-200 Custom Language Translator",
|
| 141 |
+
description="Translate text using Facebook's distilled NLLB-200 model with selectable languages.",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
iface.launch()
|