|
|
""" |
|
|
Kokoro-TTS Local Generator |
|
|
------------------------- |
|
|
A high-performance text-to-speech system with both Gradio UI and REST API support. |
|
|
Provides multiple voice models, audio formats, and cross-platform compatibility. |
|
|
|
|
|
Key Features: |
|
|
- Multiple voice models support (26+ voices) |
|
|
- Real-time generation with progress tracking |
|
|
- WAV, MP3, and AAC output formats |
|
|
- REST API for programmatic access |
|
|
- Network sharing capabilities |
|
|
- Cross-platform compatibility (Windows, macOS, Linux) |
|
|
- Configurable caching and model management |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import json |
|
|
import platform |
|
|
|
|
|
import shutil |
|
|
from pathlib import Path |
|
|
import soundfile as sf |
|
|
from pydub import AudioSegment |
|
|
import torch |
|
|
import numpy as np |
|
|
import time |
|
|
import uuid |
|
|
from typing import Dict, List, Optional, Union, Tuple, Generator |
|
|
import threading |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import socket |
|
|
import threading |
|
|
import logging |
|
|
from datetime import datetime |
|
|
from werkzeug.middleware.dispatcher import DispatcherMiddleware |
|
|
from werkzeug.serving import run_simple |
|
|
|
|
|
from models import ( |
|
|
list_available_voices, build_model, |
|
|
generate_speech |
|
|
) |
|
|
|
|
|
|
|
|
from flask import Flask, request, jsonify, send_file |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.StreamHandler(), |
|
|
logging.FileHandler("kokoro_tts.log") |
|
|
] |
|
|
) |
|
|
logger = logging.getLogger("kokoro_tts") |
|
|
|
|
|
|
|
|
CONFIG_FILE = "tts_config.json" |
|
|
DEFAULT_OUTPUT_DIR = "outputs" |
|
|
SAMPLE_RATE = 24000 |
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
logger.info(f"Using device: {device}") |
|
|
model = None |
|
|
config = { |
|
|
"output_dir": DEFAULT_OUTPUT_DIR, |
|
|
"default_voice": None, |
|
|
"default_format": "wav", |
|
|
"api_enabled": True, |
|
|
"api_port": 5000, |
|
|
"ui_port": 7860, |
|
|
"share_ui": True |
|
|
} |
|
|
|
|
|
def load_config() -> Dict: |
|
|
"""Load configuration from file or create default.""" |
|
|
try: |
|
|
if os.path.exists(CONFIG_FILE): |
|
|
with open(CONFIG_FILE, 'r') as f: |
|
|
loaded_config = json.load(f) |
|
|
|
|
|
for k, v in config.items(): |
|
|
if k not in loaded_config: |
|
|
loaded_config[k] = v |
|
|
return loaded_config |
|
|
else: |
|
|
save_config(config) |
|
|
return config |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading config: {e}") |
|
|
return config |
|
|
|
|
|
def save_config(config_data: Dict) -> None: |
|
|
"""Save configuration to file.""" |
|
|
try: |
|
|
with open(CONFIG_FILE, 'w') as f: |
|
|
json.dump(config_data, f, indent=4) |
|
|
except Exception as e: |
|
|
logger.error(f"Error saving config: {e}") |
|
|
|
|
|
def initialize_model() -> None: |
|
|
"""Initialize the TTS model.""" |
|
|
global model |
|
|
try: |
|
|
if model is None: |
|
|
logger.info("Initializing Kokoro TTS model...") |
|
|
model = build_model(None, device) |
|
|
logger.info("Model initialization complete") |
|
|
except Exception as e: |
|
|
logger.error(f"Error initializing model: {e}") |
|
|
raise |
|
|
|
|
|
def get_available_voices() -> List[str]: |
|
|
"""Get list of available voice models.""" |
|
|
try: |
|
|
|
|
|
initialize_model() |
|
|
|
|
|
voices = list_available_voices() |
|
|
if not voices: |
|
|
logger.warning("No voices found after initialization.") |
|
|
|
|
|
logger.info(f"Available voices: {voices}") |
|
|
return voices |
|
|
except Exception as e: |
|
|
logger.error(f"Error getting voices: {e}") |
|
|
return [] |
|
|
|
|
|
def convert_audio(input_path: str, output_format: str) -> str: |
|
|
"""Convert audio to specified format.""" |
|
|
try: |
|
|
if output_format == "wav": |
|
|
return input_path |
|
|
|
|
|
output_path = os.path.splitext(input_path)[0] + f".{output_format}" |
|
|
audio = AudioSegment.from_wav(input_path) |
|
|
|
|
|
if output_format == "mp3": |
|
|
audio.export(output_path, format="mp3", bitrate="192k") |
|
|
elif output_format == "aac": |
|
|
audio.export(output_path, format="aac", bitrate="192k") |
|
|
else: |
|
|
logger.warning(f"Unsupported format: {output_format}, defaulting to wav") |
|
|
return input_path |
|
|
|
|
|
logger.info(f"Converted audio to {output_format}: {output_path}") |
|
|
return output_path |
|
|
except Exception as e: |
|
|
logger.error(f"Error converting audio: {e}") |
|
|
return input_path |
|
|
|
|
|
def generate_tts( |
|
|
text: str, |
|
|
voice_name: str, |
|
|
output_format: str = "wav", |
|
|
output_path: Optional[str] = None, |
|
|
speed: float = 1.0 |
|
|
) -> Optional[str]: |
|
|
""" |
|
|
Generate TTS audio and return the path to the generated file. |
|
|
|
|
|
Args: |
|
|
text: Text to convert to speech |
|
|
voice_name: Name of the voice to use |
|
|
output_format: Output audio format (wav, mp3, aac) |
|
|
output_path: Optional custom output path |
|
|
speed: Speech speed multiplier |
|
|
|
|
|
Returns: |
|
|
Path to the generated audio file, or None if generation failed |
|
|
""" |
|
|
global model |
|
|
|
|
|
try: |
|
|
|
|
|
initialize_model() |
|
|
|
|
|
|
|
|
os.makedirs(config["output_dir"], exist_ok=True) |
|
|
|
|
|
|
|
|
if output_path: |
|
|
base_path = output_path |
|
|
wav_path = os.path.splitext(base_path)[0] + ".wav" |
|
|
else: |
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
base_name = f"tts_{timestamp}_{str(uuid.uuid4())[:8]}" |
|
|
wav_path = os.path.join(config["output_dir"], f"{base_name}.wav") |
|
|
|
|
|
|
|
|
logger.info(f"Generating speech for text: '{text[:50]}...' using voice: {voice_name}") |
|
|
|
|
|
|
|
|
voice_path = f"voices/{voice_name}.pt" |
|
|
if not os.path.exists(voice_path): |
|
|
logger.warning(f"Voice file not found: {voice_path}") |
|
|
voices = get_available_voices() |
|
|
if not voices: |
|
|
raise Exception("No voices available") |
|
|
if voice_name not in voices: |
|
|
logger.warning(f"Using default voice instead of {voice_name}") |
|
|
voice_name = voices[0] |
|
|
voice_path = f"voices/{voice_name}.pt" |
|
|
|
|
|
|
|
|
generator = model(text, voice=voice_path, speed=speed, split_pattern=r'\n+') |
|
|
|
|
|
all_audio = [] |
|
|
for i, (gs, ps, audio) in enumerate(generator): |
|
|
if audio is not None: |
|
|
if isinstance(audio, np.ndarray): |
|
|
audio = torch.from_numpy(audio).float() |
|
|
all_audio.append(audio) |
|
|
logger.debug(f"Generated segment {i+1}: {gs[:30]}...") |
|
|
|
|
|
if not all_audio: |
|
|
raise Exception("No audio generated") |
|
|
|
|
|
|
|
|
final_audio = torch.cat(all_audio, dim=0) |
|
|
sf.write(wav_path, final_audio.numpy(), SAMPLE_RATE) |
|
|
logger.info(f"Saved WAV file to {wav_path}") |
|
|
|
|
|
|
|
|
if output_format != "wav": |
|
|
output_file = convert_audio(wav_path, output_format) |
|
|
return output_file |
|
|
|
|
|
return wav_path |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating speech: {e}") |
|
|
import traceback |
|
|
logger.error(traceback.format_exc()) |
|
|
return None |
|
|
|
|
|
|
|
|
def create_ui_interface(): |
|
|
"""Create and return the Gradio interface.""" |
|
|
|
|
|
|
|
|
voices = get_available_voices() |
|
|
if not voices: |
|
|
logger.error("No voices found! Please check the voices directory.") |
|
|
|
|
|
voices = [] |
|
|
|
|
|
|
|
|
default_voice = config.get("default_voice") |
|
|
if not default_voice or default_voice not in voices: |
|
|
default_voice = voices[0] if voices else None |
|
|
if default_voice: |
|
|
config["default_voice"] = default_voice |
|
|
save_config(config) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="CB's TTS Generator") as interface: |
|
|
gr.Markdown("# **Welcome to CB's TTS Generator**") |
|
|
gr.Markdown("There are multiple voices available for you to choose. This TTS is powered by Kokoro.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
voice = gr.Dropdown( |
|
|
choices=voices, |
|
|
value=default_voice, |
|
|
label="Voice" |
|
|
) |
|
|
|
|
|
text = gr.Textbox( |
|
|
lines=8, |
|
|
placeholder="Enter text to convert to speech...", |
|
|
label="Text Input" |
|
|
) |
|
|
|
|
|
format_choice = gr.Radio( |
|
|
choices=["wav", "mp3", "aac"], |
|
|
value=config.get("default_format", "wav"), |
|
|
label="Output Format" |
|
|
) |
|
|
|
|
|
speed = gr.Slider( |
|
|
minimum=0.5, |
|
|
maximum=2.0, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
label="Speech Speed" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("Generate Speech", variant="primary") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
output = gr.Audio(label="Generated Audio") |
|
|
|
|
|
status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
|
|
|
def generate_wrapper(voice_name, text_input, format_choice, speed_value): |
|
|
if not text_input.strip(): |
|
|
return None, "Error: Please enter some text to convert." |
|
|
|
|
|
try: |
|
|
output_path = generate_tts( |
|
|
text=text_input, |
|
|
voice_name=voice_name, |
|
|
output_format=format_choice, |
|
|
speed=speed_value |
|
|
) |
|
|
|
|
|
if output_path: |
|
|
return output_path, f"Success! Generated audio with voice: {voice_name}" |
|
|
else: |
|
|
return None, "Error: Failed to generate audio. Check logs for details." |
|
|
except Exception as e: |
|
|
logger.error(f"UI generation error: {e}") |
|
|
return None, f"Error: {str(e)}" |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_wrapper, |
|
|
inputs=[voice, text, format_choice, speed], |
|
|
outputs=[output, status] |
|
|
) |
|
|
|
|
|
|
|
|
if voices: |
|
|
gr.Examples( |
|
|
[ |
|
|
["May the Force be with you.", default_voice, "wav", 1.0], |
|
|
["Here's looking at you, kid.", default_voice, "mp3", 1.0], |
|
|
["I'll be back.", default_voice, "wav", 1.0], |
|
|
["Houston, we have a problem.", default_voice, "mp3", 1.0] |
|
|
], |
|
|
fn=generate_wrapper, |
|
|
inputs=[text, voice, format_choice, speed], |
|
|
outputs=[output, status] |
|
|
) |
|
|
|
|
|
return interface |
|
|
|
|
|
|
|
|
def create_api_server() -> Flask: |
|
|
"""Create and configure the Flask API server.""" |
|
|
app = Flask("KokoroTTS-API") |
|
|
|
|
|
@app.route('/api/voices', methods=['GET']) |
|
|
def api_voices(): |
|
|
"""Get available voices.""" |
|
|
try: |
|
|
voices = get_available_voices() |
|
|
return jsonify({"voices": voices, "default": config.get("default_voice")}) |
|
|
except Exception as e: |
|
|
logger.error(f"API error in voices: {e}") |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
@app.route('/api/tts', methods=['POST']) |
|
|
def api_tts(): |
|
|
"""Generate speech from text.""" |
|
|
try: |
|
|
data = request.json |
|
|
|
|
|
if not data or 'text' not in data: |
|
|
return jsonify({"error": "Missing 'text' field"}), 400 |
|
|
|
|
|
text = data['text'] |
|
|
voice = data.get('voice', config.get("default_voice")) |
|
|
output_format = data.get('format', config.get("default_format", "wav")) |
|
|
speed = float(data.get('speed', 1.0)) |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
request_id = str(uuid.uuid4())[:8] |
|
|
filename = f"api_tts_{timestamp}_{request_id}.{output_format}" |
|
|
output_path = os.path.join(config["output_dir"], filename) |
|
|
|
|
|
|
|
|
generated_path = generate_tts( |
|
|
text=text, |
|
|
voice_name=voice, |
|
|
output_format=output_format, |
|
|
output_path=output_path, |
|
|
speed=speed |
|
|
) |
|
|
|
|
|
if not generated_path or not os.path.exists(generated_path): |
|
|
logger.error(f"Generated path doesn't exist: {generated_path}") |
|
|
return jsonify({"error": "Failed to generate audio file"}), 500 |
|
|
|
|
|
|
|
|
file_size = os.path.getsize(generated_path) |
|
|
if file_size < 100: |
|
|
logger.error(f"Generated file is too small ({file_size} bytes)") |
|
|
return jsonify({"error": "Generated audio file appears to be empty or corrupted"}), 500 |
|
|
|
|
|
logger.info(f"Sending audio file: {generated_path} ({file_size} bytes)") |
|
|
|
|
|
|
|
|
return send_file( |
|
|
generated_path, |
|
|
as_attachment=True, |
|
|
download_name=f"tts_output.{output_format}", |
|
|
mimetype=f"audio/{output_format}" if output_format != "aac" else "audio/aac" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"API error in TTS: {e}") |
|
|
import traceback |
|
|
logger.error(traceback.format_exc()) |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
|
|
|
@app.route('/api/health', methods=['GET']) |
|
|
def api_health(): |
|
|
"""Health check endpoint.""" |
|
|
return jsonify({ |
|
|
"status": "ok", |
|
|
"model_loaded": model is not None, |
|
|
"voices_count": len(get_available_voices()) |
|
|
}) |
|
|
|
|
|
@app.route('/api/config', methods=['GET', 'PUT']) |
|
|
def api_config(): |
|
|
"""Get or update configuration.""" |
|
|
if request.method == 'GET': |
|
|
return jsonify(config) |
|
|
else: |
|
|
try: |
|
|
data = request.json |
|
|
|
|
|
for key in ['output_dir', 'default_voice', 'default_format']: |
|
|
if key in data: |
|
|
config[key] = data[key] |
|
|
|
|
|
save_config(config) |
|
|
return jsonify({"status": "success", "config": config}) |
|
|
except Exception as e: |
|
|
logger.error(f"API error updating config: {e}") |
|
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
def launch_api(host="0.0.0.0", port=None): |
|
|
"""Launch the API server in a separate thread.""" |
|
|
if not config.get("api_enabled", True): |
|
|
logger.info("API server disabled in configuration") |
|
|
return |
|
|
|
|
|
api_port = port or config.get("api_port", 5000) |
|
|
logger.info(f"Launching API server on port {api_port}") |
|
|
|
|
|
app = create_api_server() |
|
|
|
|
|
def run_api_server(): |
|
|
try: |
|
|
|
|
|
from werkzeug.serving import run_simple |
|
|
run_simple(host, api_port, app, threaded=True, use_reloader=False) |
|
|
except Exception as e: |
|
|
logger.error(f"Error in API server: {e}") |
|
|
import traceback |
|
|
logger.error(traceback.format_exc()) |
|
|
|
|
|
|
|
|
api_thread = threading.Thread(target=run_api_server, daemon=True) |
|
|
api_thread.start() |
|
|
|
|
|
|
|
|
time.sleep(1) |
|
|
logger.info(f"API server running at http://{host}:{api_port}") |
|
|
return api_thread |
|
|
|
|
|
def launch_ui(server_name="0.0.0.0", server_port=None, share=None): |
|
|
port = server_port or config.get("ui_port", 7860) |
|
|
share_ui = share if share is not None else config.get("share_ui", True) |
|
|
|
|
|
logger.info(f"Launching UI on port {port} (share={share_ui})") |
|
|
interface = create_ui_interface() |
|
|
|
|
|
|
|
|
if os.environ.get("HF_SPACE") is None: |
|
|
interface.queue() |
|
|
|
|
|
interface.launch( |
|
|
server_name=server_name, |
|
|
server_port=port, |
|
|
share=share_ui, |
|
|
prevent_thread_lock=True |
|
|
) |
|
|
logger.info(f"UI server running at http://{server_name}:{port}") |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main application entry point.""" |
|
|
print("\n" + "="*50) |
|
|
print("Starting Kokoro-TTS") |
|
|
print("="*50) |
|
|
|
|
|
|
|
|
global config |
|
|
config = load_config() |
|
|
os.makedirs(config["output_dir"], exist_ok=True) |
|
|
|
|
|
|
|
|
try: |
|
|
initialize_model() |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize model: {e}") |
|
|
print(f"ERROR: Failed to initialize model: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
hostname = socket.gethostname() |
|
|
network_ip = socket.gethostbyname(hostname) |
|
|
|
|
|
|
|
|
if os.environ.get("HF_SPACE") is not None or os.environ.get("SINGLE_PORT") == "1": |
|
|
|
|
|
api_app = create_api_server() |
|
|
interface = create_ui_interface() |
|
|
|
|
|
|
|
|
|
|
|
combined_app = DispatcherMiddleware(interface.app, { |
|
|
'/api': api_app |
|
|
}) |
|
|
|
|
|
|
|
|
port = config.get("ui_port", 7860) |
|
|
print(f"Combined UI and API running on port: {port}") |
|
|
print(f"Localhost: http://localhost:{port}") |
|
|
print(f"Network: http://{network_ip}:{port}") |
|
|
|
|
|
|
|
|
run_simple("0.0.0.0", port, combined_app, use_reloader=False, threaded=True) |
|
|
else: |
|
|
|
|
|
if config.get("api_enabled", True): |
|
|
launch_api() |
|
|
ui_thread = threading.Thread(target=launch_ui, daemon=True) |
|
|
ui_thread.start() |
|
|
|
|
|
print(f"UI (localhost): http://localhost:{config.get('ui_port', 7860)}") |
|
|
print(f"UI (network): http://{network_ip}:{config.get('ui_port', 7860)}") |
|
|
if config.get("api_enabled", True): |
|
|
print(f"API (localhost): http://localhost:{config.get('api_port', 5000)}") |
|
|
print(f"API (network): http://{network_ip}:{config.get('api_port', 5000)}") |
|
|
|
|
|
|
|
|
try: |
|
|
while True: |
|
|
time.sleep(1) |
|
|
except KeyboardInterrupt: |
|
|
print("\nShutting down servers...") |
|
|
print("Press Ctrl+C again to force quit") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |