Minte commited on
Commit
f525548
Β·
1 Parent(s): 570f689

Refactor Swahili and Somali model configurations and update loading logic

Browse files
Files changed (2) hide show
  1. app.py +75 -66
  2. requirements.txt +2 -1
app.py CHANGED
@@ -11,13 +11,13 @@ LANGUAGE_CONFIG = {
11
  },
12
  "Swahili": {
13
  "code": "swh",
14
- "model_type": "swahili_mms",
15
- "swahili_code": "swh"
16
  },
17
  "Somali": {
18
  "code": "som",
19
- "model_type": "somali_m2m",
20
- "somali_code": "so"
21
  },
22
  "Afan Oromo": {
23
  "code": "gaz",
@@ -40,35 +40,29 @@ LANGUAGE_CONFIG = {
40
  models = {}
41
  tokenizers = {}
42
 
43
- print("πŸš€ Initializing specialized translation models...")
44
 
45
- # Load Swahili MMS model
46
  try:
47
- print("πŸ“₯ Loading Swahili MMS model...")
48
- swahili_model_id = "Benjamin-png/swahili-mms-tts-finetuned"
49
- # Note: This appears to be a TTS model, so we'll need to check if it supports translation
50
- # If not, we'll fall back to another approach
51
- try:
52
- tokenizers['swahili'] = AutoTokenizer.from_pretrained(swahili_model_id)
53
- models['swahili'] = AutoModelForSeq2SeqLM.from_pretrained(swahili_model_id)
54
- print("βœ… Swahili MMS model loaded successfully!")
55
- except:
56
- print("⚠️ Swahili MMS model might be TTS-only, will use fallback")
57
- models['swahili'] = None
58
  except Exception as e:
59
- print(f"❌ Failed to load Swahili MMS model: {e}")
60
- models['swahili'] = None
61
 
62
- # Load Somali M2M100 model
63
  try:
64
- print("πŸ“₯ Loading Somali M2M100 model...")
65
- somali_model_id = "Ammad1Ali/m2m100_418M-2.0"
66
- tokenizers['somali'] = AutoTokenizer.from_pretrained(somali_model_id)
67
- models['somali'] = M2M100ForConditionalGeneration.from_pretrained(somali_model_id)
68
- print("βœ… Somali M2M100 model loaded successfully!")
69
  except Exception as e:
70
- print(f"❌ Failed to load Somali M2M100 model: {e}")
71
- models['somali'] = None
72
 
73
  # Load NLLB model for other languages
74
  try:
@@ -81,56 +75,71 @@ except Exception as e:
81
  print(f"❌ Failed to load NLLB model: {e}")
82
  models['nllb'] = None
83
 
84
- def translate_with_swahili_mms(text):
85
- """Translate Swahili text using specialized model"""
86
  try:
87
- if models.get('swahili') is None:
88
  return "Swahili translation model not available"
89
 
90
- # For MMS models, we need to check the specific approach
91
- # Since this might be a TTS model, we'll use a fallback to NLLB
92
- if models['nllb'] is not None:
93
- return translate_with_nllb(text, "swh_Latn")
94
- else:
95
- return "Translation service temporarily unavailable"
96
-
 
 
 
 
 
 
 
 
 
97
  except Exception as e:
98
- print(f"Swahili translation error: {e}")
99
- if models['nllb'] is not None:
 
 
 
 
100
  return translate_with_nllb(text, "swh_Latn")
101
  return f"Translation failed: {str(e)[:200]}"
102
 
103
- def translate_with_somali_m2m(text):
104
- """Translate Somali text using M2M100 model"""
105
  try:
106
- if models.get('somali') is None or tokenizers.get('somali') is None:
107
- return "Somali translation model not available"
108
 
109
  # Set source language
110
- tokenizers['somali'].src_lang = "so"
111
 
112
  # Tokenize input
113
- inputs = tokenizers['somali'](text, return_tensors="pt", truncation=True, max_length=512)
114
 
115
  # Generate translation to English
116
  with torch.no_grad():
117
- generated_tokens = models['somali'].generate(
118
  **inputs,
119
- forced_bos_token_id=tokenizers['somali'].get_lang_id("en"),
120
  max_length=256,
121
- num_beams=5,
122
  early_stopping=True
123
  )
124
 
125
  # Decode
126
- translation = tokenizers['somali'].batch_decode(generated_tokens, skip_special_tokens=True)[0]
127
  return translation
128
 
129
  except Exception as e:
130
- print(f"Somali M2M100 translation error: {e}")
131
  # Fallback to NLLB if available
132
- if models['nllb'] is not None:
133
- return translate_with_nllb(text, "som_Latn")
 
 
134
  return f"Translation failed: {str(e)[:200]}"
135
 
136
  def translate_with_nllb(text, source_lang_code):
@@ -151,7 +160,7 @@ def translate_with_nllb(text, source_lang_code):
151
  **inputs,
152
  forced_bos_token_id=forced_bos_token_id,
153
  max_length=256,
154
- num_beams=5,
155
  early_stopping=True
156
  )
157
 
@@ -174,10 +183,10 @@ def translate_text(text, source_language):
174
  config = LANGUAGE_CONFIG[source_language]
175
 
176
  try:
177
- if config["model_type"] == "swahili_mms":
178
- return translate_with_swahili_mms(text)
179
- elif config["model_type"] == "somali_m2m":
180
- return translate_with_somali_m2m(text)
181
  else: # nllb
182
  return translate_with_nllb(text, config["nllb_code"])
183
 
@@ -301,11 +310,11 @@ with gr.Blocks(
301
  gr.Markdown("### πŸ”§ Model Information")
302
 
303
  # Create status display
304
- swahili_status = "βœ… Loaded" if models.get('swahili') else "❌ Failed"
305
- somali_status = "βœ… Loaded" if models.get('somali') else "❌ Failed"
306
  nllb_status = "βœ… Loaded" if models.get('nllb') else "❌ Failed"
307
 
308
- status_text = f"Swahili MMS: {swahili_status} | Somali M2M100: {somali_status} | NLLB: {nllb_status}"
309
  gr.Textbox(
310
  value=status_text,
311
  label="Model Status",
@@ -315,15 +324,15 @@ with gr.Blocks(
315
  # Create model info
316
  gr.Markdown(f"""
317
  **Specialized Models:**
318
- - **Swahili:** Benjamin-png/swahili-mms-tts-finetuned
319
- - **Somali:** Ammad1Ali/m2m100_418M-2.0
320
  - **Other Languages:** Facebook NLLB-200
321
 
322
  **Features:**
323
- - High-quality specialized models for Swahili and Somali
324
- - NLLB-200 for other supported languages
325
- - Fast and accurate translations
326
- - Automatic fallback to ensure service availability
327
  """)
328
 
329
  # Add CSS for better styling
 
11
  },
12
  "Swahili": {
13
  "code": "swh",
14
+ "model_type": "helsinki_swahili",
15
+ "helsinki_code": "swc"
16
  },
17
  "Somali": {
18
  "code": "som",
19
+ "model_type": "m2m",
20
+ "m2m_code": "so"
21
  },
22
  "Afan Oromo": {
23
  "code": "gaz",
 
40
  models = {}
41
  tokenizers = {}
42
 
43
+ print("πŸš€ Initializing translation models...")
44
 
45
+ # Load Helsinki-NLP Swahili model
46
  try:
47
+ print("πŸ“₯ Loading Helsinki-NLP Swahili model...")
48
+ swahili_model_id = "Helsinki-NLP/opus-mt-swc-en"
49
+ tokenizers['helsinki_swahili'] = AutoTokenizer.from_pretrained(swahili_model_id)
50
+ models['helsinki_swahili'] = AutoModelForSeq2SeqLM.from_pretrained(swahili_model_id)
51
+ print("βœ… Helsinki-NLP Swahili model loaded successfully!")
 
 
 
 
 
 
52
  except Exception as e:
53
+ print(f"❌ Failed to load Helsinki-NLP Swahili model: {e}")
54
+ models['helsinki_swahili'] = None
55
 
56
+ # Load M2M100 model for Somali
57
  try:
58
+ print("πŸ“₯ Loading M2M100 model for Somali...")
59
+ m2m_model_id = "facebook/m2m100_418M"
60
+ tokenizers['m2m'] = AutoTokenizer.from_pretrained(m2m_model_id)
61
+ models['m2m'] = M2M100ForConditionalGeneration.from_pretrained(m2m_model_id)
62
+ print("βœ… M2M100 model loaded successfully!")
63
  except Exception as e:
64
+ print(f"❌ Failed to load M2M100 model: {e}")
65
+ models['m2m'] = None
66
 
67
  # Load NLLB model for other languages
68
  try:
 
75
  print(f"❌ Failed to load NLLB model: {e}")
76
  models['nllb'] = None
77
 
78
+ def translate_with_helsinki_swahili(text):
79
+ """Translate Swahili text using Helsinki-NLP model"""
80
  try:
81
+ if models.get('helsinki_swahili') is None or tokenizers.get('helsinki_swahili') is None:
82
  return "Swahili translation model not available"
83
 
84
+ # Tokenize input
85
+ inputs = tokenizers['helsinki_swahili'](text, return_tensors="pt", truncation=True, max_length=512)
86
+
87
+ # Generate translation
88
+ with torch.no_grad():
89
+ generated_tokens = models['helsinki_swahili'].generate(
90
+ **inputs,
91
+ max_length=256,
92
+ num_beams=5,
93
+ early_stopping=True
94
+ )
95
+
96
+ # Decode
97
+ translation = tokenizers['helsinki_swahili'].batch_decode(generated_tokens, skip_special_tokens=True)[0]
98
+ return translation
99
+
100
  except Exception as e:
101
+ print(f"Helsinki Swahili translation error: {e}")
102
+ # Fallback to M2M100 if available
103
+ if models.get('m2m') is not None:
104
+ return translate_with_m2m(text, "sw")
105
+ # Fallback to NLLB if available
106
+ elif models.get('nllb') is not None:
107
  return translate_with_nllb(text, "swh_Latn")
108
  return f"Translation failed: {str(e)[:200]}"
109
 
110
+ def translate_with_m2m(text, source_lang_code):
111
+ """Translate text using M2M100 model"""
112
  try:
113
+ if models.get('m2m') is None or tokenizers.get('m2m') is None:
114
+ return "M2M100 model not available"
115
 
116
  # Set source language
117
+ tokenizers['m2m'].src_lang = source_lang_code
118
 
119
  # Tokenize input
120
+ inputs = tokenizers['m2m'](text, return_tensors="pt", truncation=True, max_length=512)
121
 
122
  # Generate translation to English
123
  with torch.no_grad():
124
+ generated_tokens = models['m2m'].generate(
125
  **inputs,
126
+ forced_bos_token_id=tokenizers['m2m'].get_lang_id("en"),
127
  max_length=256,
128
+ num_beams=3,
129
  early_stopping=True
130
  )
131
 
132
  # Decode
133
+ translation = tokenizers['m2m'].batch_decode(generated_tokens, skip_special_tokens=True)[0]
134
  return translation
135
 
136
  except Exception as e:
137
+ print(f"M2M100 translation error: {e}")
138
  # Fallback to NLLB if available
139
+ if models.get('nllb') is not None:
140
+ lang_map = {"so": "som_Latn", "sw": "swh_Latn"}
141
+ nllb_code = lang_map.get(source_lang_code, "eng_Latn")
142
+ return translate_with_nllb(text, nllb_code)
143
  return f"Translation failed: {str(e)[:200]}"
144
 
145
  def translate_with_nllb(text, source_lang_code):
 
160
  **inputs,
161
  forced_bos_token_id=forced_bos_token_id,
162
  max_length=256,
163
+ num_beams=3,
164
  early_stopping=True
165
  )
166
 
 
183
  config = LANGUAGE_CONFIG[source_language]
184
 
185
  try:
186
+ if config["model_type"] == "helsinki_swahili":
187
+ return translate_with_helsinki_swahili(text)
188
+ elif config["model_type"] == "m2m":
189
+ return translate_with_m2m(text, config["m2m_code"])
190
  else: # nllb
191
  return translate_with_nllb(text, config["nllb_code"])
192
 
 
310
  gr.Markdown("### πŸ”§ Model Information")
311
 
312
  # Create status display
313
+ helsinki_status = "βœ… Loaded" if models.get('helsinki_swahili') else "❌ Failed"
314
+ m2m_status = "βœ… Loaded" if models.get('m2m') else "❌ Failed"
315
  nllb_status = "βœ… Loaded" if models.get('nllb') else "❌ Failed"
316
 
317
+ status_text = f"Helsinki Swahili: {helsinki_status} | M2M100: {m2m_status} | NLLB: {nllb_status}"
318
  gr.Textbox(
319
  value=status_text,
320
  label="Model Status",
 
324
  # Create model info
325
  gr.Markdown(f"""
326
  **Specialized Models:**
327
+ - **Swahili:** Helsinki-NLP/opus-mt-swc-en (Specialized Swahili→English)
328
+ - **Somali:** Facebook M2M100
329
  - **Other Languages:** Facebook NLLB-200
330
 
331
  **Features:**
332
+ - High-quality specialized model for Swahili translation
333
+ - Optimized models for each language family
334
+ - Cross-model fallback for reliability
335
+ - Fast and accurate results
336
  """)
337
 
338
  # Add CSS for better styling
requirements.txt CHANGED
@@ -6,4 +6,5 @@ soundfile>=0.12.0
6
  resampy>=0.4.0
7
  numpy>=1.24.0
8
  accelerate>=0.20.0
9
- sentencepiece>=0.1.99
 
 
6
  resampy>=0.4.0
7
  numpy>=1.24.0
8
  accelerate>=0.20.0
9
+ sentencepiece>=0.1.99
10
+ protobuf>=3.20.0