""" Kokoro-TTS Local Generator ------------------------- A high-performance text-to-speech system with both Gradio UI and REST API support. Provides multiple voice models, audio formats, and cross-platform compatibility. Key Features: - Multiple voice models support (26+ voices) - Real-time generation with progress tracking - WAV, MP3, and AAC output formats - REST API for programmatic access - Network sharing capabilities - Cross-platform compatibility (Windows, macOS, Linux) - Configurable caching and model management """ import gradio as gr import json import platform import shutil from pathlib import Path import soundfile as sf from pydub import AudioSegment import torch import numpy as np import time import uuid from typing import Dict, List, Optional, Union, Tuple, Generator import threading import os import sys import time import socket import threading import logging from datetime import datetime from werkzeug.middleware.dispatcher import DispatcherMiddleware from werkzeug.serving import run_simple # Import Kokoro models from models import ( list_available_voices, build_model, generate_speech ) # Flask for API from flask import Flask, request, jsonify, send_file # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(), logging.FileHandler("kokoro_tts.log") ] ) logger = logging.getLogger("kokoro_tts") # Global configuration CONFIG_FILE = "tts_config.json" DEFAULT_OUTPUT_DIR = "outputs" SAMPLE_RATE = 24000 # Model and configuration device = 'cuda' if torch.cuda.is_available() else 'cpu' logger.info(f"Using device: {device}") model = None config = { "output_dir": DEFAULT_OUTPUT_DIR, "default_voice": None, "default_format": "wav", "api_enabled": True, "api_port": 5000, "ui_port": 7860, "share_ui": True } def load_config() -> Dict: """Load configuration from file or create default.""" try: if os.path.exists(CONFIG_FILE): with open(CONFIG_FILE, 'r') as f: loaded_config = json.load(f) # Update with any new config options for k, v in config.items(): if k not in loaded_config: loaded_config[k] = v return loaded_config else: save_config(config) return config except Exception as e: logger.error(f"Error loading config: {e}") return config def save_config(config_data: Dict) -> None: """Save configuration to file.""" try: with open(CONFIG_FILE, 'w') as f: json.dump(config_data, f, indent=4) except Exception as e: logger.error(f"Error saving config: {e}") def initialize_model() -> None: """Initialize the TTS model.""" global model try: if model is None: logger.info("Initializing Kokoro TTS model...") model = build_model(None, device) logger.info("Model initialization complete") except Exception as e: logger.error(f"Error initializing model: {e}") raise def get_available_voices() -> List[str]: """Get list of available voice models.""" try: # Initialize model to trigger voice downloads initialize_model() voices = list_available_voices() if not voices: logger.warning("No voices found after initialization.") logger.info(f"Available voices: {voices}") return voices except Exception as e: logger.error(f"Error getting voices: {e}") return [] def convert_audio(input_path: str, output_format: str) -> str: """Convert audio to specified format.""" try: if output_format == "wav": return input_path output_path = os.path.splitext(input_path)[0] + f".{output_format}" audio = AudioSegment.from_wav(input_path) if output_format == "mp3": audio.export(output_path, format="mp3", bitrate="192k") elif output_format == "aac": audio.export(output_path, format="aac", bitrate="192k") else: logger.warning(f"Unsupported format: {output_format}, defaulting to wav") return input_path logger.info(f"Converted audio to {output_format}: {output_path}") return output_path except Exception as e: logger.error(f"Error converting audio: {e}") return input_path def generate_tts( text: str, voice_name: str, output_format: str = "wav", output_path: Optional[str] = None, speed: float = 1.0 ) -> Optional[str]: """ Generate TTS audio and return the path to the generated file. Args: text: Text to convert to speech voice_name: Name of the voice to use output_format: Output audio format (wav, mp3, aac) output_path: Optional custom output path speed: Speech speed multiplier Returns: Path to the generated audio file, or None if generation failed """ global model try: # Initialize model if needed initialize_model() # Create output directory os.makedirs(config["output_dir"], exist_ok=True) # Generate base filename from text if output_path: base_path = output_path wav_path = os.path.splitext(base_path)[0] + ".wav" else: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") base_name = f"tts_{timestamp}_{str(uuid.uuid4())[:8]}" wav_path = os.path.join(config["output_dir"], f"{base_name}.wav") # Generate speech logger.info(f"Generating speech for text: '{text[:50]}...' using voice: {voice_name}") # Prepare voice path voice_path = f"voices/{voice_name}.pt" if not os.path.exists(voice_path): logger.warning(f"Voice file not found: {voice_path}") voices = get_available_voices() if not voices: raise Exception("No voices available") if voice_name not in voices: logger.warning(f"Using default voice instead of {voice_name}") voice_name = voices[0] voice_path = f"voices/{voice_name}.pt" # Generate speech generator = model(text, voice=voice_path, speed=speed, split_pattern=r'\n+') all_audio = [] for i, (gs, ps, audio) in enumerate(generator): if audio is not None: if isinstance(audio, np.ndarray): audio = torch.from_numpy(audio).float() all_audio.append(audio) logger.debug(f"Generated segment {i+1}: {gs[:30]}...") if not all_audio: raise Exception("No audio generated") # Combine audio segments and save final_audio = torch.cat(all_audio, dim=0) sf.write(wav_path, final_audio.numpy(), SAMPLE_RATE) logger.info(f"Saved WAV file to {wav_path}") # Convert to requested format if needed if output_format != "wav": output_file = convert_audio(wav_path, output_format) return output_file return wav_path except Exception as e: logger.error(f"Error generating speech: {e}") import traceback logger.error(traceback.format_exc()) return None # UI INTERFACE def create_ui_interface(): """Create and return the Gradio interface.""" # Get available voices voices = get_available_voices() if not voices: logger.error("No voices found! Please check the voices directory.") # Don't return None, continue with empty list to allow UI to still load voices = [] # Set default voice default_voice = config.get("default_voice") if not default_voice or default_voice not in voices: default_voice = voices[0] if voices else None if default_voice: config["default_voice"] = default_voice save_config(config) # Create interface with gr.Blocks(title="CB's TTS Generator") as interface: gr.Markdown("# **Welcome to CB's TTS Generator**") gr.Markdown("There are multiple voices available for you to choose. This TTS is powered by Kokoro.") with gr.Row(): with gr.Column(scale=1): # Group voice selection and text input without using Box voice = gr.Dropdown( choices=voices, value=default_voice, label="Voice" ) text = gr.Textbox( lines=8, placeholder="Enter text to convert to speech...", label="Text Input" ) format_choice = gr.Radio( choices=["wav", "mp3", "aac"], value=config.get("default_format", "wav"), label="Output Format" ) speed = gr.Slider( minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speech Speed" ) generate_btn = gr.Button("Generate Speech", variant="primary") with gr.Column(scale=1): output = gr.Audio(label="Generated Audio") status = gr.Textbox(label="Status", interactive=False) # Function to update status and generate speech def generate_wrapper(voice_name, text_input, format_choice, speed_value): if not text_input.strip(): return None, "Error: Please enter some text to convert." try: output_path = generate_tts( text=text_input, voice_name=voice_name, output_format=format_choice, speed=speed_value ) if output_path: return output_path, f"Success! Generated audio with voice: {voice_name}" else: return None, "Error: Failed to generate audio. Check logs for details." except Exception as e: logger.error(f"UI generation error: {e}") return None, f"Error: {str(e)}" generate_btn.click( fn=generate_wrapper, inputs=[voice, text, format_choice, speed], outputs=[output, status] ) # Add movie quote examples if we have voices if voices: gr.Examples( [ ["May the Force be with you.", default_voice, "wav", 1.0], ["Here's looking at you, kid.", default_voice, "mp3", 1.0], ["I'll be back.", default_voice, "wav", 1.0], ["Houston, we have a problem.", default_voice, "mp3", 1.0] ], fn=generate_wrapper, inputs=[text, voice, format_choice, speed], outputs=[output, status] ) return interface # API SERVER def create_api_server() -> Flask: """Create and configure the Flask API server.""" app = Flask("KokoroTTS-API") @app.route('/api/voices', methods=['GET']) def api_voices(): """Get available voices.""" try: voices = get_available_voices() return jsonify({"voices": voices, "default": config.get("default_voice")}) except Exception as e: logger.error(f"API error in voices: {e}") return jsonify({"error": str(e)}), 500 @app.route('/api/tts', methods=['POST']) def api_tts(): """Generate speech from text.""" try: data = request.json if not data or 'text' not in data: return jsonify({"error": "Missing 'text' field"}), 400 text = data['text'] voice = data.get('voice', config.get("default_voice")) output_format = data.get('format', config.get("default_format", "wav")) speed = float(data.get('speed', 1.0)) # Create a dedicated output filename for this request timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") request_id = str(uuid.uuid4())[:8] filename = f"api_tts_{timestamp}_{request_id}.{output_format}" output_path = os.path.join(config["output_dir"], filename) # Generate audio generated_path = generate_tts( text=text, voice_name=voice, output_format=output_format, output_path=output_path, speed=speed ) if not generated_path or not os.path.exists(generated_path): logger.error(f"Generated path doesn't exist: {generated_path}") return jsonify({"error": "Failed to generate audio file"}), 500 # Verify file size file_size = os.path.getsize(generated_path) if file_size < 100: # Very small file likely indicates an error logger.error(f"Generated file is too small ({file_size} bytes)") return jsonify({"error": "Generated audio file appears to be empty or corrupted"}), 500 logger.info(f"Sending audio file: {generated_path} ({file_size} bytes)") # Return audio file return send_file( generated_path, as_attachment=True, download_name=f"tts_output.{output_format}", mimetype=f"audio/{output_format}" if output_format != "aac" else "audio/aac" ) except Exception as e: logger.error(f"API error in TTS: {e}") import traceback logger.error(traceback.format_exc()) return jsonify({"error": str(e)}), 500 # Add a health check endpoint @app.route('/api/health', methods=['GET']) def api_health(): """Health check endpoint.""" return jsonify({ "status": "ok", "model_loaded": model is not None, "voices_count": len(get_available_voices()) }) @app.route('/api/config', methods=['GET', 'PUT']) def api_config(): """Get or update configuration.""" if request.method == 'GET': return jsonify(config) else: try: data = request.json # Only update specific fields for key in ['output_dir', 'default_voice', 'default_format']: if key in data: config[key] = data[key] save_config(config) return jsonify({"status": "success", "config": config}) except Exception as e: logger.error(f"API error updating config: {e}") return jsonify({"error": str(e)}), 500 return app # SERVER LAUNCH FUNCTIONS def launch_api(host="0.0.0.0", port=None): """Launch the API server in a separate thread.""" if not config.get("api_enabled", True): logger.info("API server disabled in configuration") return api_port = port or config.get("api_port", 5000) logger.info(f"Launching API server on port {api_port}") app = create_api_server() def run_api_server(): try: # Use Werkzeug development server for simplicity from werkzeug.serving import run_simple run_simple(host, api_port, app, threaded=True, use_reloader=False) except Exception as e: logger.error(f"Error in API server: {e}") import traceback logger.error(traceback.format_exc()) # Start in a daemon thread api_thread = threading.Thread(target=run_api_server, daemon=True) api_thread.start() # Give the server a moment to start time.sleep(1) logger.info(f"API server running at http://{host}:{api_port}") return api_thread def launch_ui(server_name="0.0.0.0", server_port=None, share=None): port = server_port or config.get("ui_port", 7860) share_ui = share if share is not None else config.get("share_ui", True) logger.info(f"Launching UI on port {port} (share={share_ui})") interface = create_ui_interface() # Disable queue if running on Hugging Face Spaces if os.environ.get("HF_SPACE") is None: interface.queue() # Only enable queue for local deployments interface.launch( server_name=server_name, server_port=port, share=share_ui, prevent_thread_lock=True ) logger.info(f"UI server running at http://{server_name}:{port}") return True # MAIN APPLICATION def main(): """Main application entry point.""" print("\n" + "="*50) print("Starting Kokoro-TTS") print("="*50) # Load configuration and create output directory global config config = load_config() os.makedirs(config["output_dir"], exist_ok=True) # Initialize model try: initialize_model() except Exception as e: logger.error(f"Failed to initialize model: {e}") print(f"ERROR: Failed to initialize model: {e}") sys.exit(1) # Get the network IP address for WSL access hostname = socket.gethostname() network_ip = socket.gethostbyname(hostname) # Check if we are running on Hugging Face Spaces or if you want to combine the servers locally if os.environ.get("HF_SPACE") is not None or os.environ.get("SINGLE_PORT") == "1": # Create the API Flask app and the Gradio interface api_app = create_api_server() interface = create_ui_interface() # Combine the Gradio app and Flask API under the same port using DispatcherMiddleware # All routes under '/api' go to the Flask API, all other routes go to Gradio. combined_app = DispatcherMiddleware(interface.app, { '/api': api_app }) # Use the UI port (or any single port you want) port = config.get("ui_port", 7860) print(f"Combined UI and API running on port: {port}") print(f"Localhost: http://localhost:{port}") print(f"Network: http://{network_ip}:{port}") # Run the combined app on a single port run_simple("0.0.0.0", port, combined_app, use_reloader=False, threaded=True) else: # Local deployment: run API and UI separately if config.get("api_enabled", True): launch_api() # launches API on its own thread (port 5000 by default) ui_thread = threading.Thread(target=launch_ui, daemon=True) ui_thread.start() print(f"UI (localhost): http://localhost:{config.get('ui_port', 7860)}") print(f"UI (network): http://{network_ip}:{config.get('ui_port', 7860)}") if config.get("api_enabled", True): print(f"API (localhost): http://localhost:{config.get('api_port', 5000)}") print(f"API (network): http://{network_ip}:{config.get('api_port', 5000)}") # Keep the main thread alive try: while True: time.sleep(1) except KeyboardInterrupt: print("\nShutting down servers...") print("Press Ctrl+C again to force quit") if __name__ == "__main__": main()