Minte commited on
Commit
2f77ad3
Β·
1 Parent(s): 5eebcbf

Add initial implementation of translation models and Gradio interface

Browse files
Files changed (2) hide show
  1. app.py +308 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoProcessor
4
+
5
+ # Language configuration with proper model handling
6
+ LANGUAGE_CONFIG = {
7
+ "Amharic": {
8
+ "code": "amh",
9
+ "model_type": "seamless",
10
+ "seamless_code": "amh"
11
+ },
12
+ "Swahili": {
13
+ "code": "swh",
14
+ "model_type": "seamless",
15
+ "seamless_code": "swh"
16
+ },
17
+ "Somali": {
18
+ "code": "som",
19
+ "model_type": "seamless",
20
+ "seamless_code": "som"
21
+ },
22
+ "Afan Oromo": {
23
+ "code": "gaz",
24
+ "model_type": "nllb",
25
+ "nllb_code": "gaz_Latn"
26
+ },
27
+ "Tigrinya": {
28
+ "code": "tir",
29
+ "model_type": "nllb",
30
+ "nllb_code": "tir_Ethi"
31
+ },
32
+ "Chichewa": {
33
+ "code": "nya",
34
+ "model_type": "nllb",
35
+ "nllb_code": "nya_Latn"
36
+ }
37
+ }
38
+
39
+ # Model instances
40
+ models = {}
41
+ tokenizers = {}
42
+ processors = {}
43
+
44
+ print("πŸš€ Initializing translation models...")
45
+
46
+ # Load SeamlessM4T model for Amharic, Swahili, Somali
47
+ try:
48
+ print("πŸ“₯ Loading SeamlessM4T model...")
49
+ seamless_model_id = "facebook/seamless-m4t-v2-large"
50
+ processors['seamless'] = AutoProcessor.from_pretrained(seamless_model_id)
51
+ models['seamless'] = AutoModelForSeq2SeqLM.from_pretrained(seamless_model_id)
52
+ print("βœ… SeamlessM4T model loaded successfully!")
53
+ except Exception as e:
54
+ print(f"❌ Failed to load SeamlessM4T model: {e}")
55
+ models['seamless'] = None
56
+ processors['seamless'] = None
57
+
58
+ # Load NLLB model for other languages
59
+ try:
60
+ print("πŸ“₯ Loading NLLB model...")
61
+ nllb_model_id = "facebook/nllb-200-distilled-600M"
62
+ tokenizers['nllb'] = AutoTokenizer.from_pretrained(nllb_model_id)
63
+ models['nllb'] = AutoModelForSeq2SeqLM.from_pretrained(nllb_model_id)
64
+ print("βœ… NLLB model loaded successfully!")
65
+ except Exception as e:
66
+ print(f"❌ Failed to load NLLB model: {e}")
67
+ models['nllb'] = None
68
+ tokenizers['nllb'] = None
69
+
70
+ def translate_with_seamless(text, source_lang_code):
71
+ """Translate text using SeamlessM4T model"""
72
+ try:
73
+ if models['seamless'] is None or processors['seamless'] is None:
74
+ return "SeamlessM4T model not available"
75
+
76
+ # Preprocess text
77
+ inputs = processors['seamless'](text=text, src_lang=source_lang_code, return_tensors="pt")
78
+
79
+ # Get BOS token for target language (English)
80
+ forced_bos_token_id = processors['seamless'].tokenizer.convert_tokens_to_ids("<|eng|>")
81
+
82
+ # Generate translation
83
+ with torch.no_grad():
84
+ generated_tokens = models['seamless'].generate(
85
+ **inputs,
86
+ forced_bos_token_id=forced_bos_token_id,
87
+ max_length=256
88
+ )
89
+
90
+ # Decode and return
91
+ translation = processors['seamless'].batch_decode(generated_tokens, skip_special_tokens=True)[0]
92
+ return translation
93
+
94
+ except Exception as e:
95
+ print(f"SeamlessM4T translation error: {e}")
96
+ return f"Translation failed: {str(e)[:200]}"
97
+
98
+ def translate_with_nllb(text, source_lang_code):
99
+ """Translate text using NLLB model"""
100
+ try:
101
+ if models['nllb'] is None or tokenizers['nllb'] is None:
102
+ return "NLLB model not available"
103
+
104
+ # Tokenize input
105
+ inputs = tokenizers['nllb'](text, return_tensors="pt")
106
+
107
+ # Define target language (English)
108
+ forced_bos_token_id = tokenizers['nllb'].convert_tokens_to_ids("eng_Latn")
109
+
110
+ # Generate translation using beam search for better quality
111
+ with torch.no_grad():
112
+ generated_tokens = models['nllb'].generate(
113
+ **inputs,
114
+ forced_bos_token_id=forced_bos_token_id,
115
+ max_length=256,
116
+ num_beams=5,
117
+ early_stopping=True
118
+ )
119
+
120
+ # Decode
121
+ translation = tokenizers['nllb'].batch_decode(generated_tokens, skip_special_tokens=True)[0]
122
+ return translation
123
+
124
+ except Exception as e:
125
+ print(f"NLLB translation error: {e}")
126
+ return f"Translation failed: {str(e)[:200]}"
127
+
128
+ def translate_text(text, source_language):
129
+ """Main translation function"""
130
+ if not text.strip():
131
+ return "Please enter text to translate"
132
+
133
+ if source_language not in LANGUAGE_CONFIG:
134
+ return f"Translation for {source_language} is not supported"
135
+
136
+ config = LANGUAGE_CONFIG[source_language]
137
+
138
+ try:
139
+ if config["model_type"] == "seamless":
140
+ return translate_with_seamless(text, config["seamless_code"])
141
+ else: # nllb
142
+ return translate_with_nllb(text, config["nllb_code"])
143
+
144
+ except Exception as e:
145
+ print(f"Translation error for {source_language}: {e}")
146
+ return f"Translation failed: {str(e)[:200]}"
147
+
148
+ # Example texts for each language
149
+ EXAMPLE_TEXTS = {
150
+ "Amharic": "αˆαˆ‰αˆ αˆ°α‹ α‰ αˆαˆ‰αˆ መα‰₯α‰ΆοΏ½οΏ½οΏ½ αŠ₯ኩል αŠα‹α’",
151
+ "Swahili": "Habari za asubuhi, leo tunajifunza teknolojia ya usemi.",
152
+ "Somali": "Maanta waa maalin qurux badan oo qoraxdu si wanaagsan u iftiimayso.",
153
+ "Afan Oromo": "Akkam bulte, har'a technology dubbachuu baranna.",
154
+ "Tigrinya": "αˆ˜α‹“αˆα‰² αˆ°αŠ“α‹­α‘ ሎሚ α‰΄αŠ­αŠ–αˆŽαŒ‚ α‹˜αˆ¨α‰£ αŠ•αˆαˆαŒ₯ፒ",
155
+ "Chichewa": "Alipo wina aliyense ali ndi ufulu wachibadwidwe."
156
+ }
157
+
158
+ # Test the models on startup
159
+ def test_models():
160
+ print("πŸ§ͺ Testing translation models...")
161
+
162
+ test_cases = [
163
+ ("Swahili", "Habari za asubuhi"),
164
+ ("Somali", "Maanta waa maalin fiican"),
165
+ ("Amharic", "αˆ°αˆ‹αˆ"),
166
+ ("Afan Oromo", "Akkam jirta"),
167
+ ("Tigrinya", "αˆ°αˆ‹αˆ"),
168
+ ("Chichewa", "Moni")
169
+ ]
170
+
171
+ for lang, text in test_cases:
172
+ try:
173
+ result = translate_text(text, lang)
174
+ print(f"βœ… {lang} test: '{text}' β†’ '{result}'")
175
+ except Exception as e:
176
+ print(f"❌ {lang} test failed: {e}")
177
+
178
+ # Run tests on startup
179
+ test_models()
180
+
181
+ # Create Gradio interface
182
+ with gr.Blocks(
183
+ theme=gr.themes.Soft(
184
+ primary_hue="blue",
185
+ secondary_hue="green"
186
+ ),
187
+ title="🌍 GihonTech - Local Language to English Translation"
188
+ ) as demo:
189
+
190
+ gr.Markdown("# 🌍 GihonTech Local Language to English Translation")
191
+ gr.Markdown("Translate text from African languages to English using advanced AI models")
192
+
193
+ with gr.Row():
194
+ with gr.Column(scale=1):
195
+ text_input = gr.Textbox(
196
+ label="Source Text",
197
+ placeholder="Enter text to translate...",
198
+ lines=4,
199
+ show_copy_button=True
200
+ )
201
+
202
+ language_select = gr.Dropdown(
203
+ choices=list(LANGUAGE_CONFIG.keys()),
204
+ value="Amharic",
205
+ label="Source Language",
206
+ info="Select the language of your text"
207
+ )
208
+
209
+ # Example buttons in two rows
210
+ with gr.Row():
211
+ for lang in ["Amharic", "Swahili", "Somali"]:
212
+ gr.Button(
213
+ f"{lang} Example",
214
+ size="sm"
215
+ ).click(
216
+ lambda l=lang: EXAMPLE_TEXTS[l],
217
+ outputs=text_input
218
+ )
219
+
220
+ with gr.Row():
221
+ for lang in ["Afan Oromo", "Tigrinya", "Chichewa"]:
222
+ gr.Button(
223
+ f"{lang} Example",
224
+ size="sm"
225
+ ).click(
226
+ lambda l=lang: EXAMPLE_TEXTS[l],
227
+ outputs=text_input
228
+ )
229
+
230
+ translate_btn = gr.Button(
231
+ "🎯 Translate to English",
232
+ variant="primary",
233
+ size="lg"
234
+ )
235
+
236
+ with gr.Column(scale=1):
237
+ translation_output = gr.Textbox(
238
+ label="English Translation",
239
+ placeholder="Your translated text will appear here...",
240
+ lines=5,
241
+ show_copy_button=True
242
+ )
243
+
244
+ # Connect the translate button
245
+ translate_btn.click(
246
+ fn=translate_text,
247
+ inputs=[text_input, language_select],
248
+ outputs=translation_output
249
+ )
250
+
251
+ # Also allow pressing Enter to translate
252
+ text_input.submit(
253
+ fn=translate_text,
254
+ inputs=[text_input, language_select],
255
+ outputs=translation_output
256
+ )
257
+
258
+ # Model status and information
259
+ with gr.Row():
260
+ with gr.Column():
261
+ gr.Markdown("### πŸ”§ Model Information")
262
+
263
+ # Create status display
264
+ seamless_status = "βœ… Loaded" if models.get('seamless') else "❌ Failed"
265
+ nllb_status = "βœ… Loaded" if models.get('nllb') else "❌ Failed"
266
+
267
+ status_text = f"SeamlessM4T: {seamless_status} | NLLB: {nllb_status}"
268
+ gr.Textbox(
269
+ value=status_text,
270
+ label="Model Status",
271
+ interactive=False
272
+ )
273
+
274
+ # Create model info
275
+ seamless_langs = [lang for lang, config in LANGUAGE_CONFIG.items() if config["model_type"] == "seamless"]
276
+ nllb_langs = [lang for lang, config in LANGUAGE_CONFIG.items() if config["model_type"] == "nllb"]
277
+
278
+ gr.Markdown(f"""
279
+ **Advanced Models (SeamlessM4T):** {', '.join(seamless_langs)}
280
+ **Standard Models (NLLB-200):** {', '.join(nllb_langs)}
281
+
282
+ **Features:**
283
+ - High-quality translations for African languages
284
+ - Support for text input and copy-paste functionality
285
+ - Fast and accurate results using beam search
286
+ - Proper tokenization for each language family
287
+ """)
288
+
289
+ # Add CSS for better styling
290
+ gr.HTML("""
291
+ <style>
292
+ .gradio-container {
293
+ max-width: 1200px !important;
294
+ }
295
+ .textbox textarea {
296
+ min-height: 120px;
297
+ }
298
+ </style>
299
+ """)
300
+
301
+ if __name__ == "__main__":
302
+ demo.launch(
303
+ server_name="0.0.0.0",
304
+ server_port=7860,
305
+ share=False,
306
+ show_error=True
307
+ )
308
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Minimal requirements.txt
2
+ torch>=2.0.1
3
+ transformers>=4.35.0
4
+ gradio>=4.0.0
5
+ soundfile>=0.12.0
6
+ resampy>=0.4.0
7
+ numpy>=1.24.0