kokoro-tts-api / main.py
eder0782's picture
Rename app.py to main.py
46b8310 verified
raw
history blame
3.38 kB
import asyncio
import base64
import io
import time
import uuid
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import torch
import soundfile as sf
import numpy as np
from kokoro import KModel, KPipeline
app = FastAPI()
CUDA_AVAILABLE = torch.cuda.is_available()
models = {gpu: KModel().to('cuda' if gpu else 'cpu').eval() for gpu in [False] + ([True] if CUDA_AVAILABLE else [])}
pipelines = {lang_code: KPipeline(lang_code=lang_code, model=False) for lang_code in ['a', 'p', 'e']}
CHOICES = {
'πŸ‡ΊπŸ‡Έ 🚺 Heart ❀️': 'af_heart',
'πŸ‡ΊπŸ‡Έ 🚺 Alloy': 'af_alloy',
'πŸ‡ΊπŸ‡Έ 🚺 Aoede': 'af_aoede',
'πŸ‡ΊπŸ‡Έ 🚺 Bella πŸ”₯': 'af_bella',
'πŸ‡ΊπŸ‡Έ 🚺 Jessica': 'af_jessica',
'πŸ‡ΊπŸ‡Έ 🚺 Kore': 'af_kore',
'πŸ‡ΊπŸ‡Έ 🚺 Nicole 🎧': 'af_nicole',
'πŸ‡ΊπŸ‡Έ 🚺 Nova': 'af_nova',
'πŸ‡ΊπŸ‡Έ 🚺 River': 'af_river',
'πŸ‡ΊπŸ‡Έ 🚺 Sarah': 'af_sarah',
'πŸ‡ΊπŸ‡Έ 🚺 Sky': 'af_sky',
'πŸ‡ΊπŸ‡Έ 🚹 Adam': 'am_adam',
'πŸ‡ΊπŸ‡Έ 🚹 Echo': 'am_echo',
'πŸ‡ΊπŸ‡Έ 🚹 Eric': 'am_eric',
'πŸ‡ΊπŸ‡Έ 🚹 Fenrir': 'am_fenrir',
'πŸ‡ΊπŸ‡Έ 🚹 Liam': 'am_liam',
'πŸ‡ΊπŸ‡Έ 🚹 Michael': 'am_michael',
'πŸ‡ΊπŸ‡Έ 🚹 Onyx': 'am_onyx',
'πŸ‡ΊπŸ‡Έ 🚹 Puck': 'am_puck',
'πŸ‡ΊπŸ‡Έ 🚹 Santa': 'am_santa',
'πŸ‡§πŸ‡· 🚺 Dora': 'pf_dora',
'πŸ‡§πŸ‡· 🚹 Alex': 'pm_alex',
'πŸ‡§πŸ‡· 🚹 Santa': 'pm_santa',
'πŸ‡ͺπŸ‡Έ 🚺 Dora': 'ef_dora',
'πŸ‡ͺπŸ‡Έ 🚹 Alex': 'em_alex',
'πŸ‡ͺπŸ‡Έ 🚹 Santa': 'em_santa',
}
for v in CHOICES.values():
pipelines[v[0]].load_voice(v)
class PredictRequest(BaseModel):
text: str
voice: str = 'af_heart'
speed: float = 1.0
def generate_audio(text: str, voice: str, speed: float, use_gpu: bool = CUDA CUDA_AVAILABLE):
pipeline = pipelines[voice[0]]
pack = pipeline.load_voice(voice)
use_gpu = use_gpu and CUDA_AVAILABLE
for _, ps, _ in pipeline(text, voice, speed):
ref_s = pack[len(ps)-1]
try:
if use_gpu:
audio = models[True](ps, ref_s, speed)
else:
audio = models[False](ps, ref_s, speed)
except Exception as e:
if use_gpu:
audio = models[False](ps, ref_s, speed)
else:
raise e
return 24000, audio.numpy()
return None, ''
@app.post("/predict")
async def predict(request: PredictRequest):
start_time = time.time()
sample_rate, audio_data = generate_audio(request.text, request.voice, request.speed, use_gpu=CUDA_AVAILABLE)
if audio_data is None:
return JSONResponse(status_code=400, content={"error": "Failed to generate audio"})
buffer = io.BytesIO()
sf.write(buffer, audio_data, sample_rate, format='WAV')
buffer.seek(0)
audio_base64 = base64.b64encode(buffer.read()).decode("utf-8")
duration = len(audio_data) / sample_rate
generation_time = time.time() - start_time
return {
"audio_base64": audio_base64,
"duration_seconds": round(duration, 2),
"generation_time_seconds": round(generation_time, 2)
}
@app.get("/voices")
async def get_voices():
return [{"name": k, "value": v} for k, v in CHOICES.items()]
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)