LouisMonawe commited on
Commit
cc13458
·
1 Parent(s): 8f8dd51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -14
app.py CHANGED
@@ -89,25 +89,56 @@
89
  # # Launch the app
90
  # interface.launch()
91
 
92
-
93
  import gradio as gr
94
- from transformers import MarianMTModel, MarianTokenizer
95
 
96
- # Pick your language pair
97
- model_name = "Helsinki-NLP/opus-mt-en-zu" # English to Zulu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- # Load tokenizer and model locally
100
- tokenizer = MarianTokenizer.from_pretrained(model_name)
101
- model = MarianMTModel.from_pretrained(model_name)
 
102
 
 
 
103
 
104
- def translate(text):
105
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
106
- translated = model.generate(**inputs)
107
- return tokenizer.decode(translated[0], skip_special_tokens=True)
 
108
 
109
 
110
  # Gradio interface
111
- gr.Interface(
112
- fn=translate, inputs="text", outputs="text", title="English to Zulu Translator"
113
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  # # Launch the app
90
  # interface.launch()
91
 
92
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
93
  import gradio as gr
 
94
 
95
+ # Load the tokenizer and model
96
+ model_name = "facebook/nllb-200-distilled-600M"
97
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
98
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
99
+
100
+ # Language code map
101
+ lang_map = {
102
+ "English": "eng_Latn",
103
+ "Afrikaans": "afr_Latn",
104
+ "Zulu": "zul_Latn",
105
+ "Xhosa": "xho_Latn",
106
+ "French": "fra_Latn",
107
+ "Spanish": "spa_Latn",
108
+ "Swahili": "swh_Latn",
109
+ }
110
+
111
 
112
+ # Translation function
113
+ def translate(text, src_lang, tgt_lang):
114
+ src_code = lang_map[src_lang]
115
+ tgt_code = lang_map[tgt_lang]
116
 
117
+ tokenizer.src_lang = src_code
118
+ inputs = tokenizer(text, return_tensors="pt", padding=True)
119
 
120
+ generated_tokens = model.generate(
121
+ **inputs, forced_bos_token_id=tokenizer.lang_code_to_id[tgt_code]
122
+ )
123
+ translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
124
+ return translated
125
 
126
 
127
  # Gradio interface
128
+ iface = gr.Interface(
129
+ fn=translate,
130
+ inputs=[
131
+ gr.Textbox(label="Enter text"),
132
+ gr.Dropdown(
133
+ choices=list(lang_map.keys()), label="From Language", value="English"
134
+ ),
135
+ gr.Dropdown(
136
+ choices=list(lang_map.keys()), label="To Language", value="Afrikaans"
137
+ ),
138
+ ],
139
+ outputs="text",
140
+ title="NLLB-200 Custom Language Translator",
141
+ description="Translate text using Facebook's distilled NLLB-200 model with selectable languages.",
142
+ )
143
+
144
+ iface.launch()