tts_gallery / app.py
Michael Hu
fix: ensure Kokoro TTS model is initialized before use and handle missing audio file edge case
9b4a9ad
import gradio as gr
import torch
import tempfile
import os
import numpy as np
import soundfile as sf
# Import our model factory
from src.models.factory import ModelFactory
# Patch torch.load to always use CPU
original_torch_load = torch.load
def patched_torch_load(f, map_location=None, **kwargs):
if map_location is None:
map_location = 'cpu'
return original_torch_load(f, map_location=map_location, **kwargs)
torch.load = patched_torch_load
# Get model descriptions
MODEL_DESCRIPTIONS = ModelFactory.get_model_descriptions()
# Models dictionary for UI display
MODELS = {
"ResembleAI/chatterbox": "Chatterbox",
"KittenML/KittenTTS": "KittenTTS",
"piper-tts": "Piper (no voice cloning)",
"SYSTRAN/faster-whisper": "Faster Whisper",
"hexgrad/kokoro": "Kokoro-82M",
"nari-labs/Dia-1.6B": "Dia TTS",
}
# Initialize model instances
tts_models = ModelFactory.get_tts_models()
stt_models = ModelFactory.get_stt_models()
# Initialize the models that need immediate initialization
for model_name in ["ResembleAI/chatterbox", "KittenML/KittenTTS"]:
if model_name in tts_models:
tts_models[model_name].initialize()
# Initialize the STT model
whisper_model = stt_models.get("SYSTRAN/faster-whisper")
if whisper_model:
whisper_model.initialize()
# Helper function to get Kokoro voices
def get_kokoro_voices(language_code):
"""
Get available voices for a specific Kokoro language code
Based on: https://huggingface.co/hexgrad/Kokoro-82M/blob/main/VOICES.md
"""
voice_map = {
# American English (a)
"a": [
"af_heart", "af_alloy", "af_aoede", "af_bella", "af_jessica",
"af_kore", "af_nicole", "af_nova", "af_river", "af_sarah",
"af_sky", "am_adam", "am_echo", "am_eric", "am_fenrir",
"am_liam", "am_michael", "am_onyx", "am_puck", "am_santa"
],
# British English (b)
"b": [
"bf_alice", "bf_emma", "bf_isabella", "bf_lily",
"bm_daniel", "bm_fable", "bm_george", "bm_lewis"
],
# Spanish (e)
"e": ["ef_dora", "em_alex", "em_santa"],
# French (f)
"f": ["ff_siwis"],
# Hindi (h)
"h": ["hf_alpha", "hf_beta", "hm_omega", "hm_psi"],
# Italian (i)
"i": ["if_sara", "im_nicola"],
# Japanese (j)
"j": ["jf_alpha", "jf_gongitsune", "jf_nezumi", "jf_tebukuro", "jm_kumo"],
# Brazilian Portuguese (p)
"p": ["pt_heart", "pt_sun", "pt_moon", "pt_star", "pt_cloud"],
# Mandarin Chinese (z)
"z": [
"zf_xiaobei", "zf_xiaoni", "zf_xiaoxiao", "zf_xiaoyi",
"zm_yunjian", "zm_yunxi", "zm_yunxia", "zm_yunyang"
]
}
return voice_map.get(language_code, ["af_heart"]) # Default to American English voices
# UI Functions for TTS Models
def tts_chatterbox(text, language, audio_prompt=None):
"""UI function for Chatterbox TTS"""
model = tts_models.get("ResembleAI/chatterbox")
if not model:
return None, "Model not available"
try:
audio_path = model.generate_speech(text, language=language, audio_prompt=audio_prompt)
return audio_path, ""
except Exception as e:
return None, f"Error: {str(e)}"
def tts_kittentts(text, audio_prompt=None):
"""UI function for KittenTTS"""
model = tts_models.get("KittenML/KittenTTS")
if not model:
return None, "Model not available"
try:
audio_path = model.generate_speech(text, audio_prompt=audio_prompt)
return audio_path, ""
except Exception as e:
return None, f"Error: {str(e)}"
def tts_piper(text, language, voice):
"""UI function for Piper TTS"""
model = tts_models.get("piper-tts")
if not model:
return None, "Model not available"
try:
model.initialize() # Ensure voices are scanned
audio_path = model.generate_speech(text, language=language, voice=voice)
return audio_path, ""
except Exception as e:
return None, f"Error: {str(e)}"
def tts_kokoro(text, language_code, voice_name):
"""UI function for Kokoro TTS"""
model = tts_models.get("hexgrad/kokoro")
if not model:
return None, "Model not available"
try:
# Initialize the model if not already initialized
if not model._initialized:
model.initialize()
# Generate speech (voice_name is kept for interface consistency but not used by Kokoro)
audio_path = model.generate_speech(text, lang_code=language_code, voice_name=voice_name)
# Check if audio file was actually created
if audio_path and os.path.exists(audio_path):
return audio_path, ""
else:
return None, "Error: Audio file was not generated"
except Exception as e:
return None, f"Error: {str(e)}"
def tts_dia(text, audio_prompt=None):
"""UI function for Dia TTS"""
model = tts_models.get("nari-labs/Dia-1.6B")
if not model:
return None, "Model not available"
try:
model.initialize() # Ensure model is loaded
audio_path = model.generate_speech(text, audio_prompt=audio_prompt)
return audio_path, ""
except Exception as e:
return None, f"Error: {str(e)}"
# UI Function for STT Model
def stt_whisper(audio_path, language=None):
"""UI function for Faster Whisper STT"""
model = stt_models.get("SYSTRAN/faster-whisper")
if not model:
return "Model not available"
try:
transcription = model.transcribe(audio_path, language=language)
return transcription
except Exception as e:
return f"Error: {str(e)}"
# Gradio UI Components
def create_tts_tab():
"""Create the TTS tab for the Gradio interface"""
with gr.Tab("Text-to-Speech"):
gr.Markdown("## Text-to-Speech Models")
with gr.Tabs():
# Chatterbox Tab
with gr.Tab("Chatterbox"):
with gr.Row():
with gr.Column():
chatterbox_text = gr.Textbox(
label="Text to speak",
placeholder="Enter text here...",
lines=5
)
chatterbox_language = gr.Dropdown(
choices=["English", "Chinese"],
value="English",
label="Language"
)
chatterbox_audio_prompt = gr.Audio(
label="Voice reference (optional)",
type="filepath"
)
chatterbox_submit = gr.Button("Generate Speech")
with gr.Column():
chatterbox_output = gr.Audio(label="Generated Speech")
chatterbox_error = gr.Textbox(label="Error", visible=False)
chatterbox_submit.click(
tts_chatterbox,
inputs=[chatterbox_text, chatterbox_language, chatterbox_audio_prompt],
outputs=[chatterbox_output, chatterbox_error]
)
# KittenTTS Tab
with gr.Tab("KittenTTS"):
with gr.Row():
with gr.Column():
kittentts_text = gr.Textbox(
label="Text to speak",
placeholder="Enter text here...",
lines=5
)
kittentts_audio_prompt = gr.Audio(
label="Voice reference (optional)",
type="filepath"
)
kittentts_submit = gr.Button("Generate Speech")
with gr.Column():
kittentts_output = gr.Audio(label="Generated Speech")
kittentts_error = gr.Textbox(label="Error", visible=False)
kittentts_submit.click(
tts_kittentts,
inputs=[kittentts_text, kittentts_audio_prompt],
outputs=[kittentts_output, kittentts_error]
)
# Piper Tab
with gr.Tab("Piper"):
with gr.Row():
with gr.Column():
piper_text = gr.Textbox(
label="Text to speak",
placeholder="Enter text here...",
lines=5
)
# Initialize Piper model to get voices
piper_model = tts_models.get("piper-tts")
if piper_model:
piper_model.initialize()
languages = piper_model.get_supported_languages()
else:
languages = ["English"]
piper_language = gr.Dropdown(
choices=languages,
value="English",
label="Language"
)
def update_piper_voices(language):
if piper_model:
voices = piper_model.get_available_voices(language)
return gr.update(choices=voices, value=voices[0] if voices else None)
return gr.update(choices=[], value=None)
piper_voice = gr.Dropdown(
label="Voice",
choices=[]
)
piper_language.change(
update_piper_voices,
inputs=[piper_language],
outputs=[piper_voice]
)
piper_submit = gr.Button("Generate Speech")
with gr.Column():
piper_output = gr.Audio(label="Generated Speech")
piper_error = gr.Textbox(label="Error", visible=False)
piper_submit.click(
tts_piper,
inputs=[piper_text, piper_language, piper_voice],
outputs=[piper_output, piper_error]
)
# Kokoro Tab
with gr.Tab("Kokoro"):
with gr.Row():
with gr.Column():
kokoro_text = gr.Textbox(
label="Text to speak",
placeholder="Enter text here...",
lines=5
)
kokoro_language = gr.Dropdown(
choices=[
"American English (a)", "British English (b)",
"Spanish (e)", "French (f)", "Hindi (h)",
"Italian (i)", "Japanese (j)",
"Brazilian Portuguese (p)", "Mandarin Chinese (z)"
],
value="American English (a)",
label="Language"
)
def get_lang_code(language):
return language.split("(")[-1].split(")")[0].strip()
def update_kokoro_voices(language):
lang_code = get_lang_code(language)
voices = get_kokoro_voices(lang_code)
return gr.update(choices=voices, value=voices[0] if voices else None)
kokoro_voice = gr.Dropdown(
label="Voice",
choices=get_kokoro_voices("a"),
value="af_heart"
)
kokoro_language.change(
update_kokoro_voices,
inputs=[kokoro_language],
outputs=[kokoro_voice]
)
kokoro_submit = gr.Button("Generate Speech")
with gr.Column():
kokoro_output = gr.Audio(label="Generated Speech")
kokoro_error = gr.Textbox(label="Error", visible=False)
kokoro_submit.click(
lambda text, lang, voice: tts_kokoro(text, get_lang_code(lang), voice),
inputs=[kokoro_text, kokoro_language, kokoro_voice],
outputs=[kokoro_output, kokoro_error]
)
# Dia Tab
with gr.Tab("Dia"):
with gr.Row():
with gr.Column():
dia_text = gr.Textbox(
label="Text to speak",
placeholder="Enter text here...",
lines=5
)
dia_audio_prompt = gr.Audio(
label="Voice reference (optional)",
type="filepath"
)
dia_submit = gr.Button("Generate Speech")
with gr.Column():
dia_output = gr.Audio(label="Generated Speech")
dia_error = gr.Textbox(label="Error", visible=False)
dia_submit.click(
tts_dia,
inputs=[dia_text, dia_audio_prompt],
outputs=[dia_output, dia_error]
)
def create_stt_tab():
"""Create the STT tab for the Gradio interface"""
with gr.Tab("Speech-to-Text"):
gr.Markdown("## Speech-to-Text Models")
with gr.Tabs():
# Faster Whisper Tab
with gr.Tab("Faster Whisper"):
with gr.Row():
with gr.Column():
whisper_audio = gr.Audio(
label="Audio to transcribe",
type="filepath"
)
whisper_language = gr.Dropdown(
choices=["Auto-detect", "English", "Chinese", "Spanish", "French", "German", "Japanese"],
value="Auto-detect",
label="Language (optional)"
)
whisper_submit = gr.Button("Transcribe")
with gr.Column():
whisper_output = gr.Textbox(
label="Transcription",
lines=5
)
whisper_submit.click(
lambda audio, lang: stt_whisper(audio, None if lang == "Auto-detect" else lang),
inputs=[whisper_audio, whisper_language],
outputs=[whisper_output]
)
# Create the Gradio interface
def create_interface():
"""Create the main Gradio interface"""
with gr.Blocks(title="TTS & STT Gallery") as demo:
gr.Markdown("# TTS & STT Model Gallery")
gr.Markdown("Explore different Text-to-Speech and Speech-to-Text models")
with gr.Tabs():
create_tts_tab()
create_stt_tab()
return demo
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch()