Minte
tts space
755fa07
# app.py
import gradio as gr
import torch
import torchaudio
from transformers import VitsModel, AutoTokenizer
import numpy as np
import io
import soundfile as sf
from datetime import datetime
import os
import tempfile
# Install uroman if not available
try:
from uroman import uroman
except ImportError:
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "uroman"])
from uroman import uroman
# Model configuration for each language
MODELS = {
"Amharic": "facebook/mms-tts-amh",
"Somali": "facebook/mms-tts-som",
"Swahili": "facebook/mms-tts-swh",
"Afan Oromo": "facebook/mms-tts-orm",
"Tigrinya": "facebook/mms-tts-tir",
"Chichewa": "facebook/mms-tts-swh" # Using Swahili as fallback
}
class MMS_TTS_Service:
def __init__(self):
self.models = {}
self.tokenizers = {}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
def load_model(self, language):
"""Load model for specific language"""
if language in self.models:
return self.models[language], self.tokenizers[language]
try:
model_name = MODELS[language]
print(f"Loading model for {language}: {model_name}")
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = VitsModel.from_pretrained(model_name)
model = model.to(self.device)
model.eval()
# Cache the loaded model
self.models[language] = model
self.tokenizers[language] = tokenizer
print(f"βœ… Successfully loaded model for {language}")
return model, tokenizer
except Exception as e:
print(f"❌ Error loading model for {language}: {e}")
raise e
def preprocess_text(self, text, language):
"""Preprocess text with romanization for Amharic and Tigrinya"""
if language in ["Amharic", "Tigrinya"]:
print(f"Romanizing {language} text...")
try:
# Romanize the text for Amharic and Tigrinya models
romanized_text = uroman(text)
print(f"Original: {text}")
print(f"Romanized: {romanized_text}")
return romanized_text
except Exception as e:
print(f"Romanization failed, using original text: {e}")
return text
else:
# For other languages, use text as is
return text
def generate_speech(self, text, language, speed=1.0):
"""Generate speech from text for specified language"""
try:
# Load model if not already loaded
model, tokenizer = self.load_model(language)
# Preprocess text (romanize for Amharic and Tigrinya)
processed_text = self.preprocess_text(text, language)
# Tokenize input text
inputs = tokenizer(processed_text, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.device)
# Generate speech with torch.no_grad for efficiency
with torch.no_grad():
outputs = model(input_ids)
waveform = outputs.waveform[0].cpu().numpy()
sample_rate = model.config.sampling_rate
# Adjust speed if needed
if speed != 1.0:
waveform = self.adjust_speed(waveform, sample_rate, speed)
return (sample_rate, waveform), None
except Exception as e:
error_msg = f"Error generating speech: {str(e)}"
print(error_msg)
return None, error_msg
def adjust_speed(self, waveform, sample_rate, speed_factor):
"""Adjust playback speed of audio"""
try:
# Simple resampling for speed adjustment
if speed_factor != 1.0:
new_length = int(len(waveform) / speed_factor)
indices = np.linspace(0, len(waveform) - 1, new_length)
waveform = np.interp(indices, np.arange(len(waveform)), waveform)
return waveform
except:
return waveform
def get_available_languages(self):
"""Get list of available languages"""
return list(MODELS.keys())
# Initialize TTS service
tts_service = MMS_TTS_Service()
def text_to_speech(text, language, speed=1.0):
"""
Main function for Gradio interface
"""
if not text.strip():
return None, "Please enter some text to convert to speech."
if len(text) > 500:
return None, "Text too long. Please keep it under 500 characters."
print(f"Generating speech for: '{text[:50]}...' in {language}")
# Generate speech
result, error = tts_service.generate_speech(text, language, speed)
if error:
return None, error
sample_rate, waveform = result
# Return as (sample_rate, audio_array) for gr.Audio
return (sample_rate, waveform), "βœ… Speech generated successfully!"
def create_demo_audio(language):
"""Create demo text for each language"""
demo_texts = {
"Amharic": "αˆ°αˆ‹αˆα£ α‹­αˆ… α‹¨α‹΅αˆα… αˆ›αˆ˜αŠ•αŒ« αˆžα‹΄αˆ αŠα‹α’ αŠ αˆ˜αˆ°αŒαŠ“αˆˆαˆ!",
"Somali": "Salaam, kani waa modelka cod-sameynta.",
"Swahili": "Halo, hii ni modeli ya kutengeneza sauti.",
"Afan Oromo": "Akkam, kun modeli sagalee uumuudha.",
"Tigrinya": "αˆ°αˆ‹αˆα£ αŠ₯α‹š α‹΅αˆαŒΊ α‹αŒˆα‰₯ር αˆžα‹΄αˆ αŠ₯ዩፒ α‹¨α‰αŠ•α‹¨αˆˆα‹­!",
"Chichewa": "Moni, iyi ndi modeli yopanga mawu."
}
return demo_texts.get(language, "Hello, this is a text-to-speech model.")
# Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="MMS Text-to-Speech") as demo:
gr.Markdown(
"""
# πŸŽ™οΈ MMS Text-to-Speech for African Languages
Convert text to natural speech in multiple African languages using Facebook's MMS-TTS models.
**Special Features for Amharic & Tigrinya:** Automatic romanization for better pronunciation
"""
)
with gr.Row():
with gr.Column():
language = gr.Dropdown(
choices=tts_service.get_available_languages(),
value="Amharic",
label="Select Language",
info="Choose the language for speech generation"
)
text_input = gr.Textbox(
lines=3,
placeholder="Enter text to convert to speech...",
label="Input Text",
info="Maximum 500 characters"
)
speed = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Speech Speed",
info="Adjust the playback speed"
)
with gr.Row():
generate_btn = gr.Button("Generate Speech", variant="primary")
clear_btn = gr.Button("Clear")
# Demo section
gr.Markdown("### 🎯 Quick Demo")
demo_btn = gr.Button("Load Demo Text")
demo_output = gr.Textbox(label="Demo Text", interactive=False)
with gr.Column():
audio_output = gr.Audio(
label="Generated Speech",
type="numpy",
interactive=False
)
status = gr.Textbox(
label="Status",
interactive=False,
placeholder="Ready to generate speech..."
)
# Batch processing section (simplified)
with gr.Accordion("πŸ“š Batch Processing (Advanced)", open=False):
gr.Markdown("Process multiple texts at once. Each line will be converted to a separate audio file.")
batch_text = gr.Textbox(
lines=4,
placeholder="Enter multiple texts, one per line...\nExample:\nHello\nHow are you?\nThank you",
label="Batch Texts",
info="Maximum 5 texts, each under 200 characters"
)
batch_btn = gr.Button("Process Batch Texts")
batch_status = gr.Textbox(label="Batch Processing Status")
# We'll use a gallery or multiple audio outputs for batch results
batch_results = gr.Gallery(
label="Batch Results",
show_label=True,
columns=2
)
# Event handlers
def generate_speech_handler(text, lang, spd):
if not text.strip():
return None, "Please enter some text."
return text_to_speech(text, lang, spd)
def clear_all():
return "", "", None, "Cleared!", "", None
def load_demo(lang):
return create_demo_audio(lang)
def process_batch(texts, lang, spd):
"""Process multiple texts and return file paths"""
if not texts.strip():
return None, "No texts provided.", []
text_list = [t.strip() for t in texts.split('\n') if t.strip()]
if len(text_list) > 5:
return None, "Maximum 5 texts allowed for batch processing.", []
# Validate each text
for i, text in enumerate(text_list):
if len(text) > 200:
return None, f"Text {i+1} is too long (max 200 characters).", []
results = []
error_count = 0
for i, text in enumerate(text_list):
result, error = tts_service.generate_speech(text, lang, spd)
if error:
error_count += 1
print(f"Error processing text {i+1}: {error}")
else:
sample_rate, waveform = result
# Create temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
sf.write(f.name, waveform, sample_rate)
results.append(f.name)
if error_count > 0:
status_msg = f"Processed {len(results)}/{len(text_list)} texts. {error_count} failed."
else:
status_msg = f"Successfully processed all {len(text_list)} texts!"
# Return first result as preview and all as files
preview_audio = (results[0] if results else None)
return preview_audio, status_msg, results
# Connect events
generate_btn.click(
fn=generate_speech_handler,
inputs=[text_input, language, speed],
outputs=[audio_output, status]
)
clear_btn.click(
fn=clear_all,
outputs=[text_input, demo_output, audio_output, status, batch_text, batch_results]
)
demo_btn.click(
fn=load_demo,
inputs=[language],
outputs=[demo_output]
)
batch_btn.click(
fn=process_batch,
inputs=[batch_text, language, speed],
outputs=[audio_output, batch_status, batch_results]
)
# Examples with better Amharic and Tigrinya samples
gr.Markdown("### πŸ’‘ Example Texts")
examples = [
["Amharic", "αˆαˆ‰αˆ αˆ°α‹ α‰ αˆαˆ‰αˆ መα‰₯ቢች αŠ₯ኩል αŠα‹α’ αŠ αˆ˜αˆ°αŒαŠ“αˆˆαˆ!"],
["Tigrinya", "αŠ©αˆ‰ ሰα‰₯ αŠ•αŠ©αˆ‰ αˆ˜αˆ°αˆ‹α‰΅ αŠ₯ኩል αŠ₯ዩፒ α‹¨α‰αŠ•α‹¨αˆˆα‹­!"],
["Somali", "Qof walba wuxuu leeyahay xuquuqda aadamaha."],
["Swahili", "Kila mtu ana haki zote za binadamu."],
["Afan Oromo", "Nama hundi mirga ummataa hundaa waliin dhalate."],
["Chichewa", "Alipo wina aliyense ali ndi ufulu wachibadwidwe."]
]
gr.Examples(
examples=examples,
inputs=[language, text_input],
outputs=[audio_output, status],
fn=generate_speech_handler,
cache_examples=False
)
# Language-specific information
with gr.Accordion("ℹ️ Language-Specific Information", open=False):
gr.Markdown("""
### Amharic & Tigrinya Support
- **Automatic Romanization**: Text is automatically converted to Latin script for better pronunciation
- **Native Script Support**: Works with Ge'ez script (αŠα‹°αˆ) characters
- **Enhanced Accuracy**: Romanization improves model performance for these languages
### Other Languages
- **Somali, Swahili, Afan Oromo**: Direct text processing
- **Chichewa**: Uses Swahili model as fallback
### Technical Details
- Uses Facebook's MMS-TTS models
- Automatic uroman romanization for Amharic and Tigrinya
- GPU acceleration when available
""")
# Footer
gr.Markdown(
"""
---
### ℹ️ About
**Powered by:** Facebook MMS-TTS Models
**Supported Languages:** Amharic, Somali, Swahili, Afan Oromo, Tigrinya, Chichewa
**Special Features:** Automatic romanization for Amharic & Tigrinya
**Model Type:** Text-to-Speech
**Max Text Length:** 500 characters (single), 200 characters (batch)
Note: First request may take longer as models are downloaded.
"""
)
if __name__ == "__main__":
# Pre-load a model to reduce first-time latency
print("πŸš€ Starting MMS Text-to-Speech Service...")
print("πŸ“‹ Supported Languages:", list(MODELS.keys()))
print("🌟 Special Romanization for: Amharic, Tigrinya")
# Pre-load Amharic model for faster first response
try:
tts_service.load_model("Amharic")
print("βœ… Pre-loaded Amharic model")
except Exception as e:
print("⚠️ Could not pre-load model:", e)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)