tts / app.py
IAMCB's picture
update the ui and api_ v3.1 fucking bugs
cf0add7
"""
Kokoro-TTS Local Generator
-------------------------
A high-performance text-to-speech system with both Gradio UI and REST API support.
Provides multiple voice models, audio formats, and cross-platform compatibility.
Key Features:
- Multiple voice models support (26+ voices)
- Real-time generation with progress tracking
- WAV, MP3, and AAC output formats
- REST API for programmatic access
- Network sharing capabilities
- Cross-platform compatibility (Windows, macOS, Linux)
- Configurable caching and model management
"""
import gradio as gr
import json
import platform
import shutil
from pathlib import Path
import soundfile as sf
from pydub import AudioSegment
import torch
import numpy as np
import time
import uuid
from typing import Dict, List, Optional, Union, Tuple, Generator
import threading
import os
import sys
import time
import socket
import threading
import logging
from datetime import datetime
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from werkzeug.serving import run_simple
# Import Kokoro models
from models import (
list_available_voices, build_model,
generate_speech
)
# Flask for API
from flask import Flask, request, jsonify, send_file
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler("kokoro_tts.log")
]
)
logger = logging.getLogger("kokoro_tts")
# Global configuration
CONFIG_FILE = "tts_config.json"
DEFAULT_OUTPUT_DIR = "outputs"
SAMPLE_RATE = 24000
# Model and configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Using device: {device}")
model = None
config = {
"output_dir": DEFAULT_OUTPUT_DIR,
"default_voice": None,
"default_format": "wav",
"api_enabled": True,
"api_port": 5000,
"ui_port": 7860,
"share_ui": True
}
def load_config() -> Dict:
"""Load configuration from file or create default."""
try:
if os.path.exists(CONFIG_FILE):
with open(CONFIG_FILE, 'r') as f:
loaded_config = json.load(f)
# Update with any new config options
for k, v in config.items():
if k not in loaded_config:
loaded_config[k] = v
return loaded_config
else:
save_config(config)
return config
except Exception as e:
logger.error(f"Error loading config: {e}")
return config
def save_config(config_data: Dict) -> None:
"""Save configuration to file."""
try:
with open(CONFIG_FILE, 'w') as f:
json.dump(config_data, f, indent=4)
except Exception as e:
logger.error(f"Error saving config: {e}")
def initialize_model() -> None:
"""Initialize the TTS model."""
global model
try:
if model is None:
logger.info("Initializing Kokoro TTS model...")
model = build_model(None, device)
logger.info("Model initialization complete")
except Exception as e:
logger.error(f"Error initializing model: {e}")
raise
def get_available_voices() -> List[str]:
"""Get list of available voice models."""
try:
# Initialize model to trigger voice downloads
initialize_model()
voices = list_available_voices()
if not voices:
logger.warning("No voices found after initialization.")
logger.info(f"Available voices: {voices}")
return voices
except Exception as e:
logger.error(f"Error getting voices: {e}")
return []
def convert_audio(input_path: str, output_format: str) -> str:
"""Convert audio to specified format."""
try:
if output_format == "wav":
return input_path
output_path = os.path.splitext(input_path)[0] + f".{output_format}"
audio = AudioSegment.from_wav(input_path)
if output_format == "mp3":
audio.export(output_path, format="mp3", bitrate="192k")
elif output_format == "aac":
audio.export(output_path, format="aac", bitrate="192k")
else:
logger.warning(f"Unsupported format: {output_format}, defaulting to wav")
return input_path
logger.info(f"Converted audio to {output_format}: {output_path}")
return output_path
except Exception as e:
logger.error(f"Error converting audio: {e}")
return input_path
def generate_tts(
text: str,
voice_name: str,
output_format: str = "wav",
output_path: Optional[str] = None,
speed: float = 1.0
) -> Optional[str]:
"""
Generate TTS audio and return the path to the generated file.
Args:
text: Text to convert to speech
voice_name: Name of the voice to use
output_format: Output audio format (wav, mp3, aac)
output_path: Optional custom output path
speed: Speech speed multiplier
Returns:
Path to the generated audio file, or None if generation failed
"""
global model
try:
# Initialize model if needed
initialize_model()
# Create output directory
os.makedirs(config["output_dir"], exist_ok=True)
# Generate base filename from text
if output_path:
base_path = output_path
wav_path = os.path.splitext(base_path)[0] + ".wav"
else:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
base_name = f"tts_{timestamp}_{str(uuid.uuid4())[:8]}"
wav_path = os.path.join(config["output_dir"], f"{base_name}.wav")
# Generate speech
logger.info(f"Generating speech for text: '{text[:50]}...' using voice: {voice_name}")
# Prepare voice path
voice_path = f"voices/{voice_name}.pt"
if not os.path.exists(voice_path):
logger.warning(f"Voice file not found: {voice_path}")
voices = get_available_voices()
if not voices:
raise Exception("No voices available")
if voice_name not in voices:
logger.warning(f"Using default voice instead of {voice_name}")
voice_name = voices[0]
voice_path = f"voices/{voice_name}.pt"
# Generate speech
generator = model(text, voice=voice_path, speed=speed, split_pattern=r'\n+')
all_audio = []
for i, (gs, ps, audio) in enumerate(generator):
if audio is not None:
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
all_audio.append(audio)
logger.debug(f"Generated segment {i+1}: {gs[:30]}...")
if not all_audio:
raise Exception("No audio generated")
# Combine audio segments and save
final_audio = torch.cat(all_audio, dim=0)
sf.write(wav_path, final_audio.numpy(), SAMPLE_RATE)
logger.info(f"Saved WAV file to {wav_path}")
# Convert to requested format if needed
if output_format != "wav":
output_file = convert_audio(wav_path, output_format)
return output_file
return wav_path
except Exception as e:
logger.error(f"Error generating speech: {e}")
import traceback
logger.error(traceback.format_exc())
return None
# UI INTERFACE
def create_ui_interface():
"""Create and return the Gradio interface."""
# Get available voices
voices = get_available_voices()
if not voices:
logger.error("No voices found! Please check the voices directory.")
# Don't return None, continue with empty list to allow UI to still load
voices = []
# Set default voice
default_voice = config.get("default_voice")
if not default_voice or default_voice not in voices:
default_voice = voices[0] if voices else None
if default_voice:
config["default_voice"] = default_voice
save_config(config)
# Create interface
with gr.Blocks(title="CB's TTS Generator") as interface:
gr.Markdown("# **Welcome to CB's TTS Generator**")
gr.Markdown("There are multiple voices available for you to choose. This TTS is powered by Kokoro.")
with gr.Row():
with gr.Column(scale=1):
# Group voice selection and text input without using Box
voice = gr.Dropdown(
choices=voices,
value=default_voice,
label="Voice"
)
text = gr.Textbox(
lines=8,
placeholder="Enter text to convert to speech...",
label="Text Input"
)
format_choice = gr.Radio(
choices=["wav", "mp3", "aac"],
value=config.get("default_format", "wav"),
label="Output Format"
)
speed = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Speech Speed"
)
generate_btn = gr.Button("Generate Speech", variant="primary")
with gr.Column(scale=1):
output = gr.Audio(label="Generated Audio")
status = gr.Textbox(label="Status", interactive=False)
# Function to update status and generate speech
def generate_wrapper(voice_name, text_input, format_choice, speed_value):
if not text_input.strip():
return None, "Error: Please enter some text to convert."
try:
output_path = generate_tts(
text=text_input,
voice_name=voice_name,
output_format=format_choice,
speed=speed_value
)
if output_path:
return output_path, f"Success! Generated audio with voice: {voice_name}"
else:
return None, "Error: Failed to generate audio. Check logs for details."
except Exception as e:
logger.error(f"UI generation error: {e}")
return None, f"Error: {str(e)}"
generate_btn.click(
fn=generate_wrapper,
inputs=[voice, text, format_choice, speed],
outputs=[output, status]
)
# Add movie quote examples if we have voices
if voices:
gr.Examples(
[
["May the Force be with you.", default_voice, "wav", 1.0],
["Here's looking at you, kid.", default_voice, "mp3", 1.0],
["I'll be back.", default_voice, "wav", 1.0],
["Houston, we have a problem.", default_voice, "mp3", 1.0]
],
fn=generate_wrapper,
inputs=[text, voice, format_choice, speed],
outputs=[output, status]
)
return interface
# API SERVER
def create_api_server() -> Flask:
"""Create and configure the Flask API server."""
app = Flask("KokoroTTS-API")
@app.route('/api/voices', methods=['GET'])
def api_voices():
"""Get available voices."""
try:
voices = get_available_voices()
return jsonify({"voices": voices, "default": config.get("default_voice")})
except Exception as e:
logger.error(f"API error in voices: {e}")
return jsonify({"error": str(e)}), 500
@app.route('/api/tts', methods=['POST'])
def api_tts():
"""Generate speech from text."""
try:
data = request.json
if not data or 'text' not in data:
return jsonify({"error": "Missing 'text' field"}), 400
text = data['text']
voice = data.get('voice', config.get("default_voice"))
output_format = data.get('format', config.get("default_format", "wav"))
speed = float(data.get('speed', 1.0))
# Create a dedicated output filename for this request
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
request_id = str(uuid.uuid4())[:8]
filename = f"api_tts_{timestamp}_{request_id}.{output_format}"
output_path = os.path.join(config["output_dir"], filename)
# Generate audio
generated_path = generate_tts(
text=text,
voice_name=voice,
output_format=output_format,
output_path=output_path,
speed=speed
)
if not generated_path or not os.path.exists(generated_path):
logger.error(f"Generated path doesn't exist: {generated_path}")
return jsonify({"error": "Failed to generate audio file"}), 500
# Verify file size
file_size = os.path.getsize(generated_path)
if file_size < 100: # Very small file likely indicates an error
logger.error(f"Generated file is too small ({file_size} bytes)")
return jsonify({"error": "Generated audio file appears to be empty or corrupted"}), 500
logger.info(f"Sending audio file: {generated_path} ({file_size} bytes)")
# Return audio file
return send_file(
generated_path,
as_attachment=True,
download_name=f"tts_output.{output_format}",
mimetype=f"audio/{output_format}" if output_format != "aac" else "audio/aac"
)
except Exception as e:
logger.error(f"API error in TTS: {e}")
import traceback
logger.error(traceback.format_exc())
return jsonify({"error": str(e)}), 500
# Add a health check endpoint
@app.route('/api/health', methods=['GET'])
def api_health():
"""Health check endpoint."""
return jsonify({
"status": "ok",
"model_loaded": model is not None,
"voices_count": len(get_available_voices())
})
@app.route('/api/config', methods=['GET', 'PUT'])
def api_config():
"""Get or update configuration."""
if request.method == 'GET':
return jsonify(config)
else:
try:
data = request.json
# Only update specific fields
for key in ['output_dir', 'default_voice', 'default_format']:
if key in data:
config[key] = data[key]
save_config(config)
return jsonify({"status": "success", "config": config})
except Exception as e:
logger.error(f"API error updating config: {e}")
return jsonify({"error": str(e)}), 500
return app
# SERVER LAUNCH FUNCTIONS
def launch_api(host="0.0.0.0", port=None):
"""Launch the API server in a separate thread."""
if not config.get("api_enabled", True):
logger.info("API server disabled in configuration")
return
api_port = port or config.get("api_port", 5000)
logger.info(f"Launching API server on port {api_port}")
app = create_api_server()
def run_api_server():
try:
# Use Werkzeug development server for simplicity
from werkzeug.serving import run_simple
run_simple(host, api_port, app, threaded=True, use_reloader=False)
except Exception as e:
logger.error(f"Error in API server: {e}")
import traceback
logger.error(traceback.format_exc())
# Start in a daemon thread
api_thread = threading.Thread(target=run_api_server, daemon=True)
api_thread.start()
# Give the server a moment to start
time.sleep(1)
logger.info(f"API server running at http://{host}:{api_port}")
return api_thread
def launch_ui(server_name="0.0.0.0", server_port=None, share=None):
port = server_port or config.get("ui_port", 7860)
share_ui = share if share is not None else config.get("share_ui", True)
logger.info(f"Launching UI on port {port} (share={share_ui})")
interface = create_ui_interface()
# Disable queue if running on Hugging Face Spaces
if os.environ.get("HF_SPACE") is None:
interface.queue() # Only enable queue for local deployments
interface.launch(
server_name=server_name,
server_port=port,
share=share_ui,
prevent_thread_lock=True
)
logger.info(f"UI server running at http://{server_name}:{port}")
return True
# MAIN APPLICATION
def main():
"""Main application entry point."""
print("\n" + "="*50)
print("Starting Kokoro-TTS")
print("="*50)
# Load configuration and create output directory
global config
config = load_config()
os.makedirs(config["output_dir"], exist_ok=True)
# Initialize model
try:
initialize_model()
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
print(f"ERROR: Failed to initialize model: {e}")
sys.exit(1)
# Get the network IP address for WSL access
hostname = socket.gethostname()
network_ip = socket.gethostbyname(hostname)
# Check if we are running on Hugging Face Spaces or if you want to combine the servers locally
if os.environ.get("HF_SPACE") is not None or os.environ.get("SINGLE_PORT") == "1":
# Create the API Flask app and the Gradio interface
api_app = create_api_server()
interface = create_ui_interface()
# Combine the Gradio app and Flask API under the same port using DispatcherMiddleware
# All routes under '/api' go to the Flask API, all other routes go to Gradio.
combined_app = DispatcherMiddleware(interface.app, {
'/api': api_app
})
# Use the UI port (or any single port you want)
port = config.get("ui_port", 7860)
print(f"Combined UI and API running on port: {port}")
print(f"Localhost: http://localhost:{port}")
print(f"Network: http://{network_ip}:{port}")
# Run the combined app on a single port
run_simple("0.0.0.0", port, combined_app, use_reloader=False, threaded=True)
else:
# Local deployment: run API and UI separately
if config.get("api_enabled", True):
launch_api() # launches API on its own thread (port 5000 by default)
ui_thread = threading.Thread(target=launch_ui, daemon=True)
ui_thread.start()
print(f"UI (localhost): http://localhost:{config.get('ui_port', 7860)}")
print(f"UI (network): http://{network_ip}:{config.get('ui_port', 7860)}")
if config.get("api_enabled", True):
print(f"API (localhost): http://localhost:{config.get('api_port', 5000)}")
print(f"API (network): http://{network_ip}:{config.get('api_port', 5000)}")
# Keep the main thread alive
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("\nShutting down servers...")
print("Press Ctrl+C again to force quit")
if __name__ == "__main__":
main()