kokoro / app.py
ABAO77's picture
Refactor code structure for improved readability and maintainability
693c106
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, validator
import numpy as np
import io
import wave
from kokoro_onnx import Kokoro
from kokoro_onnx.tokenizer import Tokenizer
from typing import Optional, Dict, Tuple
import uvicorn
from ui import html_content
import asyncio
import concurrent.futures
from functools import lru_cache
import threading
from queue import Queue
import time
import hashlib
app = FastAPI(title="Kokoro TTS API", version="1.0.0")
# Thread pool for CPU-intensive tasks
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
# Initialize models once
tokenizer = Tokenizer()
kokoro = Kokoro("kokoro-v1.0.onnx", "voices-v1.0.bin")
SUPPORTED_LANGUAGES = ["en-us"]
# Cache for phonemes and voice styles
phoneme_cache: Dict[str, str] = {}
voice_style_cache: Dict[str, np.ndarray] = {}
audio_cache: Dict[str, Tuple[np.ndarray, int]] = {}
# Request queue for batching
request_queue = Queue()
batch_size = 4
batch_timeout = 0.1 # 100ms
class TTSRequest(BaseModel):
text: str
voice: str = "af_heart"
language: str = "en-us"
blend_voice_name: Optional[str] = None
speed: float = 1.0
class TTSResponse(BaseModel):
phonemes: str
sample_rate: int
def get_cache_key(text: str, language: str, voice: str, blend_voice: Optional[str], speed: float) -> str:
"""Generate cache key for request"""
key_data = f"{text}|{language}|{voice}|{blend_voice}|{speed}"
return hashlib.md5(key_data.encode()).hexdigest()
@lru_cache(maxsize=1000)
def cached_phonemize(text: str, language: str) -> str:
"""Cache phoneme conversion"""
return tokenizer.phonemize(text, lang=language)
def get_cached_voice_style(voice_name: str) -> np.ndarray:
"""Cache voice styles to avoid repeated loading"""
if voice_name not in voice_style_cache:
voice_style_cache[voice_name] = kokoro.get_voice_style(voice_name)
return voice_style_cache[voice_name]
def process_voice_blend(voice: str, blend_voice_name: Optional[str]) -> np.ndarray:
"""Optimized voice blending with caching"""
if not blend_voice_name:
return get_cached_voice_style(voice) # Fixed: return the voice style, not the voice name
blend_key = f"{voice}+{blend_voice_name}"
if blend_key not in voice_style_cache:
first_voice = get_cached_voice_style(voice)
second_voice = get_cached_voice_style(blend_voice_name)
blended_voice = np.add(first_voice * 0.5, second_voice * 0.5)
voice_style_cache[blend_key] = blended_voice
return voice_style_cache[blend_key]
def batch_process_tts(requests: list) -> list:
"""Process multiple TTS requests in batch"""
results = []
# Pre-process all phonemes
phonemes_batch = []
for req in requests:
phonemes = cached_phonemize(req.text, req.language)
phonemes_batch.append(phonemes)
# Process audio generation for each request
for i, req in enumerate(requests):
try:
phonemes = phonemes_batch[i]
voice = process_voice_blend(req.voice, req.blend_voice_name)
# Generate audio - Fixed parameter order
samples, sample_rate = kokoro.create(
phonemes, voice=voice, speed=req.speed, lang=None, is_phonemes=True
)
results.append((samples, sample_rate, phonemes, None))
except Exception as e:
results.append((None, None, None, str(e)))
return results
def numpy_to_wav_bytes(audio_data: np.ndarray, sample_rate: int) -> bytes:
"""Optimized WAV conversion with pre-allocated buffer"""
if audio_data.dtype != np.int16:
audio_data = (audio_data * 32767).astype(np.int16)
# Pre-calculate buffer size
buffer_size = len(audio_data) * 2 + 44 # audio data + WAV header
buffer = io.BytesIO()
buffer.truncate(buffer_size)
buffer.seek(0)
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_data.tobytes())
buffer.seek(0)
return buffer.getvalue()
async def run_in_executor(func, *args, **kwargs):
"""Run CPU-intensive function in thread pool"""
loop = asyncio.get_event_loop()
if kwargs:
# Use functools.partial for keyword arguments
from functools import partial
func_with_args = partial(func, *args, **kwargs)
return await loop.run_in_executor(executor, func_with_args)
else:
return await loop.run_in_executor(executor, func, *args)
@app.get("/", response_class=HTMLResponse)
async def get_home():
return HTMLResponse(content=html_content)
@app.get("/voices")
async def get_voices():
# Cache voice list
if not hasattr(get_voices, '_cached_voices'):
get_voices._cached_voices = {"voices": sorted(kokoro.get_voices())}
return get_voices._cached_voices
@app.get("/languages")
async def get_languages():
return {"languages": SUPPORTED_LANGUAGES}
@app.post("/tts/audio")
async def generate_audio(request: TTSRequest):
"""Optimized audio generation with caching"""
try:
# Validate inputs
if request.language not in SUPPORTED_LANGUAGES:
raise HTTPException(
status_code=400, detail=f"Unsupported language: {request.language}"
)
available_voices = kokoro.get_voices()
if request.voice not in available_voices:
raise HTTPException(
status_code=400, detail=f"Unsupported voice: {request.voice}"
)
if request.blend_voice_name and request.blend_voice_name not in available_voices:
raise HTTPException(
status_code=400, detail=f"Unsupported blend voice: {request.blend_voice_name}"
)
# Check cache first
cache_key = get_cache_key(
request.text, request.language, request.voice,
request.blend_voice_name, request.speed
)
if cache_key in audio_cache:
samples, sample_rate = audio_cache[cache_key]
else:
# Generate phonemes (cached)
phonemes = cached_phonemize(request.text, request.language)
# Process voice (cached)
voice = process_voice_blend(request.voice, request.blend_voice_name)
# Generate audio in thread pool - Fixed parameter passing
samples, sample_rate = await run_in_executor(
kokoro.create,
phonemes,
voice=voice,
speed=request.speed,
lang=None,
is_phonemes=True
)
# Cache result (limit cache size)
if len(audio_cache) < 100:
audio_cache[cache_key] = (samples, sample_rate)
# Convert to WAV in thread pool
wav_bytes = await run_in_executor(numpy_to_wav_bytes, samples, sample_rate)
return StreamingResponse(
io.BytesIO(wav_bytes),
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=output.wav"},
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/tts/info", response_model=TTSResponse)
async def generate_info(request: TTSRequest):
"""Optimized info generation with caching"""
try:
if request.language not in SUPPORTED_LANGUAGES:
raise HTTPException(
status_code=400, detail=f"Unsupported language: {request.language}"
)
# Use cached phonemization
phonemes = cached_phonemize(request.text, request.language)
sample_rate = 24000
return TTSResponse(phonemes=phonemes, sample_rate=sample_rate)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/tts/batch")
async def generate_batch(requests: list[TTSRequest]):
"""Batch processing endpoint for multiple requests"""
try:
# Validate all requests first
available_voices = kokoro.get_voices()
for req in requests:
if req.language not in SUPPORTED_LANGUAGES:
raise HTTPException(
status_code=400, detail=f"Unsupported language: {req.language}"
)
if req.voice not in available_voices:
raise HTTPException(
status_code=400, detail=f"Unsupported voice: {req.voice}"
)
# Process batch in thread pool
results = await run_in_executor(batch_process_tts, requests)
# Convert results
response_data = []
for i, (samples, sample_rate, phonemes, error) in enumerate(results):
if error:
response_data.append({"error": error})
else:
wav_bytes = await run_in_executor(numpy_to_wav_bytes, samples, sample_rate)
import base64
audio_base64 = base64.b64encode(wav_bytes).decode()
response_data.append({
"phonemes": phonemes,
"sample_rate": sample_rate,
"audio_base64": audio_base64,
"audio_format": "wav"
})
return {"results": response_data}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/tts/both")
async def generate_both(request: TTSRequest):
"""Generate both audio and metadata with optimizations"""
try:
# Validate inputs
if request.language not in SUPPORTED_LANGUAGES:
raise HTTPException(
status_code=400, detail=f"Unsupported language: {request.language}"
)
available_voices = kokoro.get_voices()
if request.voice not in available_voices:
raise HTTPException(
status_code=400, detail=f"Unsupported voice: {request.voice}"
)
if request.blend_voice_name and request.blend_voice_name not in available_voices:
raise HTTPException(
status_code=400, detail=f"Unsupported blend voice: {request.blend_voice_name}"
)
# Check cache
cache_key = get_cache_key(
request.text, request.language, request.voice,
request.blend_voice_name, request.speed
)
if cache_key in audio_cache:
samples, sample_rate = audio_cache[cache_key]
phonemes = cached_phonemize(request.text, request.language)
else:
# Generate phonemes
phonemes = cached_phonemize(request.text, request.language)
# Process voice
voice = process_voice_blend(request.voice, request.blend_voice_name)
# Generate audio - Fixed parameter passing
samples, sample_rate = await run_in_executor(
kokoro.create,
phonemes,
voice=voice,
speed=request.speed,
lang=None,
is_phonemes=True
)
# Cache result
if len(audio_cache) < 100:
audio_cache[cache_key] = (samples, sample_rate)
# Convert to base64
wav_bytes = await run_in_executor(numpy_to_wav_bytes, samples, sample_rate)
import base64
audio_base64 = base64.b64encode(wav_bytes).decode()
return {
"phonemes": phonemes,
"sample_rate": sample_rate,
"audio_base64": audio_base64,
"audio_format": "wav",
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Cleanup function for cache management
@app.on_event("startup")
async def startup_event():
"""Preload commonly used voices"""
common_voices = ["af_heart", "af_bella", "af_sarah"]
for voice in common_voices:
if voice in kokoro.get_voices():
get_cached_voice_style(voice)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)