Minte commited on
Commit
a5a20e8
Β·
1 Parent(s): 2f77ad3

resource issue

Browse files
Files changed (1) hide show
  1. app.py +73 -111
app.py CHANGED
@@ -1,23 +1,23 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoProcessor
4
 
5
- # Language configuration with proper model handling
6
  LANGUAGE_CONFIG = {
7
  "Amharic": {
8
  "code": "amh",
9
- "model_type": "seamless",
10
- "seamless_code": "amh"
11
  },
12
  "Swahili": {
13
  "code": "swh",
14
- "model_type": "seamless",
15
- "seamless_code": "swh"
16
  },
17
  "Somali": {
18
  "code": "som",
19
- "model_type": "seamless",
20
- "seamless_code": "som"
21
  },
22
  "Afan Oromo": {
23
  "code": "gaz",
@@ -37,110 +37,74 @@ LANGUAGE_CONFIG = {
37
  }
38
 
39
  # Model instances
40
- models = {}
41
- tokenizers = {}
42
- processors = {}
43
 
44
- print("πŸš€ Initializing translation models...")
45
 
46
- # Load SeamlessM4T model for Amharic, Swahili, Somali
47
  try:
48
- print("πŸ“₯ Loading SeamlessM4T model...")
49
- seamless_model_id = "facebook/seamless-m4t-v2-large"
50
- processors['seamless'] = AutoProcessor.from_pretrained(seamless_model_id)
51
- models['seamless'] = AutoModelForSeq2SeqLM.from_pretrained(seamless_model_id)
52
- print("βœ… SeamlessM4T model loaded successfully!")
53
- except Exception as e:
54
- print(f"❌ Failed to load SeamlessM4T model: {e}")
55
- models['seamless'] = None
56
- processors['seamless'] = None
57
-
58
- # Load NLLB model for other languages
59
- try:
60
- print("πŸ“₯ Loading NLLB model...")
61
- nllb_model_id = "facebook/nllb-200-distilled-600M"
62
- tokenizers['nllb'] = AutoTokenizer.from_pretrained(nllb_model_id)
63
- models['nllb'] = AutoModelForSeq2SeqLM.from_pretrained(nllb_model_id)
64
  print("βœ… NLLB model loaded successfully!")
65
  except Exception as e:
66
- print(f"❌ Failed to load NLLB model: {e}")
67
- models['nllb'] = None
68
- tokenizers['nllb'] = None
69
-
70
- def translate_with_seamless(text, source_lang_code):
71
- """Translate text using SeamlessM4T model"""
72
  try:
73
- if models['seamless'] is None or processors['seamless'] is None:
74
- return "SeamlessM4T model not available"
75
-
76
- # Preprocess text
77
- inputs = processors['seamless'](text=text, src_lang=source_lang_code, return_tensors="pt")
78
-
79
- # Get BOS token for target language (English)
80
- forced_bos_token_id = processors['seamless'].tokenizer.convert_tokens_to_ids("<|eng|>")
81
-
82
- # Generate translation
83
- with torch.no_grad():
84
- generated_tokens = models['seamless'].generate(
85
- **inputs,
86
- forced_bos_token_id=forced_bos_token_id,
87
- max_length=256
88
- )
89
-
90
- # Decode and return
91
- translation = processors['seamless'].batch_decode(generated_tokens, skip_special_tokens=True)[0]
92
- return translation
93
-
94
- except Exception as e:
95
- print(f"SeamlessM4T translation error: {e}")
96
- return f"Translation failed: {str(e)[:200]}"
97
 
98
- def translate_with_nllb(text, source_lang_code):
99
- """Translate text using NLLB model"""
 
 
 
 
 
 
 
 
 
 
 
100
  try:
101
- if models['nllb'] is None or tokenizers['nllb'] is None:
102
- return "NLLB model not available"
103
-
104
  # Tokenize input
105
- inputs = tokenizers['nllb'](text, return_tensors="pt")
 
 
 
106
 
107
  # Define target language (English)
108
- forced_bos_token_id = tokenizers['nllb'].convert_tokens_to_ids("eng_Latn")
109
 
110
- # Generate translation using beam search for better quality
111
  with torch.no_grad():
112
- generated_tokens = models['nllb'].generate(
113
  **inputs,
114
  forced_bos_token_id=forced_bos_token_id,
115
  max_length=256,
116
- num_beams=5,
117
- early_stopping=True
 
118
  )
119
 
120
  # Decode
121
- translation = tokenizers['nllb'].batch_decode(generated_tokens, skip_special_tokens=True)[0]
122
  return translation
123
 
124
- except Exception as e:
125
- print(f"NLLB translation error: {e}")
126
- return f"Translation failed: {str(e)[:200]}"
127
-
128
- def translate_text(text, source_language):
129
- """Main translation function"""
130
- if not text.strip():
131
- return "Please enter text to translate"
132
-
133
- if source_language not in LANGUAGE_CONFIG:
134
- return f"Translation for {source_language} is not supported"
135
-
136
- config = LANGUAGE_CONFIG[source_language]
137
-
138
- try:
139
- if config["model_type"] == "seamless":
140
- return translate_with_seamless(text, config["seamless_code"])
141
- else: # nllb
142
- return translate_with_nllb(text, config["nllb_code"])
143
-
144
  except Exception as e:
145
  print(f"Translation error for {source_language}: {e}")
146
  return f"Translation failed: {str(e)[:200]}"
@@ -155,17 +119,18 @@ EXAMPLE_TEXTS = {
155
  "Chichewa": "Alipo wina aliyense ali ndi ufulu wachibadwidwe."
156
  }
157
 
158
- # Test the models on startup
159
- def test_models():
160
- print("πŸ§ͺ Testing translation models...")
 
 
 
 
161
 
162
  test_cases = [
163
  ("Swahili", "Habari za asubuhi"),
164
  ("Somali", "Maanta waa maalin fiican"),
165
  ("Amharic", "αˆ°αˆ‹αˆ"),
166
- ("Afan Oromo", "Akkam jirta"),
167
- ("Tigrinya", "αˆ°αˆ‹αˆ"),
168
- ("Chichewa", "Moni")
169
  ]
170
 
171
  for lang, text in test_cases:
@@ -175,8 +140,9 @@ def test_models():
175
  except Exception as e:
176
  print(f"❌ {lang} test failed: {e}")
177
 
178
- # Run tests on startup
179
- test_models()
 
180
 
181
  # Create Gradio interface
182
  with gr.Blocks(
@@ -261,10 +227,9 @@ with gr.Blocks(
261
  gr.Markdown("### πŸ”§ Model Information")
262
 
263
  # Create status display
264
- seamless_status = "βœ… Loaded" if models.get('seamless') else "❌ Failed"
265
- nllb_status = "βœ… Loaded" if models.get('nllb') else "❌ Failed"
266
 
267
- status_text = f"SeamlessM4T: {seamless_status} | NLLB: {nllb_status}"
268
  gr.Textbox(
269
  value=status_text,
270
  label="Model Status",
@@ -272,18 +237,16 @@ with gr.Blocks(
272
  )
273
 
274
  # Create model info
275
- seamless_langs = [lang for lang, config in LANGUAGE_CONFIG.items() if config["model_type"] == "seamless"]
276
- nllb_langs = [lang for lang, config in LANGUAGE_CONFIG.items() if config["model_type"] == "nllb"]
277
-
278
  gr.Markdown(f"""
279
- **Advanced Models (SeamlessM4T):** {', '.join(seamless_langs)}
280
- **Standard Models (NLLB-200):** {', '.join(nllb_langs)}
 
281
 
282
  **Features:**
283
  - High-quality translations for African languages
284
  - Support for text input and copy-paste functionality
285
- - Fast and accurate results using beam search
286
- - Proper tokenization for each language family
287
  """)
288
 
289
  # Add CSS for better styling
@@ -304,5 +267,4 @@ if __name__ == "__main__":
304
  server_port=7860,
305
  share=False,
306
  show_error=True
307
- )
308
-
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ # Language configuration with optimized model selection
6
  LANGUAGE_CONFIG = {
7
  "Amharic": {
8
  "code": "amh",
9
+ "model_type": "nllb",
10
+ "nllb_code": "amh_Ethi"
11
  },
12
  "Swahili": {
13
  "code": "swh",
14
+ "model_type": "nllb",
15
+ "nllb_code": "swh_Latn"
16
  },
17
  "Somali": {
18
  "code": "som",
19
+ "model_type": "nllb",
20
+ "nllb_code": "som_Latn"
21
  },
22
  "Afan Oromo": {
23
  "code": "gaz",
 
37
  }
38
 
39
  # Model instances
40
+ model = None
41
+ tokenizer = None
 
42
 
43
+ print("πŸš€ Initializing translation model for Hugging Face Spaces...")
44
 
45
+ # Load a smaller, more efficient NLLB model
46
  try:
47
+ print("πŸ“₯ Loading NLLB-200-1.3B model...")
48
+ model_id = "facebook/nllb-200-1.3B"
49
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
50
+ model = AutoModelForSeq2SeqLM.from_pretrained(
51
+ model_id,
52
+ torch_dtype=torch.float16, # Use half precision to save memory
53
+ device_map="auto"
54
+ )
 
 
 
 
 
 
 
 
55
  print("βœ… NLLB model loaded successfully!")
56
  except Exception as e:
57
+ print(f"❌ Failed to load NLLB-200-1.3B: {e}")
 
 
 
 
 
58
  try:
59
+ # Fallback to even smaller model
60
+ print("πŸ”„ Trying smaller model: NLLB-200-distilled-600M...")
61
+ model_id = "facebook/nllb-200-distilled-600M"
62
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
63
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
64
+ print("βœ… NLLB distilled model loaded successfully!")
65
+ except Exception as e2:
66
+ print(f"❌ All models failed to load: {e2}")
67
+ model = None
68
+ tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ def translate_text(text, source_language):
71
+ """Main translation function"""
72
+ if not text.strip():
73
+ return "Please enter text to translate"
74
+
75
+ if source_language not in LANGUAGE_CONFIG:
76
+ return f"Translation for {source_language} is not supported"
77
+
78
+ if model is None or tokenizer is None:
79
+ return "Translation model is not available. Please try again later."
80
+
81
+ config = LANGUAGE_CONFIG[source_language]
82
+
83
  try:
 
 
 
84
  # Tokenize input
85
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
86
+
87
+ # Move to same device as model
88
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
89
 
90
  # Define target language (English)
91
+ forced_bos_token_id = tokenizer.convert_tokens_to_ids("eng_Latn")
92
 
93
+ # Generate translation with optimized settings for HF Spaces
94
  with torch.no_grad():
95
+ generated_tokens = model.generate(
96
  **inputs,
97
  forced_bos_token_id=forced_bos_token_id,
98
  max_length=256,
99
+ num_beams=3, # Reduced for faster inference
100
+ early_stopping=True,
101
+ no_repeat_ngram_size=2
102
  )
103
 
104
  # Decode
105
+ translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
106
  return translation
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
  print(f"Translation error for {source_language}: {e}")
110
  return f"Translation failed: {str(e)[:200]}"
 
119
  "Chichewa": "Alipo wina aliyense ali ndi ufulu wachibadwidwe."
120
  }
121
 
122
+ # Test the model on startup
123
+ def test_model():
124
+ if model is None:
125
+ print("❌ No model available for testing")
126
+ return
127
+
128
+ print("πŸ§ͺ Testing translation model...")
129
 
130
  test_cases = [
131
  ("Swahili", "Habari za asubuhi"),
132
  ("Somali", "Maanta waa maalin fiican"),
133
  ("Amharic", "αˆ°αˆ‹αˆ"),
 
 
 
134
  ]
135
 
136
  for lang, text in test_cases:
 
140
  except Exception as e:
141
  print(f"❌ {lang} test failed: {e}")
142
 
143
+ # Run test if model is loaded
144
+ if model is not None:
145
+ test_model()
146
 
147
  # Create Gradio interface
148
  with gr.Blocks(
 
227
  gr.Markdown("### πŸ”§ Model Information")
228
 
229
  # Create status display
230
+ model_status = "βœ… Loaded" if model is not None else "❌ Failed to load"
 
231
 
232
+ status_text = f"NLLB-200 Model: {model_status}"
233
  gr.Textbox(
234
  value=status_text,
235
  label="Model Status",
 
237
  )
238
 
239
  # Create model info
 
 
 
240
  gr.Markdown(f"""
241
+ **Supported Languages:** {', '.join(LANGUAGE_CONFIG.keys())}
242
+
243
+ **Model:** NLLB-200 (No Language Left Behind)
244
 
245
  **Features:**
246
  - High-quality translations for African languages
247
  - Support for text input and copy-paste functionality
248
+ - Fast and accurate results
249
+ - Optimized for Hugging Face Spaces
250
  """)
251
 
252
  # Add CSS for better styling
 
267
  server_port=7860,
268
  share=False,
269
  show_error=True
270
+ )