import gradio as gr import torch import tempfile import os import numpy as np import soundfile as sf # Import our model factory from src.models.factory import ModelFactory # Patch torch.load to always use CPU original_torch_load = torch.load def patched_torch_load(f, map_location=None, **kwargs): if map_location is None: map_location = 'cpu' return original_torch_load(f, map_location=map_location, **kwargs) torch.load = patched_torch_load # Get model descriptions MODEL_DESCRIPTIONS = ModelFactory.get_model_descriptions() # Models dictionary for UI display MODELS = { "ResembleAI/chatterbox": "Chatterbox", "KittenML/KittenTTS": "KittenTTS", "piper-tts": "Piper (no voice cloning)", "SYSTRAN/faster-whisper": "Faster Whisper", "hexgrad/kokoro": "Kokoro-82M", "nari-labs/Dia-1.6B": "Dia TTS", } # Initialize model instances tts_models = ModelFactory.get_tts_models() stt_models = ModelFactory.get_stt_models() # Initialize the models that need immediate initialization for model_name in ["ResembleAI/chatterbox", "KittenML/KittenTTS"]: if model_name in tts_models: tts_models[model_name].initialize() # Initialize the STT model whisper_model = stt_models.get("SYSTRAN/faster-whisper") if whisper_model: whisper_model.initialize() # Helper function to get Kokoro voices def get_kokoro_voices(language_code): """ Get available voices for a specific Kokoro language code Based on: https://huggingface.co/hexgrad/Kokoro-82M/blob/main/VOICES.md """ voice_map = { # American English (a) "a": [ "af_heart", "af_alloy", "af_aoede", "af_bella", "af_jessica", "af_kore", "af_nicole", "af_nova", "af_river", "af_sarah", "af_sky", "am_adam", "am_echo", "am_eric", "am_fenrir", "am_liam", "am_michael", "am_onyx", "am_puck", "am_santa" ], # British English (b) "b": [ "bf_alice", "bf_emma", "bf_isabella", "bf_lily", "bm_daniel", "bm_fable", "bm_george", "bm_lewis" ], # Spanish (e) "e": ["ef_dora", "em_alex", "em_santa"], # French (f) "f": ["ff_siwis"], # Hindi (h) "h": ["hf_alpha", "hf_beta", "hm_omega", "hm_psi"], # Italian (i) "i": ["if_sara", "im_nicola"], # Japanese (j) "j": ["jf_alpha", "jf_gongitsune", "jf_nezumi", "jf_tebukuro", "jm_kumo"], # Brazilian Portuguese (p) "p": ["pt_heart", "pt_sun", "pt_moon", "pt_star", "pt_cloud"], # Mandarin Chinese (z) "z": [ "zf_xiaobei", "zf_xiaoni", "zf_xiaoxiao", "zf_xiaoyi", "zm_yunjian", "zm_yunxi", "zm_yunxia", "zm_yunyang" ] } return voice_map.get(language_code, ["af_heart"]) # Default to American English voices # UI Functions for TTS Models def tts_chatterbox(text, language, audio_prompt=None): """UI function for Chatterbox TTS""" model = tts_models.get("ResembleAI/chatterbox") if not model: return None, "Model not available" try: audio_path = model.generate_speech(text, language=language, audio_prompt=audio_prompt) return audio_path, "" except Exception as e: return None, f"Error: {str(e)}" def tts_kittentts(text, audio_prompt=None): """UI function for KittenTTS""" model = tts_models.get("KittenML/KittenTTS") if not model: return None, "Model not available" try: audio_path = model.generate_speech(text, audio_prompt=audio_prompt) return audio_path, "" except Exception as e: return None, f"Error: {str(e)}" def tts_piper(text, language, voice): """UI function for Piper TTS""" model = tts_models.get("piper-tts") if not model: return None, "Model not available" try: model.initialize() # Ensure voices are scanned audio_path = model.generate_speech(text, language=language, voice=voice) return audio_path, "" except Exception as e: return None, f"Error: {str(e)}" def tts_kokoro(text, language_code, voice_name): """UI function for Kokoro TTS""" model = tts_models.get("hexgrad/kokoro") if not model: return None, "Model not available" try: # Initialize the model if not already initialized if not model._initialized: model.initialize() # Generate speech (voice_name is kept for interface consistency but not used by Kokoro) audio_path = model.generate_speech(text, lang_code=language_code, voice_name=voice_name) # Check if audio file was actually created if audio_path and os.path.exists(audio_path): return audio_path, "" else: return None, "Error: Audio file was not generated" except Exception as e: return None, f"Error: {str(e)}" def tts_dia(text, audio_prompt=None): """UI function for Dia TTS""" model = tts_models.get("nari-labs/Dia-1.6B") if not model: return None, "Model not available" try: model.initialize() # Ensure model is loaded audio_path = model.generate_speech(text, audio_prompt=audio_prompt) return audio_path, "" except Exception as e: return None, f"Error: {str(e)}" # UI Function for STT Model def stt_whisper(audio_path, language=None): """UI function for Faster Whisper STT""" model = stt_models.get("SYSTRAN/faster-whisper") if not model: return "Model not available" try: transcription = model.transcribe(audio_path, language=language) return transcription except Exception as e: return f"Error: {str(e)}" # Gradio UI Components def create_tts_tab(): """Create the TTS tab for the Gradio interface""" with gr.Tab("Text-to-Speech"): gr.Markdown("## Text-to-Speech Models") with gr.Tabs(): # Chatterbox Tab with gr.Tab("Chatterbox"): with gr.Row(): with gr.Column(): chatterbox_text = gr.Textbox( label="Text to speak", placeholder="Enter text here...", lines=5 ) chatterbox_language = gr.Dropdown( choices=["English", "Chinese"], value="English", label="Language" ) chatterbox_audio_prompt = gr.Audio( label="Voice reference (optional)", type="filepath" ) chatterbox_submit = gr.Button("Generate Speech") with gr.Column(): chatterbox_output = gr.Audio(label="Generated Speech") chatterbox_error = gr.Textbox(label="Error", visible=False) chatterbox_submit.click( tts_chatterbox, inputs=[chatterbox_text, chatterbox_language, chatterbox_audio_prompt], outputs=[chatterbox_output, chatterbox_error] ) # KittenTTS Tab with gr.Tab("KittenTTS"): with gr.Row(): with gr.Column(): kittentts_text = gr.Textbox( label="Text to speak", placeholder="Enter text here...", lines=5 ) kittentts_audio_prompt = gr.Audio( label="Voice reference (optional)", type="filepath" ) kittentts_submit = gr.Button("Generate Speech") with gr.Column(): kittentts_output = gr.Audio(label="Generated Speech") kittentts_error = gr.Textbox(label="Error", visible=False) kittentts_submit.click( tts_kittentts, inputs=[kittentts_text, kittentts_audio_prompt], outputs=[kittentts_output, kittentts_error] ) # Piper Tab with gr.Tab("Piper"): with gr.Row(): with gr.Column(): piper_text = gr.Textbox( label="Text to speak", placeholder="Enter text here...", lines=5 ) # Initialize Piper model to get voices piper_model = tts_models.get("piper-tts") if piper_model: piper_model.initialize() languages = piper_model.get_supported_languages() else: languages = ["English"] piper_language = gr.Dropdown( choices=languages, value="English", label="Language" ) def update_piper_voices(language): if piper_model: voices = piper_model.get_available_voices(language) return gr.update(choices=voices, value=voices[0] if voices else None) return gr.update(choices=[], value=None) piper_voice = gr.Dropdown( label="Voice", choices=[] ) piper_language.change( update_piper_voices, inputs=[piper_language], outputs=[piper_voice] ) piper_submit = gr.Button("Generate Speech") with gr.Column(): piper_output = gr.Audio(label="Generated Speech") piper_error = gr.Textbox(label="Error", visible=False) piper_submit.click( tts_piper, inputs=[piper_text, piper_language, piper_voice], outputs=[piper_output, piper_error] ) # Kokoro Tab with gr.Tab("Kokoro"): with gr.Row(): with gr.Column(): kokoro_text = gr.Textbox( label="Text to speak", placeholder="Enter text here...", lines=5 ) kokoro_language = gr.Dropdown( choices=[ "American English (a)", "British English (b)", "Spanish (e)", "French (f)", "Hindi (h)", "Italian (i)", "Japanese (j)", "Brazilian Portuguese (p)", "Mandarin Chinese (z)" ], value="American English (a)", label="Language" ) def get_lang_code(language): return language.split("(")[-1].split(")")[0].strip() def update_kokoro_voices(language): lang_code = get_lang_code(language) voices = get_kokoro_voices(lang_code) return gr.update(choices=voices, value=voices[0] if voices else None) kokoro_voice = gr.Dropdown( label="Voice", choices=get_kokoro_voices("a"), value="af_heart" ) kokoro_language.change( update_kokoro_voices, inputs=[kokoro_language], outputs=[kokoro_voice] ) kokoro_submit = gr.Button("Generate Speech") with gr.Column(): kokoro_output = gr.Audio(label="Generated Speech") kokoro_error = gr.Textbox(label="Error", visible=False) kokoro_submit.click( lambda text, lang, voice: tts_kokoro(text, get_lang_code(lang), voice), inputs=[kokoro_text, kokoro_language, kokoro_voice], outputs=[kokoro_output, kokoro_error] ) # Dia Tab with gr.Tab("Dia"): with gr.Row(): with gr.Column(): dia_text = gr.Textbox( label="Text to speak", placeholder="Enter text here...", lines=5 ) dia_audio_prompt = gr.Audio( label="Voice reference (optional)", type="filepath" ) dia_submit = gr.Button("Generate Speech") with gr.Column(): dia_output = gr.Audio(label="Generated Speech") dia_error = gr.Textbox(label="Error", visible=False) dia_submit.click( tts_dia, inputs=[dia_text, dia_audio_prompt], outputs=[dia_output, dia_error] ) def create_stt_tab(): """Create the STT tab for the Gradio interface""" with gr.Tab("Speech-to-Text"): gr.Markdown("## Speech-to-Text Models") with gr.Tabs(): # Faster Whisper Tab with gr.Tab("Faster Whisper"): with gr.Row(): with gr.Column(): whisper_audio = gr.Audio( label="Audio to transcribe", type="filepath" ) whisper_language = gr.Dropdown( choices=["Auto-detect", "English", "Chinese", "Spanish", "French", "German", "Japanese"], value="Auto-detect", label="Language (optional)" ) whisper_submit = gr.Button("Transcribe") with gr.Column(): whisper_output = gr.Textbox( label="Transcription", lines=5 ) whisper_submit.click( lambda audio, lang: stt_whisper(audio, None if lang == "Auto-detect" else lang), inputs=[whisper_audio, whisper_language], outputs=[whisper_output] ) # Create the Gradio interface def create_interface(): """Create the main Gradio interface""" with gr.Blocks(title="TTS & STT Gallery") as demo: gr.Markdown("# TTS & STT Model Gallery") gr.Markdown("Explore different Text-to-Speech and Speech-to-Text models") with gr.Tabs(): create_tts_tab() create_stt_tab() return demo # Launch the app if __name__ == "__main__": demo = create_interface() demo.launch()