|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torchaudio |
|
|
from transformers import VitsModel, AutoTokenizer |
|
|
import numpy as np |
|
|
import io |
|
|
import soundfile as sf |
|
|
from datetime import datetime |
|
|
import os |
|
|
import tempfile |
|
|
|
|
|
|
|
|
try: |
|
|
from uroman import uroman |
|
|
except ImportError: |
|
|
import subprocess |
|
|
import sys |
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "uroman"]) |
|
|
from uroman import uroman |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"Amharic": "facebook/mms-tts-amh", |
|
|
"Somali": "facebook/mms-tts-som", |
|
|
"Swahili": "facebook/mms-tts-swh", |
|
|
"Afan Oromo": "facebook/mms-tts-orm", |
|
|
"Tigrinya": "facebook/mms-tts-tir", |
|
|
"Chichewa": "facebook/mms-tts-swh" |
|
|
} |
|
|
|
|
|
class MMS_TTS_Service: |
|
|
def __init__(self): |
|
|
self.models = {} |
|
|
self.tokenizers = {} |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
def load_model(self, language): |
|
|
"""Load model for specific language""" |
|
|
if language in self.models: |
|
|
return self.models[language], self.tokenizers[language] |
|
|
|
|
|
try: |
|
|
model_name = MODELS[language] |
|
|
print(f"Loading model for {language}: {model_name}") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = VitsModel.from_pretrained(model_name) |
|
|
model = model.to(self.device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
self.models[language] = model |
|
|
self.tokenizers[language] = tokenizer |
|
|
|
|
|
print(f"β
Successfully loaded model for {language}") |
|
|
return model, tokenizer |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error loading model for {language}: {e}") |
|
|
raise e |
|
|
|
|
|
def preprocess_text(self, text, language): |
|
|
"""Preprocess text with romanization for Amharic and Tigrinya""" |
|
|
if language in ["Amharic", "Tigrinya"]: |
|
|
print(f"Romanizing {language} text...") |
|
|
try: |
|
|
|
|
|
romanized_text = uroman(text) |
|
|
print(f"Original: {text}") |
|
|
print(f"Romanized: {romanized_text}") |
|
|
return romanized_text |
|
|
except Exception as e: |
|
|
print(f"Romanization failed, using original text: {e}") |
|
|
return text |
|
|
else: |
|
|
|
|
|
return text |
|
|
|
|
|
def generate_speech(self, text, language, speed=1.0): |
|
|
"""Generate speech from text for specified language""" |
|
|
try: |
|
|
|
|
|
model, tokenizer = self.load_model(language) |
|
|
|
|
|
|
|
|
processed_text = self.preprocess_text(text, language) |
|
|
|
|
|
|
|
|
inputs = tokenizer(processed_text, return_tensors="pt") |
|
|
input_ids = inputs["input_ids"].to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids) |
|
|
waveform = outputs.waveform[0].cpu().numpy() |
|
|
sample_rate = model.config.sampling_rate |
|
|
|
|
|
|
|
|
if speed != 1.0: |
|
|
waveform = self.adjust_speed(waveform, sample_rate, speed) |
|
|
|
|
|
return (sample_rate, waveform), None |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error generating speech: {str(e)}" |
|
|
print(error_msg) |
|
|
return None, error_msg |
|
|
|
|
|
def adjust_speed(self, waveform, sample_rate, speed_factor): |
|
|
"""Adjust playback speed of audio""" |
|
|
try: |
|
|
|
|
|
if speed_factor != 1.0: |
|
|
new_length = int(len(waveform) / speed_factor) |
|
|
indices = np.linspace(0, len(waveform) - 1, new_length) |
|
|
waveform = np.interp(indices, np.arange(len(waveform)), waveform) |
|
|
return waveform |
|
|
except: |
|
|
return waveform |
|
|
|
|
|
def get_available_languages(self): |
|
|
"""Get list of available languages""" |
|
|
return list(MODELS.keys()) |
|
|
|
|
|
|
|
|
tts_service = MMS_TTS_Service() |
|
|
|
|
|
def text_to_speech(text, language, speed=1.0): |
|
|
""" |
|
|
Main function for Gradio interface |
|
|
""" |
|
|
if not text.strip(): |
|
|
return None, "Please enter some text to convert to speech." |
|
|
|
|
|
if len(text) > 500: |
|
|
return None, "Text too long. Please keep it under 500 characters." |
|
|
|
|
|
print(f"Generating speech for: '{text[:50]}...' in {language}") |
|
|
|
|
|
|
|
|
result, error = tts_service.generate_speech(text, language, speed) |
|
|
|
|
|
if error: |
|
|
return None, error |
|
|
|
|
|
sample_rate, waveform = result |
|
|
|
|
|
|
|
|
return (sample_rate, waveform), "β
Speech generated successfully!" |
|
|
|
|
|
def create_demo_audio(language): |
|
|
"""Create demo text for each language""" |
|
|
demo_texts = { |
|
|
"Amharic": "α°ααα£ αα
α¨α΅αα
αααα« αα΄α ααα’ α αα°αααα!", |
|
|
"Somali": "Salaam, kani waa modelka cod-sameynta.", |
|
|
"Swahili": "Halo, hii ni modeli ya kutengeneza sauti.", |
|
|
"Afan Oromo": "Akkam, kun modeli sagalee uumuudha.", |
|
|
"Tigrinya": "α°ααα£ α₯α α΅ααΊ ααα₯α αα΄α α₯α©α’ α¨ααα¨αα!", |
|
|
"Chichewa": "Moni, iyi ndi modeli yopanga mawu." |
|
|
} |
|
|
|
|
|
return demo_texts.get(language, "Hello, this is a text-to-speech model.") |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="MMS Text-to-Speech") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# ποΈ MMS Text-to-Speech for African Languages |
|
|
Convert text to natural speech in multiple African languages using Facebook's MMS-TTS models. |
|
|
|
|
|
**Special Features for Amharic & Tigrinya:** Automatic romanization for better pronunciation |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
language = gr.Dropdown( |
|
|
choices=tts_service.get_available_languages(), |
|
|
value="Amharic", |
|
|
label="Select Language", |
|
|
info="Choose the language for speech generation" |
|
|
) |
|
|
|
|
|
text_input = gr.Textbox( |
|
|
lines=3, |
|
|
placeholder="Enter text to convert to speech...", |
|
|
label="Input Text", |
|
|
info="Maximum 500 characters" |
|
|
) |
|
|
|
|
|
speed = gr.Slider( |
|
|
minimum=0.5, |
|
|
maximum=2.0, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
label="Speech Speed", |
|
|
info="Adjust the playback speed" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
generate_btn = gr.Button("Generate Speech", variant="primary") |
|
|
clear_btn = gr.Button("Clear") |
|
|
|
|
|
|
|
|
gr.Markdown("### π― Quick Demo") |
|
|
demo_btn = gr.Button("Load Demo Text") |
|
|
demo_output = gr.Textbox(label="Demo Text", interactive=False) |
|
|
|
|
|
with gr.Column(): |
|
|
audio_output = gr.Audio( |
|
|
label="Generated Speech", |
|
|
type="numpy", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
placeholder="Ready to generate speech..." |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("π Batch Processing (Advanced)", open=False): |
|
|
gr.Markdown("Process multiple texts at once. Each line will be converted to a separate audio file.") |
|
|
|
|
|
batch_text = gr.Textbox( |
|
|
lines=4, |
|
|
placeholder="Enter multiple texts, one per line...\nExample:\nHello\nHow are you?\nThank you", |
|
|
label="Batch Texts", |
|
|
info="Maximum 5 texts, each under 200 characters" |
|
|
) |
|
|
|
|
|
batch_btn = gr.Button("Process Batch Texts") |
|
|
batch_status = gr.Textbox(label="Batch Processing Status") |
|
|
|
|
|
|
|
|
batch_results = gr.Gallery( |
|
|
label="Batch Results", |
|
|
show_label=True, |
|
|
columns=2 |
|
|
) |
|
|
|
|
|
|
|
|
def generate_speech_handler(text, lang, spd): |
|
|
if not text.strip(): |
|
|
return None, "Please enter some text." |
|
|
return text_to_speech(text, lang, spd) |
|
|
|
|
|
def clear_all(): |
|
|
return "", "", None, "Cleared!", "", None |
|
|
|
|
|
def load_demo(lang): |
|
|
return create_demo_audio(lang) |
|
|
|
|
|
def process_batch(texts, lang, spd): |
|
|
"""Process multiple texts and return file paths""" |
|
|
if not texts.strip(): |
|
|
return None, "No texts provided.", [] |
|
|
|
|
|
text_list = [t.strip() for t in texts.split('\n') if t.strip()] |
|
|
if len(text_list) > 5: |
|
|
return None, "Maximum 5 texts allowed for batch processing.", [] |
|
|
|
|
|
|
|
|
for i, text in enumerate(text_list): |
|
|
if len(text) > 200: |
|
|
return None, f"Text {i+1} is too long (max 200 characters).", [] |
|
|
|
|
|
results = [] |
|
|
error_count = 0 |
|
|
|
|
|
for i, text in enumerate(text_list): |
|
|
result, error = tts_service.generate_speech(text, lang, spd) |
|
|
if error: |
|
|
error_count += 1 |
|
|
print(f"Error processing text {i+1}: {error}") |
|
|
else: |
|
|
sample_rate, waveform = result |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
|
|
sf.write(f.name, waveform, sample_rate) |
|
|
results.append(f.name) |
|
|
|
|
|
if error_count > 0: |
|
|
status_msg = f"Processed {len(results)}/{len(text_list)} texts. {error_count} failed." |
|
|
else: |
|
|
status_msg = f"Successfully processed all {len(text_list)} texts!" |
|
|
|
|
|
|
|
|
preview_audio = (results[0] if results else None) |
|
|
return preview_audio, status_msg, results |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_speech_handler, |
|
|
inputs=[text_input, language, speed], |
|
|
outputs=[audio_output, status] |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=clear_all, |
|
|
outputs=[text_input, demo_output, audio_output, status, batch_text, batch_results] |
|
|
) |
|
|
|
|
|
demo_btn.click( |
|
|
fn=load_demo, |
|
|
inputs=[language], |
|
|
outputs=[demo_output] |
|
|
) |
|
|
|
|
|
batch_btn.click( |
|
|
fn=process_batch, |
|
|
inputs=[batch_text, language, speed], |
|
|
outputs=[audio_output, batch_status, batch_results] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### π‘ Example Texts") |
|
|
examples = [ |
|
|
["Amharic", "ααα α°α α ααα αα₯αΆα½ α₯α©α ααα’ α αα°αααα!"], |
|
|
["Tigrinya", "α©α α°α₯ αα©α αα°αα΅ α₯α©α α₯α©α’ α¨ααα¨αα!"], |
|
|
["Somali", "Qof walba wuxuu leeyahay xuquuqda aadamaha."], |
|
|
["Swahili", "Kila mtu ana haki zote za binadamu."], |
|
|
["Afan Oromo", "Nama hundi mirga ummataa hundaa waliin dhalate."], |
|
|
["Chichewa", "Alipo wina aliyense ali ndi ufulu wachibadwidwe."] |
|
|
] |
|
|
|
|
|
gr.Examples( |
|
|
examples=examples, |
|
|
inputs=[language, text_input], |
|
|
outputs=[audio_output, status], |
|
|
fn=generate_speech_handler, |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("βΉοΈ Language-Specific Information", open=False): |
|
|
gr.Markdown(""" |
|
|
### Amharic & Tigrinya Support |
|
|
- **Automatic Romanization**: Text is automatically converted to Latin script for better pronunciation |
|
|
- **Native Script Support**: Works with Ge'ez script (αα°α) characters |
|
|
- **Enhanced Accuracy**: Romanization improves model performance for these languages |
|
|
|
|
|
### Other Languages |
|
|
- **Somali, Swahili, Afan Oromo**: Direct text processing |
|
|
- **Chichewa**: Uses Swahili model as fallback |
|
|
|
|
|
### Technical Details |
|
|
- Uses Facebook's MMS-TTS models |
|
|
- Automatic uroman romanization for Amharic and Tigrinya |
|
|
- GPU acceleration when available |
|
|
""") |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
### βΉοΈ About |
|
|
**Powered by:** Facebook MMS-TTS Models |
|
|
**Supported Languages:** Amharic, Somali, Swahili, Afan Oromo, Tigrinya, Chichewa |
|
|
**Special Features:** Automatic romanization for Amharic & Tigrinya |
|
|
**Model Type:** Text-to-Speech |
|
|
**Max Text Length:** 500 characters (single), 200 characters (batch) |
|
|
|
|
|
Note: First request may take longer as models are downloaded. |
|
|
""" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("π Starting MMS Text-to-Speech Service...") |
|
|
print("π Supported Languages:", list(MODELS.keys())) |
|
|
print("π Special Romanization for: Amharic, Tigrinya") |
|
|
|
|
|
|
|
|
try: |
|
|
tts_service.load_model("Amharic") |
|
|
print("β
Pre-loaded Amharic model") |
|
|
except Exception as e: |
|
|
print("β οΈ Could not pre-load model:", e) |
|
|
|
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
show_error=True |
|
|
) |