# app.py import gradio as gr import torch import torchaudio from transformers import VitsModel, AutoTokenizer import numpy as np import io import soundfile as sf from datetime import datetime import os import tempfile # Install uroman if not available try: from uroman import uroman except ImportError: import subprocess import sys subprocess.check_call([sys.executable, "-m", "pip", "install", "uroman"]) from uroman import uroman # Model configuration for each language MODELS = { "Amharic": "facebook/mms-tts-amh", "Somali": "facebook/mms-tts-som", "Swahili": "facebook/mms-tts-swh", "Afan Oromo": "facebook/mms-tts-orm", "Tigrinya": "facebook/mms-tts-tir", "Chichewa": "facebook/mms-tts-swh" # Using Swahili as fallback } class MMS_TTS_Service: def __init__(self): self.models = {} self.tokenizers = {} self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") def load_model(self, language): """Load model for specific language""" if language in self.models: return self.models[language], self.tokenizers[language] try: model_name = MODELS[language] print(f"Loading model for {language}: {model_name}") # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_name) model = VitsModel.from_pretrained(model_name) model = model.to(self.device) model.eval() # Cache the loaded model self.models[language] = model self.tokenizers[language] = tokenizer print(f"✅ Successfully loaded model for {language}") return model, tokenizer except Exception as e: print(f"❌ Error loading model for {language}: {e}") raise e def preprocess_text(self, text, language): """Preprocess text with romanization for Amharic and Tigrinya""" if language in ["Amharic", "Tigrinya"]: print(f"Romanizing {language} text...") try: # Romanize the text for Amharic and Tigrinya models romanized_text = uroman(text) print(f"Original: {text}") print(f"Romanized: {romanized_text}") return romanized_text except Exception as e: print(f"Romanization failed, using original text: {e}") return text else: # For other languages, use text as is return text def generate_speech(self, text, language, speed=1.0): """Generate speech from text for specified language""" try: # Load model if not already loaded model, tokenizer = self.load_model(language) # Preprocess text (romanize for Amharic and Tigrinya) processed_text = self.preprocess_text(text, language) # Tokenize input text inputs = tokenizer(processed_text, return_tensors="pt") input_ids = inputs["input_ids"].to(self.device) # Generate speech with torch.no_grad for efficiency with torch.no_grad(): outputs = model(input_ids) waveform = outputs.waveform[0].cpu().numpy() sample_rate = model.config.sampling_rate # Adjust speed if needed if speed != 1.0: waveform = self.adjust_speed(waveform, sample_rate, speed) return (sample_rate, waveform), None except Exception as e: error_msg = f"Error generating speech: {str(e)}" print(error_msg) return None, error_msg def adjust_speed(self, waveform, sample_rate, speed_factor): """Adjust playback speed of audio""" try: # Simple resampling for speed adjustment if speed_factor != 1.0: new_length = int(len(waveform) / speed_factor) indices = np.linspace(0, len(waveform) - 1, new_length) waveform = np.interp(indices, np.arange(len(waveform)), waveform) return waveform except: return waveform def get_available_languages(self): """Get list of available languages""" return list(MODELS.keys()) # Initialize TTS service tts_service = MMS_TTS_Service() def text_to_speech(text, language, speed=1.0): """ Main function for Gradio interface """ if not text.strip(): return None, "Please enter some text to convert to speech." if len(text) > 500: return None, "Text too long. Please keep it under 500 characters." print(f"Generating speech for: '{text[:50]}...' in {language}") # Generate speech result, error = tts_service.generate_speech(text, language, speed) if error: return None, error sample_rate, waveform = result # Return as (sample_rate, audio_array) for gr.Audio return (sample_rate, waveform), "✅ Speech generated successfully!" def create_demo_audio(language): """Create demo text for each language""" demo_texts = { "Amharic": "ሰላም፣ ይህ የድምፅ ማመንጫ ሞዴል ነው። አመሰግናለሁ!", "Somali": "Salaam, kani waa modelka cod-sameynta.", "Swahili": "Halo, hii ni modeli ya kutengeneza sauti.", "Afan Oromo": "Akkam, kun modeli sagalee uumuudha.", "Tigrinya": "ሰላም፣ እዚ ድምጺ ዝገብር ሞዴል እዩ። የቐንየለይ!", "Chichewa": "Moni, iyi ndi modeli yopanga mawu." } return demo_texts.get(language, "Hello, this is a text-to-speech model.") # Gradio interface with gr.Blocks(theme=gr.themes.Soft(), title="MMS Text-to-Speech") as demo: gr.Markdown( """ # 🎙️ MMS Text-to-Speech for African Languages Convert text to natural speech in multiple African languages using Facebook's MMS-TTS models. **Special Features for Amharic & Tigrinya:** Automatic romanization for better pronunciation """ ) with gr.Row(): with gr.Column(): language = gr.Dropdown( choices=tts_service.get_available_languages(), value="Amharic", label="Select Language", info="Choose the language for speech generation" ) text_input = gr.Textbox( lines=3, placeholder="Enter text to convert to speech...", label="Input Text", info="Maximum 500 characters" ) speed = gr.Slider( minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speech Speed", info="Adjust the playback speed" ) with gr.Row(): generate_btn = gr.Button("Generate Speech", variant="primary") clear_btn = gr.Button("Clear") # Demo section gr.Markdown("### 🎯 Quick Demo") demo_btn = gr.Button("Load Demo Text") demo_output = gr.Textbox(label="Demo Text", interactive=False) with gr.Column(): audio_output = gr.Audio( label="Generated Speech", type="numpy", interactive=False ) status = gr.Textbox( label="Status", interactive=False, placeholder="Ready to generate speech..." ) # Batch processing section (simplified) with gr.Accordion("📚 Batch Processing (Advanced)", open=False): gr.Markdown("Process multiple texts at once. Each line will be converted to a separate audio file.") batch_text = gr.Textbox( lines=4, placeholder="Enter multiple texts, one per line...\nExample:\nHello\nHow are you?\nThank you", label="Batch Texts", info="Maximum 5 texts, each under 200 characters" ) batch_btn = gr.Button("Process Batch Texts") batch_status = gr.Textbox(label="Batch Processing Status") # We'll use a gallery or multiple audio outputs for batch results batch_results = gr.Gallery( label="Batch Results", show_label=True, columns=2 ) # Event handlers def generate_speech_handler(text, lang, spd): if not text.strip(): return None, "Please enter some text." return text_to_speech(text, lang, spd) def clear_all(): return "", "", None, "Cleared!", "", None def load_demo(lang): return create_demo_audio(lang) def process_batch(texts, lang, spd): """Process multiple texts and return file paths""" if not texts.strip(): return None, "No texts provided.", [] text_list = [t.strip() for t in texts.split('\n') if t.strip()] if len(text_list) > 5: return None, "Maximum 5 texts allowed for batch processing.", [] # Validate each text for i, text in enumerate(text_list): if len(text) > 200: return None, f"Text {i+1} is too long (max 200 characters).", [] results = [] error_count = 0 for i, text in enumerate(text_list): result, error = tts_service.generate_speech(text, lang, spd) if error: error_count += 1 print(f"Error processing text {i+1}: {error}") else: sample_rate, waveform = result # Create temporary file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: sf.write(f.name, waveform, sample_rate) results.append(f.name) if error_count > 0: status_msg = f"Processed {len(results)}/{len(text_list)} texts. {error_count} failed." else: status_msg = f"Successfully processed all {len(text_list)} texts!" # Return first result as preview and all as files preview_audio = (results[0] if results else None) return preview_audio, status_msg, results # Connect events generate_btn.click( fn=generate_speech_handler, inputs=[text_input, language, speed], outputs=[audio_output, status] ) clear_btn.click( fn=clear_all, outputs=[text_input, demo_output, audio_output, status, batch_text, batch_results] ) demo_btn.click( fn=load_demo, inputs=[language], outputs=[demo_output] ) batch_btn.click( fn=process_batch, inputs=[batch_text, language, speed], outputs=[audio_output, batch_status, batch_results] ) # Examples with better Amharic and Tigrinya samples gr.Markdown("### 💡 Example Texts") examples = [ ["Amharic", "ሁሉም ሰው በሁሉም መብቶች እኩል ነው። አመሰግናለሁ!"], ["Tigrinya", "ኩሉ ሰብ ንኩሉ መሰላት እኩል እዩ። የቐንየለይ!"], ["Somali", "Qof walba wuxuu leeyahay xuquuqda aadamaha."], ["Swahili", "Kila mtu ana haki zote za binadamu."], ["Afan Oromo", "Nama hundi mirga ummataa hundaa waliin dhalate."], ["Chichewa", "Alipo wina aliyense ali ndi ufulu wachibadwidwe."] ] gr.Examples( examples=examples, inputs=[language, text_input], outputs=[audio_output, status], fn=generate_speech_handler, cache_examples=False ) # Language-specific information with gr.Accordion("ℹ️ Language-Specific Information", open=False): gr.Markdown(""" ### Amharic & Tigrinya Support - **Automatic Romanization**: Text is automatically converted to Latin script for better pronunciation - **Native Script Support**: Works with Ge'ez script (ፊደል) characters - **Enhanced Accuracy**: Romanization improves model performance for these languages ### Other Languages - **Somali, Swahili, Afan Oromo**: Direct text processing - **Chichewa**: Uses Swahili model as fallback ### Technical Details - Uses Facebook's MMS-TTS models - Automatic uroman romanization for Amharic and Tigrinya - GPU acceleration when available """) # Footer gr.Markdown( """ --- ### ℹ️ About **Powered by:** Facebook MMS-TTS Models **Supported Languages:** Amharic, Somali, Swahili, Afan Oromo, Tigrinya, Chichewa **Special Features:** Automatic romanization for Amharic & Tigrinya **Model Type:** Text-to-Speech **Max Text Length:** 500 characters (single), 200 characters (batch) Note: First request may take longer as models are downloaded. """ ) if __name__ == "__main__": # Pre-load a model to reduce first-time latency print("🚀 Starting MMS Text-to-Speech Service...") print("📋 Supported Languages:", list(MODELS.keys())) print("🌟 Special Romanization for: Amharic, Tigrinya") # Pre-load Amharic model for faster first response try: tts_service.load_model("Amharic") print("✅ Pre-loaded Amharic model") except Exception as e: print("⚠️ Could not pre-load model:", e) demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )