Yareda21
a
47684ff
# import gradio as gr
# import tempfile
# import traceback
# from dotenv import load_dotenv
# from transformers import pipeline
# import soundfile as sf
# import transformers
# import torch
# load_dotenv()
# print(f"Transformers version: {transformers.__version__}")
# print(f"Torch version: {torch.__version__}")
# # -----------------------------
# # Romanizer Setup (CRITICAL FIX)
# # -----------------------------
# try:
# from uroman import Uroman
# romanizer = Uroman()
# print("βœ“ uroman successfully loaded.")
# except Exception:
# romanizer = None
# print("⚠ uroman not available. Falling back to unidecode.")
# from unidecode import unidecode
# # -----------------------------
# # Available Models
# # -----------------------------
# AVAILABLE_MODELS = {
# "facebook/mms-tts-amh": "facebook/mms-tts-amh",
# "AddisuSeteye/speecht5_tts_amharic": "AddisuSeteye/speecht5_tts_amharic",
# "Walelign/speecht5_tts_amharic": "Walelign/speecht5_tts_amharic",
# }
# # Cache loaded pipelines
# loaded_pipelines = {}
# def load_tts_pipeline(model_id: str):
# """Load and cache a text-to-speech pipeline."""
# if model_id not in loaded_pipelines:
# print(f"Loading pipeline for: {model_id}")
# loaded_pipelines[model_id] = pipeline(
# task="text-to-speech",
# model=model_id,
# device=-1 # CPU
# )
# return loaded_pipelines[model_id]
# def romanize_text(text: str) -> str:
# """
# Convert Amharic text to Latin script for MMS-TTS compatibility.
# Prevents tokenizer from producing empty token sequences.
# """
# if romanizer:
# try:
# return romanizer.romanize_string(text)
# except Exception:
# pass
# # fallback
# return unidecode(text)
# def synthesize_text(text: str, model_choice: str):
# """Generate speech locally using selected model."""
# if not text or not text.strip():
# return None, "Please enter some Amharic text."
# model_id = AVAILABLE_MODELS.get(model_choice)
# if not model_id:
# return None, "Selected model not found."
# # Only MMS supports local inference properly
# # if model_id != "facebook/mms-tts-amh":
# # return None, (
# # f"Model `{model_choice}` is not supported for local inference.\n"
# # "Use facebook/mms-tts-amh locally or deploy other models via API."
# # )
# try:
# tts = load_tts_pipeline(model_id)
# # -------- CRITICAL FIX --------
# romanized_text = romanize_text(text)
# if not romanized_text.strip():
# return None, (
# "Romanization produced empty text.\n"
# "Ensure uroman-python is installed properly."
# )
# print("Original text:", text)
# print("Romanized text:", romanized_text)
# result = tts(romanized_text)
# audio = result["audio"]
# sr = result["sampling_rate"]
# if audio is None or len(audio) == 0:
# return None, "Model returned empty audio."
# with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
# sf.write(tmp_file.name, audio, sr)
# wav_path = tmp_file.name
# return wav_path, f"βœ“ Successfully synthesized using {model_choice}"
# except RuntimeError as e:
# if "length must be non-negative" in str(e):
# return None, (
# "Tokenizer produced invalid sequence.\n"
# "Install uroman-python and restart the app.\n"
# "If problem persists, verify transformers >= 4.33."
# )
# traceback.print_exc()
# return None, f"Runtime error: {str(e)}"
# except Exception as e:
# traceback.print_exc()
# return None, f"Unexpected error: {str(e)}"
# # -----------------------------
# # Gradio UI
# # -----------------------------
# with gr.Blocks(title="Amharic Text-to-Speech") as app:
# gr.Markdown("""
# πŸ‡ͺπŸ‡Ή **Amharic Text-to-Speech (Local MMS Model)**
# **Steps**
# 1. Select model
# 2. Enter Amharic text
# 3. Click Generate
# """)
# with gr.Row():
# with gr.Column(scale=2):
# model_dropdown = gr.Dropdown(
# choices=list(AVAILABLE_MODELS.keys()),
# value="facebook/mms-tts-amh",
# label="Choose TTS Model"
# )
# text_input = gr.Textbox(
# label="Enter Amharic Text",
# placeholder="αˆ°αˆ‹αˆ α‹­αˆ… α‹¨αˆ™αŠ¨αˆ« α…αˆα αŠα‹α’",
# lines=5,
# )
# generate_button = gr.Button("Generate Audio πŸ”Š")
# with gr.Column(scale=1):
# audio_output = gr.Audio(
# label="Generated Audio",
# type="filepath"
# )
# status_text = gr.Textbox(
# label="Status",
# interactive=False
# )
# generate_button.click(
# fn=synthesize_text,
# inputs=[text_input, model_dropdown],
# outputs=[audio_output, status_text],
# )
# gr.Markdown("""
# ---
# **Local Model**
# - facebook/mms-tts-amh
# **API-only Models**
# - AddisuSeteye/speecht5_tts_amharic
# - Walelign/speecht5_tts_amharic
# """)
# if __name__ == "__main__":
# # Theme moved to launch() for Gradio 6 compatibility
# app.launch(theme=gr.themes.Base())
import gradio as gr
import tempfile
import traceback
from dotenv import load_dotenv
import torch
import soundfile as sf
from transformers import (
pipeline,
SpeechT5Processor,
SpeechT5ForTextToSpeech,
SpeechT5HifiGan
)
load_dotenv()
print("Torch:", torch.__version__)
# -----------------------------
# Models
# -----------------------------
AVAILABLE_MODELS = {
"facebook/mms-tts-amh": "facebook/mms-tts-amh",
"AddisuSeteye/speecht5_tts_amharic": "AddisuSeteye/speecht5_tts_amharic",
"Walelign/speecht5_tts_amharic": "Walelign/speecht5_tts_amharic",
}
loaded_models = {}
# -----------------------------
# MMS Loader
# -----------------------------
def load_mms(model_id):
if model_id not in loaded_models:
print(f"Loading MMS model: {model_id}")
loaded_models[model_id] = pipeline(
"text-to-speech",
model=model_id,
device=-1
)
return loaded_models[model_id]
# -----------------------------
# SpeechT5 Loader
# -----------------------------
def load_speecht5(model_id):
if model_id not in loaded_models:
print(f"Loading SpeechT5 model: {model_id}")
processor = SpeechT5Processor.from_pretrained(model_id)
model = SpeechT5ForTextToSpeech.from_pretrained(model_id)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
model.eval()
loaded_models[model_id] = {
"processor": processor,
"model": model,
"vocoder": vocoder
}
return loaded_models[model_id]
# -----------------------------
# Synthesis
# -----------------------------
def synthesize_text(text: str, model_choice: str):
if not text.strip():
return None, "Enter Amharic text."
model_id = AVAILABLE_MODELS.get(model_choice)
try:
# ---------------- MMS ----------------
if model_id == "facebook/mms-tts-amh":
tts = load_mms(model_id)
result = tts(text)
audio = result["audio"]
sr = result["sampling_rate"]
# --------------- SpeechT5 ---------------
else:
components = load_speecht5(model_id)
processor = components["processor"]
model = components["model"]
vocoder = components["vocoder"]
inputs = processor(text=text, return_tensors="pt")
# Neutral speaker embedding (required)
speaker_embeddings = torch.zeros((1, 512))
with torch.no_grad():
speech = model.generate_speech(
inputs["input_ids"],
speaker_embeddings=speaker_embeddings,
vocoder=vocoder
)
audio = speech.numpy()
sr = 16000
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
sf.write(f.name, audio, sr)
return f.name, f"βœ“ Generated using {model_choice}"
except Exception as e:
traceback.print_exc()
return None, f"Error: {str(e)}"
# -----------------------------
# UI
# -----------------------------
with gr.Blocks(title="Amharic Text-to-Speech") as app:
gr.Markdown("""
πŸ‡ͺπŸ‡Ή **Fully Local Amharic TTS**
All models run locally.
""")
with gr.Row():
with gr.Column(scale=2):
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value="facebook/mms-tts-amh",
label="Choose Model",
)
text_input = gr.Textbox(
label="Enter Amharic Text",
placeholder="αˆ°αˆ‹αˆ α‹­αˆ… α‹¨αˆ™αŠ¨αˆ« α…αˆα αŠα‹α’",
lines=5,
)
generate_button = gr.Button("Generate πŸ”Š")
with gr.Column(scale=1):
audio_output = gr.Audio(type="filepath")
status_text = gr.Textbox(interactive=False)
generate_button.click(
synthesize_text,
inputs=[text_input, model_dropdown],
outputs=[audio_output, status_text],
)
if __name__ == "__main__":
app.launch(theme=gr.themes.Base())