Spaces:
Running
Running
Michael Hu
fix: ensure Kokoro TTS model is initialized before use and handle missing audio file edge case
9b4a9ad
| import gradio as gr | |
| import torch | |
| import tempfile | |
| import os | |
| import numpy as np | |
| import soundfile as sf | |
| # Import our model factory | |
| from src.models.factory import ModelFactory | |
| # Patch torch.load to always use CPU | |
| original_torch_load = torch.load | |
| def patched_torch_load(f, map_location=None, **kwargs): | |
| if map_location is None: | |
| map_location = 'cpu' | |
| return original_torch_load(f, map_location=map_location, **kwargs) | |
| torch.load = patched_torch_load | |
| # Get model descriptions | |
| MODEL_DESCRIPTIONS = ModelFactory.get_model_descriptions() | |
| # Models dictionary for UI display | |
| MODELS = { | |
| "ResembleAI/chatterbox": "Chatterbox", | |
| "KittenML/KittenTTS": "KittenTTS", | |
| "piper-tts": "Piper (no voice cloning)", | |
| "SYSTRAN/faster-whisper": "Faster Whisper", | |
| "hexgrad/kokoro": "Kokoro-82M", | |
| "nari-labs/Dia-1.6B": "Dia TTS", | |
| } | |
| # Initialize model instances | |
| tts_models = ModelFactory.get_tts_models() | |
| stt_models = ModelFactory.get_stt_models() | |
| # Initialize the models that need immediate initialization | |
| for model_name in ["ResembleAI/chatterbox", "KittenML/KittenTTS"]: | |
| if model_name in tts_models: | |
| tts_models[model_name].initialize() | |
| # Initialize the STT model | |
| whisper_model = stt_models.get("SYSTRAN/faster-whisper") | |
| if whisper_model: | |
| whisper_model.initialize() | |
| # Helper function to get Kokoro voices | |
| def get_kokoro_voices(language_code): | |
| """ | |
| Get available voices for a specific Kokoro language code | |
| Based on: https://huggingface.co/hexgrad/Kokoro-82M/blob/main/VOICES.md | |
| """ | |
| voice_map = { | |
| # American English (a) | |
| "a": [ | |
| "af_heart", "af_alloy", "af_aoede", "af_bella", "af_jessica", | |
| "af_kore", "af_nicole", "af_nova", "af_river", "af_sarah", | |
| "af_sky", "am_adam", "am_echo", "am_eric", "am_fenrir", | |
| "am_liam", "am_michael", "am_onyx", "am_puck", "am_santa" | |
| ], | |
| # British English (b) | |
| "b": [ | |
| "bf_alice", "bf_emma", "bf_isabella", "bf_lily", | |
| "bm_daniel", "bm_fable", "bm_george", "bm_lewis" | |
| ], | |
| # Spanish (e) | |
| "e": ["ef_dora", "em_alex", "em_santa"], | |
| # French (f) | |
| "f": ["ff_siwis"], | |
| # Hindi (h) | |
| "h": ["hf_alpha", "hf_beta", "hm_omega", "hm_psi"], | |
| # Italian (i) | |
| "i": ["if_sara", "im_nicola"], | |
| # Japanese (j) | |
| "j": ["jf_alpha", "jf_gongitsune", "jf_nezumi", "jf_tebukuro", "jm_kumo"], | |
| # Brazilian Portuguese (p) | |
| "p": ["pt_heart", "pt_sun", "pt_moon", "pt_star", "pt_cloud"], | |
| # Mandarin Chinese (z) | |
| "z": [ | |
| "zf_xiaobei", "zf_xiaoni", "zf_xiaoxiao", "zf_xiaoyi", | |
| "zm_yunjian", "zm_yunxi", "zm_yunxia", "zm_yunyang" | |
| ] | |
| } | |
| return voice_map.get(language_code, ["af_heart"]) # Default to American English voices | |
| # UI Functions for TTS Models | |
| def tts_chatterbox(text, language, audio_prompt=None): | |
| """UI function for Chatterbox TTS""" | |
| model = tts_models.get("ResembleAI/chatterbox") | |
| if not model: | |
| return None, "Model not available" | |
| try: | |
| audio_path = model.generate_speech(text, language=language, audio_prompt=audio_prompt) | |
| return audio_path, "" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def tts_kittentts(text, audio_prompt=None): | |
| """UI function for KittenTTS""" | |
| model = tts_models.get("KittenML/KittenTTS") | |
| if not model: | |
| return None, "Model not available" | |
| try: | |
| audio_path = model.generate_speech(text, audio_prompt=audio_prompt) | |
| return audio_path, "" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def tts_piper(text, language, voice): | |
| """UI function for Piper TTS""" | |
| model = tts_models.get("piper-tts") | |
| if not model: | |
| return None, "Model not available" | |
| try: | |
| model.initialize() # Ensure voices are scanned | |
| audio_path = model.generate_speech(text, language=language, voice=voice) | |
| return audio_path, "" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def tts_kokoro(text, language_code, voice_name): | |
| """UI function for Kokoro TTS""" | |
| model = tts_models.get("hexgrad/kokoro") | |
| if not model: | |
| return None, "Model not available" | |
| try: | |
| # Initialize the model if not already initialized | |
| if not model._initialized: | |
| model.initialize() | |
| # Generate speech (voice_name is kept for interface consistency but not used by Kokoro) | |
| audio_path = model.generate_speech(text, lang_code=language_code, voice_name=voice_name) | |
| # Check if audio file was actually created | |
| if audio_path and os.path.exists(audio_path): | |
| return audio_path, "" | |
| else: | |
| return None, "Error: Audio file was not generated" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def tts_dia(text, audio_prompt=None): | |
| """UI function for Dia TTS""" | |
| model = tts_models.get("nari-labs/Dia-1.6B") | |
| if not model: | |
| return None, "Model not available" | |
| try: | |
| model.initialize() # Ensure model is loaded | |
| audio_path = model.generate_speech(text, audio_prompt=audio_prompt) | |
| return audio_path, "" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| # UI Function for STT Model | |
| def stt_whisper(audio_path, language=None): | |
| """UI function for Faster Whisper STT""" | |
| model = stt_models.get("SYSTRAN/faster-whisper") | |
| if not model: | |
| return "Model not available" | |
| try: | |
| transcription = model.transcribe(audio_path, language=language) | |
| return transcription | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Gradio UI Components | |
| def create_tts_tab(): | |
| """Create the TTS tab for the Gradio interface""" | |
| with gr.Tab("Text-to-Speech"): | |
| gr.Markdown("## Text-to-Speech Models") | |
| with gr.Tabs(): | |
| # Chatterbox Tab | |
| with gr.Tab("Chatterbox"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| chatterbox_text = gr.Textbox( | |
| label="Text to speak", | |
| placeholder="Enter text here...", | |
| lines=5 | |
| ) | |
| chatterbox_language = gr.Dropdown( | |
| choices=["English", "Chinese"], | |
| value="English", | |
| label="Language" | |
| ) | |
| chatterbox_audio_prompt = gr.Audio( | |
| label="Voice reference (optional)", | |
| type="filepath" | |
| ) | |
| chatterbox_submit = gr.Button("Generate Speech") | |
| with gr.Column(): | |
| chatterbox_output = gr.Audio(label="Generated Speech") | |
| chatterbox_error = gr.Textbox(label="Error", visible=False) | |
| chatterbox_submit.click( | |
| tts_chatterbox, | |
| inputs=[chatterbox_text, chatterbox_language, chatterbox_audio_prompt], | |
| outputs=[chatterbox_output, chatterbox_error] | |
| ) | |
| # KittenTTS Tab | |
| with gr.Tab("KittenTTS"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| kittentts_text = gr.Textbox( | |
| label="Text to speak", | |
| placeholder="Enter text here...", | |
| lines=5 | |
| ) | |
| kittentts_audio_prompt = gr.Audio( | |
| label="Voice reference (optional)", | |
| type="filepath" | |
| ) | |
| kittentts_submit = gr.Button("Generate Speech") | |
| with gr.Column(): | |
| kittentts_output = gr.Audio(label="Generated Speech") | |
| kittentts_error = gr.Textbox(label="Error", visible=False) | |
| kittentts_submit.click( | |
| tts_kittentts, | |
| inputs=[kittentts_text, kittentts_audio_prompt], | |
| outputs=[kittentts_output, kittentts_error] | |
| ) | |
| # Piper Tab | |
| with gr.Tab("Piper"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| piper_text = gr.Textbox( | |
| label="Text to speak", | |
| placeholder="Enter text here...", | |
| lines=5 | |
| ) | |
| # Initialize Piper model to get voices | |
| piper_model = tts_models.get("piper-tts") | |
| if piper_model: | |
| piper_model.initialize() | |
| languages = piper_model.get_supported_languages() | |
| else: | |
| languages = ["English"] | |
| piper_language = gr.Dropdown( | |
| choices=languages, | |
| value="English", | |
| label="Language" | |
| ) | |
| def update_piper_voices(language): | |
| if piper_model: | |
| voices = piper_model.get_available_voices(language) | |
| return gr.update(choices=voices, value=voices[0] if voices else None) | |
| return gr.update(choices=[], value=None) | |
| piper_voice = gr.Dropdown( | |
| label="Voice", | |
| choices=[] | |
| ) | |
| piper_language.change( | |
| update_piper_voices, | |
| inputs=[piper_language], | |
| outputs=[piper_voice] | |
| ) | |
| piper_submit = gr.Button("Generate Speech") | |
| with gr.Column(): | |
| piper_output = gr.Audio(label="Generated Speech") | |
| piper_error = gr.Textbox(label="Error", visible=False) | |
| piper_submit.click( | |
| tts_piper, | |
| inputs=[piper_text, piper_language, piper_voice], | |
| outputs=[piper_output, piper_error] | |
| ) | |
| # Kokoro Tab | |
| with gr.Tab("Kokoro"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| kokoro_text = gr.Textbox( | |
| label="Text to speak", | |
| placeholder="Enter text here...", | |
| lines=5 | |
| ) | |
| kokoro_language = gr.Dropdown( | |
| choices=[ | |
| "American English (a)", "British English (b)", | |
| "Spanish (e)", "French (f)", "Hindi (h)", | |
| "Italian (i)", "Japanese (j)", | |
| "Brazilian Portuguese (p)", "Mandarin Chinese (z)" | |
| ], | |
| value="American English (a)", | |
| label="Language" | |
| ) | |
| def get_lang_code(language): | |
| return language.split("(")[-1].split(")")[0].strip() | |
| def update_kokoro_voices(language): | |
| lang_code = get_lang_code(language) | |
| voices = get_kokoro_voices(lang_code) | |
| return gr.update(choices=voices, value=voices[0] if voices else None) | |
| kokoro_voice = gr.Dropdown( | |
| label="Voice", | |
| choices=get_kokoro_voices("a"), | |
| value="af_heart" | |
| ) | |
| kokoro_language.change( | |
| update_kokoro_voices, | |
| inputs=[kokoro_language], | |
| outputs=[kokoro_voice] | |
| ) | |
| kokoro_submit = gr.Button("Generate Speech") | |
| with gr.Column(): | |
| kokoro_output = gr.Audio(label="Generated Speech") | |
| kokoro_error = gr.Textbox(label="Error", visible=False) | |
| kokoro_submit.click( | |
| lambda text, lang, voice: tts_kokoro(text, get_lang_code(lang), voice), | |
| inputs=[kokoro_text, kokoro_language, kokoro_voice], | |
| outputs=[kokoro_output, kokoro_error] | |
| ) | |
| # Dia Tab | |
| with gr.Tab("Dia"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| dia_text = gr.Textbox( | |
| label="Text to speak", | |
| placeholder="Enter text here...", | |
| lines=5 | |
| ) | |
| dia_audio_prompt = gr.Audio( | |
| label="Voice reference (optional)", | |
| type="filepath" | |
| ) | |
| dia_submit = gr.Button("Generate Speech") | |
| with gr.Column(): | |
| dia_output = gr.Audio(label="Generated Speech") | |
| dia_error = gr.Textbox(label="Error", visible=False) | |
| dia_submit.click( | |
| tts_dia, | |
| inputs=[dia_text, dia_audio_prompt], | |
| outputs=[dia_output, dia_error] | |
| ) | |
| def create_stt_tab(): | |
| """Create the STT tab for the Gradio interface""" | |
| with gr.Tab("Speech-to-Text"): | |
| gr.Markdown("## Speech-to-Text Models") | |
| with gr.Tabs(): | |
| # Faster Whisper Tab | |
| with gr.Tab("Faster Whisper"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| whisper_audio = gr.Audio( | |
| label="Audio to transcribe", | |
| type="filepath" | |
| ) | |
| whisper_language = gr.Dropdown( | |
| choices=["Auto-detect", "English", "Chinese", "Spanish", "French", "German", "Japanese"], | |
| value="Auto-detect", | |
| label="Language (optional)" | |
| ) | |
| whisper_submit = gr.Button("Transcribe") | |
| with gr.Column(): | |
| whisper_output = gr.Textbox( | |
| label="Transcription", | |
| lines=5 | |
| ) | |
| whisper_submit.click( | |
| lambda audio, lang: stt_whisper(audio, None if lang == "Auto-detect" else lang), | |
| inputs=[whisper_audio, whisper_language], | |
| outputs=[whisper_output] | |
| ) | |
| # Create the Gradio interface | |
| def create_interface(): | |
| """Create the main Gradio interface""" | |
| with gr.Blocks(title="TTS & STT Gallery") as demo: | |
| gr.Markdown("# TTS & STT Model Gallery") | |
| gr.Markdown("Explore different Text-to-Speech and Speech-to-Text models") | |
| with gr.Tabs(): | |
| create_tts_tab() | |
| create_stt_tab() | |
| return demo | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch() | |