File size: 14,006 Bytes
6d28d4b
 
 
 
 
 
 
 
 
 
943a8da
6d28d4b
755fa07
 
 
 
 
 
 
 
 
6d28d4b
 
 
 
 
 
 
943a8da
6d28d4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755fa07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d28d4b
 
 
 
 
 
755fa07
 
 
6d28d4b
755fa07
6d28d4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
943a8da
 
6d28d4b
 
943a8da
6d28d4b
755fa07
6d28d4b
 
 
755fa07
6d28d4b
 
 
943a8da
6d28d4b
 
 
 
 
 
 
755fa07
 
6d28d4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
943a8da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d28d4b
 
 
 
 
 
 
 
943a8da
6d28d4b
 
 
 
 
943a8da
6d28d4b
943a8da
6d28d4b
 
943a8da
 
 
 
 
 
 
6d28d4b
943a8da
 
6d28d4b
943a8da
 
 
 
 
 
 
 
 
 
 
6d28d4b
943a8da
 
 
 
6d28d4b
943a8da
 
 
6d28d4b
 
 
 
 
 
 
 
 
 
943a8da
6d28d4b
 
 
 
 
 
 
 
 
 
 
943a8da
6d28d4b
 
755fa07
6d28d4b
 
755fa07
 
6d28d4b
 
 
943a8da
6d28d4b
 
 
 
 
 
 
 
 
 
755fa07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d28d4b
 
 
 
 
 
 
755fa07
6d28d4b
943a8da
6d28d4b
943a8da
6d28d4b
 
 
 
 
 
 
755fa07
6d28d4b
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
# app.py
import gradio as gr
import torch
import torchaudio
from transformers import VitsModel, AutoTokenizer
import numpy as np
import io
import soundfile as sf
from datetime import datetime
import os
import tempfile

# Install uroman if not available
try:
    from uroman import uroman
except ImportError:
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "uroman"])
    from uroman import uroman

# Model configuration for each language
MODELS = {
    "Amharic": "facebook/mms-tts-amh",
    "Somali": "facebook/mms-tts-som", 
    "Swahili": "facebook/mms-tts-swh",
    "Afan Oromo": "facebook/mms-tts-orm",
    "Tigrinya": "facebook/mms-tts-tir",
    "Chichewa": "facebook/mms-tts-swh"  # Using Swahili as fallback
}

class MMS_TTS_Service:
    def __init__(self):
        self.models = {}
        self.tokenizers = {}
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
    
    def load_model(self, language):
        """Load model for specific language"""
        if language in self.models:
            return self.models[language], self.tokenizers[language]
        
        try:
            model_name = MODELS[language]
            print(f"Loading model for {language}: {model_name}")
            
            # Load tokenizer and model
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = VitsModel.from_pretrained(model_name)
            model = model.to(self.device)
            model.eval()
            
            # Cache the loaded model
            self.models[language] = model
            self.tokenizers[language] = tokenizer
            
            print(f"βœ… Successfully loaded model for {language}")
            return model, tokenizer
            
        except Exception as e:
            print(f"❌ Error loading model for {language}: {e}")
            raise e
    
    def preprocess_text(self, text, language):
        """Preprocess text with romanization for Amharic and Tigrinya"""
        if language in ["Amharic", "Tigrinya"]:
            print(f"Romanizing {language} text...")
            try:
                # Romanize the text for Amharic and Tigrinya models
                romanized_text = uroman(text)
                print(f"Original: {text}")
                print(f"Romanized: {romanized_text}")
                return romanized_text
            except Exception as e:
                print(f"Romanization failed, using original text: {e}")
                return text
        else:
            # For other languages, use text as is
            return text
    
    def generate_speech(self, text, language, speed=1.0):
        """Generate speech from text for specified language"""
        try:
            # Load model if not already loaded
            model, tokenizer = self.load_model(language)
            
            # Preprocess text (romanize for Amharic and Tigrinya)
            processed_text = self.preprocess_text(text, language)
            
            # Tokenize input text
            inputs = tokenizer(processed_text, return_tensors="pt")
            input_ids = inputs["input_ids"].to(self.device)
            
            # Generate speech with torch.no_grad for efficiency
            with torch.no_grad():
                outputs = model(input_ids)
                waveform = outputs.waveform[0].cpu().numpy()
                sample_rate = model.config.sampling_rate
            
            # Adjust speed if needed
            if speed != 1.0:
                waveform = self.adjust_speed(waveform, sample_rate, speed)
            
            return (sample_rate, waveform), None
            
        except Exception as e:
            error_msg = f"Error generating speech: {str(e)}"
            print(error_msg)
            return None, error_msg
    
    def adjust_speed(self, waveform, sample_rate, speed_factor):
        """Adjust playback speed of audio"""
        try:
            # Simple resampling for speed adjustment
            if speed_factor != 1.0:
                new_length = int(len(waveform) / speed_factor)
                indices = np.linspace(0, len(waveform) - 1, new_length)
                waveform = np.interp(indices, np.arange(len(waveform)), waveform)
            return waveform
        except:
            return waveform
    
    def get_available_languages(self):
        """Get list of available languages"""
        return list(MODELS.keys())

# Initialize TTS service
tts_service = MMS_TTS_Service()

def text_to_speech(text, language, speed=1.0):
    """
    Main function for Gradio interface
    """
    if not text.strip():
        return None, "Please enter some text to convert to speech."
    
    if len(text) > 500:
        return None, "Text too long. Please keep it under 500 characters."
    
    print(f"Generating speech for: '{text[:50]}...' in {language}")
    
    # Generate speech
    result, error = tts_service.generate_speech(text, language, speed)
    
    if error:
        return None, error
    
    sample_rate, waveform = result
    
    # Return as (sample_rate, audio_array) for gr.Audio
    return (sample_rate, waveform), "βœ… Speech generated successfully!"

def create_demo_audio(language):
    """Create demo text for each language"""
    demo_texts = {
        "Amharic": "αˆ°αˆ‹αˆα£ α‹­αˆ… α‹¨α‹΅αˆα… αˆ›αˆ˜αŠ•αŒ« αˆžα‹΄αˆ αŠα‹α’ αŠ αˆ˜αˆ°αŒαŠ“αˆˆαˆ!",
        "Somali": "Salaam, kani waa modelka cod-sameynta.",
        "Swahili": "Halo, hii ni modeli ya kutengeneza sauti.",
        "Afan Oromo": "Akkam, kun modeli sagalee uumuudha.",
        "Tigrinya": "αˆ°αˆ‹αˆα£ αŠ₯α‹š α‹΅αˆαŒΊ α‹αŒˆα‰₯ር αˆžα‹΄αˆ αŠ₯ዩፒ α‹¨α‰αŠ•α‹¨αˆˆα‹­!",
        "Chichewa": "Moni, iyi ndi modeli yopanga mawu."
    }
    
    return demo_texts.get(language, "Hello, this is a text-to-speech model.")

# Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="MMS Text-to-Speech") as demo:
    gr.Markdown(
        """
        # πŸŽ™οΈ MMS Text-to-Speech for African Languages
        Convert text to natural speech in multiple African languages using Facebook's MMS-TTS models.
        
        **Special Features for Amharic & Tigrinya:** Automatic romanization for better pronunciation
        """
    )
    
    with gr.Row():
        with gr.Column():
            language = gr.Dropdown(
                choices=tts_service.get_available_languages(),
                value="Amharic",
                label="Select Language",
                info="Choose the language for speech generation"
            )
            
            text_input = gr.Textbox(
                lines=3,
                placeholder="Enter text to convert to speech...",
                label="Input Text",
                info="Maximum 500 characters"
            )
            
            speed = gr.Slider(
                minimum=0.5,
                maximum=2.0,
                value=1.0,
                step=0.1,
                label="Speech Speed",
                info="Adjust the playback speed"
            )
            
            with gr.Row():
                generate_btn = gr.Button("Generate Speech", variant="primary")
                clear_btn = gr.Button("Clear")
            
            # Demo section
            gr.Markdown("### 🎯 Quick Demo")
            demo_btn = gr.Button("Load Demo Text")
            demo_output = gr.Textbox(label="Demo Text", interactive=False)
        
        with gr.Column():
            audio_output = gr.Audio(
                label="Generated Speech",
                type="numpy",
                interactive=False
            )
            
            status = gr.Textbox(
                label="Status",
                interactive=False,
                placeholder="Ready to generate speech..."
            )
    
    # Batch processing section (simplified)
    with gr.Accordion("πŸ“š Batch Processing (Advanced)", open=False):
        gr.Markdown("Process multiple texts at once. Each line will be converted to a separate audio file.")
        
        batch_text = gr.Textbox(
            lines=4,
            placeholder="Enter multiple texts, one per line...\nExample:\nHello\nHow are you?\nThank you",
            label="Batch Texts",
            info="Maximum 5 texts, each under 200 characters"
        )
        
        batch_btn = gr.Button("Process Batch Texts")
        batch_status = gr.Textbox(label="Batch Processing Status")
        
        # We'll use a gallery or multiple audio outputs for batch results
        batch_results = gr.Gallery(
            label="Batch Results",
            show_label=True,
            columns=2
        )
    
    # Event handlers
    def generate_speech_handler(text, lang, spd):
        if not text.strip():
            return None, "Please enter some text."
        return text_to_speech(text, lang, spd)
    
    def clear_all():
        return "", "", None, "Cleared!", "", None
    
    def load_demo(lang):
        return create_demo_audio(lang)
    
    def process_batch(texts, lang, spd):
        """Process multiple texts and return file paths"""
        if not texts.strip():
            return None, "No texts provided.", []
        
        text_list = [t.strip() for t in texts.split('\n') if t.strip()]
        if len(text_list) > 5:
            return None, "Maximum 5 texts allowed for batch processing.", []
        
        # Validate each text
        for i, text in enumerate(text_list):
            if len(text) > 200:
                return None, f"Text {i+1} is too long (max 200 characters).", []
        
        results = []
        error_count = 0
        
        for i, text in enumerate(text_list):
            result, error = tts_service.generate_speech(text, lang, spd)
            if error:
                error_count += 1
                print(f"Error processing text {i+1}: {error}")
            else:
                sample_rate, waveform = result
                # Create temporary file
                with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
                    sf.write(f.name, waveform, sample_rate)
                    results.append(f.name)
        
        if error_count > 0:
            status_msg = f"Processed {len(results)}/{len(text_list)} texts. {error_count} failed."
        else:
            status_msg = f"Successfully processed all {len(text_list)} texts!"
        
        # Return first result as preview and all as files
        preview_audio = (results[0] if results else None)
        return preview_audio, status_msg, results
    
    # Connect events
    generate_btn.click(
        fn=generate_speech_handler,
        inputs=[text_input, language, speed],
        outputs=[audio_output, status]
    )
    
    clear_btn.click(
        fn=clear_all,
        outputs=[text_input, demo_output, audio_output, status, batch_text, batch_results]
    )
    
    demo_btn.click(
        fn=load_demo,
        inputs=[language],
        outputs=[demo_output]
    )
    
    batch_btn.click(
        fn=process_batch,
        inputs=[batch_text, language, speed],
        outputs=[audio_output, batch_status, batch_results]
    )
    
    # Examples with better Amharic and Tigrinya samples
    gr.Markdown("### πŸ’‘ Example Texts")
    examples = [
        ["Amharic", "αˆαˆ‰αˆ αˆ°α‹ α‰ αˆαˆ‰αˆ መα‰₯ቢች αŠ₯ኩል αŠα‹α’ αŠ αˆ˜αˆ°αŒαŠ“αˆˆαˆ!"],
        ["Tigrinya", "αŠ©αˆ‰ ሰα‰₯ αŠ•αŠ©αˆ‰ αˆ˜αˆ°αˆ‹α‰΅ αŠ₯ኩል αŠ₯ዩፒ α‹¨α‰αŠ•α‹¨αˆˆα‹­!"],
        ["Somali", "Qof walba wuxuu leeyahay xuquuqda aadamaha."],
        ["Swahili", "Kila mtu ana haki zote za binadamu."],
        ["Afan Oromo", "Nama hundi mirga ummataa hundaa waliin dhalate."],
        ["Chichewa", "Alipo wina aliyense ali ndi ufulu wachibadwidwe."]
    ]
    
    gr.Examples(
        examples=examples,
        inputs=[language, text_input],
        outputs=[audio_output, status],
        fn=generate_speech_handler,
        cache_examples=False
    )
    
    # Language-specific information
    with gr.Accordion("ℹ️ Language-Specific Information", open=False):
        gr.Markdown("""
        ### Amharic & Tigrinya Support
        - **Automatic Romanization**: Text is automatically converted to Latin script for better pronunciation
        - **Native Script Support**: Works with Ge'ez script (αŠα‹°αˆ) characters
        - **Enhanced Accuracy**: Romanization improves model performance for these languages
        
        ### Other Languages
        - **Somali, Swahili, Afan Oromo**: Direct text processing
        - **Chichewa**: Uses Swahili model as fallback
        
        ### Technical Details
        - Uses Facebook's MMS-TTS models
        - Automatic uroman romanization for Amharic and Tigrinya
        - GPU acceleration when available
        """)
    
    # Footer
    gr.Markdown(
        """
        ---
        ### ℹ️ About
        **Powered by:** Facebook MMS-TTS Models  
        **Supported Languages:** Amharic, Somali, Swahili, Afan Oromo, Tigrinya, Chichewa  
        **Special Features:** Automatic romanization for Amharic & Tigrinya  
        **Model Type:** Text-to-Speech  
        **Max Text Length:** 500 characters (single), 200 characters (batch)
        
        Note: First request may take longer as models are downloaded.
        """
    )

if __name__ == "__main__":
    # Pre-load a model to reduce first-time latency
    print("πŸš€ Starting MMS Text-to-Speech Service...")
    print("πŸ“‹ Supported Languages:", list(MODELS.keys()))
    print("🌟 Special Romanization for: Amharic, Tigrinya")
    
    # Pre-load Amharic model for faster first response
    try:
        tts_service.load_model("Amharic")
        print("βœ… Pre-loaded Amharic model")
    except Exception as e:
        print("⚠️ Could not pre-load model:", e)
    
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_error=True
    )