| from fastapi import FastAPI, HTTPException, BackgroundTasks |
| from fastapi.responses import StreamingResponse, HTMLResponse |
| from fastapi.staticfiles import StaticFiles |
| from pydantic import BaseModel, validator |
| import numpy as np |
| import io |
| import wave |
| from kokoro_onnx import Kokoro |
| from kokoro_onnx.tokenizer import Tokenizer |
| from typing import Optional, Dict, Tuple |
| import uvicorn |
| from ui import html_content |
| import asyncio |
| import concurrent.futures |
| from functools import lru_cache |
| import threading |
| from queue import Queue |
| import time |
| import hashlib |
|
|
| app = FastAPI(title="Kokoro TTS API", version="1.0.0") |
|
|
| |
| executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) |
|
|
| |
| tokenizer = Tokenizer() |
| kokoro = Kokoro("kokoro-v1.0.onnx", "voices-v1.0.bin") |
| SUPPORTED_LANGUAGES = ["en-us"] |
|
|
| |
| phoneme_cache: Dict[str, str] = {} |
| voice_style_cache: Dict[str, np.ndarray] = {} |
| audio_cache: Dict[str, Tuple[np.ndarray, int]] = {} |
|
|
| |
| request_queue = Queue() |
| batch_size = 4 |
| batch_timeout = 0.1 |
|
|
| class TTSRequest(BaseModel): |
| text: str |
| voice: str = "af_heart" |
| language: str = "en-us" |
| blend_voice_name: Optional[str] = None |
| speed: float = 1.0 |
|
|
| class TTSResponse(BaseModel): |
| phonemes: str |
| sample_rate: int |
|
|
| def get_cache_key(text: str, language: str, voice: str, blend_voice: Optional[str], speed: float) -> str: |
| """Generate cache key for request""" |
| key_data = f"{text}|{language}|{voice}|{blend_voice}|{speed}" |
| return hashlib.md5(key_data.encode()).hexdigest() |
|
|
| @lru_cache(maxsize=1000) |
| def cached_phonemize(text: str, language: str) -> str: |
| """Cache phoneme conversion""" |
| return tokenizer.phonemize(text, lang=language) |
|
|
| def get_cached_voice_style(voice_name: str) -> np.ndarray: |
| """Cache voice styles to avoid repeated loading""" |
| if voice_name not in voice_style_cache: |
| voice_style_cache[voice_name] = kokoro.get_voice_style(voice_name) |
| return voice_style_cache[voice_name] |
|
|
| def process_voice_blend(voice: str, blend_voice_name: Optional[str]) -> np.ndarray: |
| """Optimized voice blending with caching""" |
| if not blend_voice_name: |
| return get_cached_voice_style(voice) |
| |
| blend_key = f"{voice}+{blend_voice_name}" |
| if blend_key not in voice_style_cache: |
| first_voice = get_cached_voice_style(voice) |
| second_voice = get_cached_voice_style(blend_voice_name) |
| blended_voice = np.add(first_voice * 0.5, second_voice * 0.5) |
| voice_style_cache[blend_key] = blended_voice |
| |
| return voice_style_cache[blend_key] |
|
|
| def batch_process_tts(requests: list) -> list: |
| """Process multiple TTS requests in batch""" |
| results = [] |
| |
| |
| phonemes_batch = [] |
| for req in requests: |
| phonemes = cached_phonemize(req.text, req.language) |
| phonemes_batch.append(phonemes) |
| |
| |
| for i, req in enumerate(requests): |
| try: |
| phonemes = phonemes_batch[i] |
| voice = process_voice_blend(req.voice, req.blend_voice_name) |
| |
| |
| samples, sample_rate = kokoro.create( |
| phonemes, voice=voice, speed=req.speed, lang=None, is_phonemes=True |
| ) |
| results.append((samples, sample_rate, phonemes, None)) |
| |
| except Exception as e: |
| results.append((None, None, None, str(e))) |
| |
| return results |
|
|
| def numpy_to_wav_bytes(audio_data: np.ndarray, sample_rate: int) -> bytes: |
| """Optimized WAV conversion with pre-allocated buffer""" |
| if audio_data.dtype != np.int16: |
| audio_data = (audio_data * 32767).astype(np.int16) |
|
|
| |
| buffer_size = len(audio_data) * 2 + 44 |
| buffer = io.BytesIO() |
| buffer.truncate(buffer_size) |
| buffer.seek(0) |
| |
| with wave.open(buffer, "wb") as wav_file: |
| wav_file.setnchannels(1) |
| wav_file.setsampwidth(2) |
| wav_file.setframerate(sample_rate) |
| wav_file.writeframes(audio_data.tobytes()) |
|
|
| buffer.seek(0) |
| return buffer.getvalue() |
|
|
| async def run_in_executor(func, *args, **kwargs): |
| """Run CPU-intensive function in thread pool""" |
| loop = asyncio.get_event_loop() |
| if kwargs: |
| |
| from functools import partial |
| func_with_args = partial(func, *args, **kwargs) |
| return await loop.run_in_executor(executor, func_with_args) |
| else: |
| return await loop.run_in_executor(executor, func, *args) |
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def get_home(): |
| return HTMLResponse(content=html_content) |
|
|
| @app.get("/voices") |
| async def get_voices(): |
| |
| if not hasattr(get_voices, '_cached_voices'): |
| get_voices._cached_voices = {"voices": sorted(kokoro.get_voices())} |
| return get_voices._cached_voices |
|
|
| @app.get("/languages") |
| async def get_languages(): |
| return {"languages": SUPPORTED_LANGUAGES} |
|
|
| @app.post("/tts/audio") |
| async def generate_audio(request: TTSRequest): |
| """Optimized audio generation with caching""" |
| try: |
| |
| if request.language not in SUPPORTED_LANGUAGES: |
| raise HTTPException( |
| status_code=400, detail=f"Unsupported language: {request.language}" |
| ) |
|
|
| available_voices = kokoro.get_voices() |
| if request.voice not in available_voices: |
| raise HTTPException( |
| status_code=400, detail=f"Unsupported voice: {request.voice}" |
| ) |
|
|
| if request.blend_voice_name and request.blend_voice_name not in available_voices: |
| raise HTTPException( |
| status_code=400, detail=f"Unsupported blend voice: {request.blend_voice_name}" |
| ) |
|
|
| |
| cache_key = get_cache_key( |
| request.text, request.language, request.voice, |
| request.blend_voice_name, request.speed |
| ) |
| |
| if cache_key in audio_cache: |
| samples, sample_rate = audio_cache[cache_key] |
| else: |
| |
| phonemes = cached_phonemize(request.text, request.language) |
| |
| |
| voice = process_voice_blend(request.voice, request.blend_voice_name) |
|
|
| |
| samples, sample_rate = await run_in_executor( |
| kokoro.create, |
| phonemes, |
| voice=voice, |
| speed=request.speed, |
| lang=None, |
| is_phonemes=True |
| ) |
| |
| |
| if len(audio_cache) < 100: |
| audio_cache[cache_key] = (samples, sample_rate) |
|
|
| |
| wav_bytes = await run_in_executor(numpy_to_wav_bytes, samples, sample_rate) |
|
|
| return StreamingResponse( |
| io.BytesIO(wav_bytes), |
| media_type="audio/wav", |
| headers={"Content-Disposition": "attachment; filename=output.wav"}, |
| ) |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/tts/info", response_model=TTSResponse) |
| async def generate_info(request: TTSRequest): |
| """Optimized info generation with caching""" |
| try: |
| if request.language not in SUPPORTED_LANGUAGES: |
| raise HTTPException( |
| status_code=400, detail=f"Unsupported language: {request.language}" |
| ) |
|
|
| |
| phonemes = cached_phonemize(request.text, request.language) |
| sample_rate = 24000 |
|
|
| return TTSResponse(phonemes=phonemes, sample_rate=sample_rate) |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/tts/batch") |
| async def generate_batch(requests: list[TTSRequest]): |
| """Batch processing endpoint for multiple requests""" |
| try: |
| |
| available_voices = kokoro.get_voices() |
| for req in requests: |
| if req.language not in SUPPORTED_LANGUAGES: |
| raise HTTPException( |
| status_code=400, detail=f"Unsupported language: {req.language}" |
| ) |
| if req.voice not in available_voices: |
| raise HTTPException( |
| status_code=400, detail=f"Unsupported voice: {req.voice}" |
| ) |
|
|
| |
| results = await run_in_executor(batch_process_tts, requests) |
| |
| |
| response_data = [] |
| for i, (samples, sample_rate, phonemes, error) in enumerate(results): |
| if error: |
| response_data.append({"error": error}) |
| else: |
| wav_bytes = await run_in_executor(numpy_to_wav_bytes, samples, sample_rate) |
| import base64 |
| audio_base64 = base64.b64encode(wav_bytes).decode() |
| |
| response_data.append({ |
| "phonemes": phonemes, |
| "sample_rate": sample_rate, |
| "audio_base64": audio_base64, |
| "audio_format": "wav" |
| }) |
| |
| return {"results": response_data} |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/tts/both") |
| async def generate_both(request: TTSRequest): |
| """Generate both audio and metadata with optimizations""" |
| try: |
| |
| if request.language not in SUPPORTED_LANGUAGES: |
| raise HTTPException( |
| status_code=400, detail=f"Unsupported language: {request.language}" |
| ) |
|
|
| available_voices = kokoro.get_voices() |
| if request.voice not in available_voices: |
| raise HTTPException( |
| status_code=400, detail=f"Unsupported voice: {request.voice}" |
| ) |
|
|
| if request.blend_voice_name and request.blend_voice_name not in available_voices: |
| raise HTTPException( |
| status_code=400, detail=f"Unsupported blend voice: {request.blend_voice_name}" |
| ) |
|
|
| |
| cache_key = get_cache_key( |
| request.text, request.language, request.voice, |
| request.blend_voice_name, request.speed |
| ) |
| |
| if cache_key in audio_cache: |
| samples, sample_rate = audio_cache[cache_key] |
| phonemes = cached_phonemize(request.text, request.language) |
| else: |
| |
| phonemes = cached_phonemize(request.text, request.language) |
| |
| |
| voice = process_voice_blend(request.voice, request.blend_voice_name) |
|
|
| |
| samples, sample_rate = await run_in_executor( |
| kokoro.create, |
| phonemes, |
| voice=voice, |
| speed=request.speed, |
| lang=None, |
| is_phonemes=True |
| ) |
| |
| |
| if len(audio_cache) < 100: |
| audio_cache[cache_key] = (samples, sample_rate) |
|
|
| |
| wav_bytes = await run_in_executor(numpy_to_wav_bytes, samples, sample_rate) |
| import base64 |
| audio_base64 = base64.b64encode(wav_bytes).decode() |
|
|
| return { |
| "phonemes": phonemes, |
| "sample_rate": sample_rate, |
| "audio_base64": audio_base64, |
| "audio_format": "wav", |
| } |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| |
| @app.on_event("startup") |
| async def startup_event(): |
| """Preload commonly used voices""" |
| common_voices = ["af_heart", "af_bella", "af_sarah"] |
| for voice in common_voices: |
| if voice in kokoro.get_voices(): |
| get_cached_voice_style(voice) |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) |