File size: 14,006 Bytes
6d28d4b 943a8da 6d28d4b 755fa07 6d28d4b 943a8da 6d28d4b 755fa07 6d28d4b 755fa07 6d28d4b 755fa07 6d28d4b 943a8da 6d28d4b 943a8da 6d28d4b 755fa07 6d28d4b 755fa07 6d28d4b 943a8da 6d28d4b 755fa07 6d28d4b 943a8da 6d28d4b 943a8da 6d28d4b 943a8da 6d28d4b 943a8da 6d28d4b 943a8da 6d28d4b 943a8da 6d28d4b 943a8da 6d28d4b 943a8da 6d28d4b 943a8da 6d28d4b 943a8da 6d28d4b 943a8da 6d28d4b 755fa07 6d28d4b 755fa07 6d28d4b 943a8da 6d28d4b 755fa07 6d28d4b 755fa07 6d28d4b 943a8da 6d28d4b 943a8da 6d28d4b 755fa07 6d28d4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 |
# app.py
import gradio as gr
import torch
import torchaudio
from transformers import VitsModel, AutoTokenizer
import numpy as np
import io
import soundfile as sf
from datetime import datetime
import os
import tempfile
# Install uroman if not available
try:
from uroman import uroman
except ImportError:
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "uroman"])
from uroman import uroman
# Model configuration for each language
MODELS = {
"Amharic": "facebook/mms-tts-amh",
"Somali": "facebook/mms-tts-som",
"Swahili": "facebook/mms-tts-swh",
"Afan Oromo": "facebook/mms-tts-orm",
"Tigrinya": "facebook/mms-tts-tir",
"Chichewa": "facebook/mms-tts-swh" # Using Swahili as fallback
}
class MMS_TTS_Service:
def __init__(self):
self.models = {}
self.tokenizers = {}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
def load_model(self, language):
"""Load model for specific language"""
if language in self.models:
return self.models[language], self.tokenizers[language]
try:
model_name = MODELS[language]
print(f"Loading model for {language}: {model_name}")
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = VitsModel.from_pretrained(model_name)
model = model.to(self.device)
model.eval()
# Cache the loaded model
self.models[language] = model
self.tokenizers[language] = tokenizer
print(f"β
Successfully loaded model for {language}")
return model, tokenizer
except Exception as e:
print(f"β Error loading model for {language}: {e}")
raise e
def preprocess_text(self, text, language):
"""Preprocess text with romanization for Amharic and Tigrinya"""
if language in ["Amharic", "Tigrinya"]:
print(f"Romanizing {language} text...")
try:
# Romanize the text for Amharic and Tigrinya models
romanized_text = uroman(text)
print(f"Original: {text}")
print(f"Romanized: {romanized_text}")
return romanized_text
except Exception as e:
print(f"Romanization failed, using original text: {e}")
return text
else:
# For other languages, use text as is
return text
def generate_speech(self, text, language, speed=1.0):
"""Generate speech from text for specified language"""
try:
# Load model if not already loaded
model, tokenizer = self.load_model(language)
# Preprocess text (romanize for Amharic and Tigrinya)
processed_text = self.preprocess_text(text, language)
# Tokenize input text
inputs = tokenizer(processed_text, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.device)
# Generate speech with torch.no_grad for efficiency
with torch.no_grad():
outputs = model(input_ids)
waveform = outputs.waveform[0].cpu().numpy()
sample_rate = model.config.sampling_rate
# Adjust speed if needed
if speed != 1.0:
waveform = self.adjust_speed(waveform, sample_rate, speed)
return (sample_rate, waveform), None
except Exception as e:
error_msg = f"Error generating speech: {str(e)}"
print(error_msg)
return None, error_msg
def adjust_speed(self, waveform, sample_rate, speed_factor):
"""Adjust playback speed of audio"""
try:
# Simple resampling for speed adjustment
if speed_factor != 1.0:
new_length = int(len(waveform) / speed_factor)
indices = np.linspace(0, len(waveform) - 1, new_length)
waveform = np.interp(indices, np.arange(len(waveform)), waveform)
return waveform
except:
return waveform
def get_available_languages(self):
"""Get list of available languages"""
return list(MODELS.keys())
# Initialize TTS service
tts_service = MMS_TTS_Service()
def text_to_speech(text, language, speed=1.0):
"""
Main function for Gradio interface
"""
if not text.strip():
return None, "Please enter some text to convert to speech."
if len(text) > 500:
return None, "Text too long. Please keep it under 500 characters."
print(f"Generating speech for: '{text[:50]}...' in {language}")
# Generate speech
result, error = tts_service.generate_speech(text, language, speed)
if error:
return None, error
sample_rate, waveform = result
# Return as (sample_rate, audio_array) for gr.Audio
return (sample_rate, waveform), "β
Speech generated successfully!"
def create_demo_audio(language):
"""Create demo text for each language"""
demo_texts = {
"Amharic": "α°ααα£ αα
α¨α΅αα
αααα« αα΄α ααα’ α αα°αααα!",
"Somali": "Salaam, kani waa modelka cod-sameynta.",
"Swahili": "Halo, hii ni modeli ya kutengeneza sauti.",
"Afan Oromo": "Akkam, kun modeli sagalee uumuudha.",
"Tigrinya": "α°ααα£ α₯α α΅ααΊ ααα₯α αα΄α α₯α©α’ α¨ααα¨αα!",
"Chichewa": "Moni, iyi ndi modeli yopanga mawu."
}
return demo_texts.get(language, "Hello, this is a text-to-speech model.")
# Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="MMS Text-to-Speech") as demo:
gr.Markdown(
"""
# ποΈ MMS Text-to-Speech for African Languages
Convert text to natural speech in multiple African languages using Facebook's MMS-TTS models.
**Special Features for Amharic & Tigrinya:** Automatic romanization for better pronunciation
"""
)
with gr.Row():
with gr.Column():
language = gr.Dropdown(
choices=tts_service.get_available_languages(),
value="Amharic",
label="Select Language",
info="Choose the language for speech generation"
)
text_input = gr.Textbox(
lines=3,
placeholder="Enter text to convert to speech...",
label="Input Text",
info="Maximum 500 characters"
)
speed = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Speech Speed",
info="Adjust the playback speed"
)
with gr.Row():
generate_btn = gr.Button("Generate Speech", variant="primary")
clear_btn = gr.Button("Clear")
# Demo section
gr.Markdown("### π― Quick Demo")
demo_btn = gr.Button("Load Demo Text")
demo_output = gr.Textbox(label="Demo Text", interactive=False)
with gr.Column():
audio_output = gr.Audio(
label="Generated Speech",
type="numpy",
interactive=False
)
status = gr.Textbox(
label="Status",
interactive=False,
placeholder="Ready to generate speech..."
)
# Batch processing section (simplified)
with gr.Accordion("π Batch Processing (Advanced)", open=False):
gr.Markdown("Process multiple texts at once. Each line will be converted to a separate audio file.")
batch_text = gr.Textbox(
lines=4,
placeholder="Enter multiple texts, one per line...\nExample:\nHello\nHow are you?\nThank you",
label="Batch Texts",
info="Maximum 5 texts, each under 200 characters"
)
batch_btn = gr.Button("Process Batch Texts")
batch_status = gr.Textbox(label="Batch Processing Status")
# We'll use a gallery or multiple audio outputs for batch results
batch_results = gr.Gallery(
label="Batch Results",
show_label=True,
columns=2
)
# Event handlers
def generate_speech_handler(text, lang, spd):
if not text.strip():
return None, "Please enter some text."
return text_to_speech(text, lang, spd)
def clear_all():
return "", "", None, "Cleared!", "", None
def load_demo(lang):
return create_demo_audio(lang)
def process_batch(texts, lang, spd):
"""Process multiple texts and return file paths"""
if not texts.strip():
return None, "No texts provided.", []
text_list = [t.strip() for t in texts.split('\n') if t.strip()]
if len(text_list) > 5:
return None, "Maximum 5 texts allowed for batch processing.", []
# Validate each text
for i, text in enumerate(text_list):
if len(text) > 200:
return None, f"Text {i+1} is too long (max 200 characters).", []
results = []
error_count = 0
for i, text in enumerate(text_list):
result, error = tts_service.generate_speech(text, lang, spd)
if error:
error_count += 1
print(f"Error processing text {i+1}: {error}")
else:
sample_rate, waveform = result
# Create temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
sf.write(f.name, waveform, sample_rate)
results.append(f.name)
if error_count > 0:
status_msg = f"Processed {len(results)}/{len(text_list)} texts. {error_count} failed."
else:
status_msg = f"Successfully processed all {len(text_list)} texts!"
# Return first result as preview and all as files
preview_audio = (results[0] if results else None)
return preview_audio, status_msg, results
# Connect events
generate_btn.click(
fn=generate_speech_handler,
inputs=[text_input, language, speed],
outputs=[audio_output, status]
)
clear_btn.click(
fn=clear_all,
outputs=[text_input, demo_output, audio_output, status, batch_text, batch_results]
)
demo_btn.click(
fn=load_demo,
inputs=[language],
outputs=[demo_output]
)
batch_btn.click(
fn=process_batch,
inputs=[batch_text, language, speed],
outputs=[audio_output, batch_status, batch_results]
)
# Examples with better Amharic and Tigrinya samples
gr.Markdown("### π‘ Example Texts")
examples = [
["Amharic", "ααα α°α α ααα αα₯αΆα½ α₯α©α ααα’ α αα°αααα!"],
["Tigrinya", "α©α α°α₯ αα©α αα°αα΅ α₯α©α α₯α©α’ α¨ααα¨αα!"],
["Somali", "Qof walba wuxuu leeyahay xuquuqda aadamaha."],
["Swahili", "Kila mtu ana haki zote za binadamu."],
["Afan Oromo", "Nama hundi mirga ummataa hundaa waliin dhalate."],
["Chichewa", "Alipo wina aliyense ali ndi ufulu wachibadwidwe."]
]
gr.Examples(
examples=examples,
inputs=[language, text_input],
outputs=[audio_output, status],
fn=generate_speech_handler,
cache_examples=False
)
# Language-specific information
with gr.Accordion("βΉοΈ Language-Specific Information", open=False):
gr.Markdown("""
### Amharic & Tigrinya Support
- **Automatic Romanization**: Text is automatically converted to Latin script for better pronunciation
- **Native Script Support**: Works with Ge'ez script (αα°α) characters
- **Enhanced Accuracy**: Romanization improves model performance for these languages
### Other Languages
- **Somali, Swahili, Afan Oromo**: Direct text processing
- **Chichewa**: Uses Swahili model as fallback
### Technical Details
- Uses Facebook's MMS-TTS models
- Automatic uroman romanization for Amharic and Tigrinya
- GPU acceleration when available
""")
# Footer
gr.Markdown(
"""
---
### βΉοΈ About
**Powered by:** Facebook MMS-TTS Models
**Supported Languages:** Amharic, Somali, Swahili, Afan Oromo, Tigrinya, Chichewa
**Special Features:** Automatic romanization for Amharic & Tigrinya
**Model Type:** Text-to-Speech
**Max Text Length:** 500 characters (single), 200 characters (batch)
Note: First request may take longer as models are downloaded.
"""
)
if __name__ == "__main__":
# Pre-load a model to reduce first-time latency
print("π Starting MMS Text-to-Speech Service...")
print("π Supported Languages:", list(MODELS.keys()))
print("π Special Romanization for: Amharic, Tigrinya")
# Pre-load Amharic model for faster first response
try:
tts_service.load_model("Amharic")
print("β
Pre-loaded Amharic model")
except Exception as e:
print("β οΈ Could not pre-load model:", e)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
) |