ttsgenie-dev / app.py
22333Misaka's picture
Upload 5 files
a287c67 verified
"""
Genie-TTS OpenAI Compatible API Server
======================================
This server provides an OpenAI-compatible TTS API endpoint (/v1/audio/speech)
for the Genie-TTS engine.
Usage:
POST /v1/audio/speech
{
"model": "liang", # Voice model name
"input": "要合成的文本", # Text to synthesize
"voice": "alloy", # Ignored - for OpenAI compatibility
"response_format": "wav", # Only wav is supported
"speed": 1.0 # Ignored - for OpenAI compatibility
}
"""
import os
import sys
import io
import wave
import json
import logging
import asyncio
from pathlib import Path
from typing import Optional, Dict, Any, Union
from contextlib import asynccontextmanager
import numpy as np
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import Response, StreamingResponse, JSONResponse
from pydantic import BaseModel, Field
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Model configuration
MODELS_DIR = Path(os.environ.get("MODELS_DIR", "/app/models"))
VOICES: Dict[str, Dict[str, Any]] = {}
# Audio settings
SAMPLE_RATE = 32000
CHANNELS = 1
BYTES_PER_SAMPLE = 2
class SpeechRequest(BaseModel):
"""OpenAI-compatible speech request model."""
model: str = Field(..., description="The voice model to use")
input: str = Field(..., description="The text to synthesize")
voice: Optional[str] = Field(default="alloy", description="Ignored - for OpenAI compatibility")
response_format: Optional[str] = Field(default="wav", description="Only wav is supported")
speed: Optional[float] = Field(default=1.0, description="Ignored - for OpenAI compatibility")
class ErrorResponse(BaseModel):
"""OpenAI-compatible error response."""
error: Dict[str, Any]
def load_voice_config(voice_dir: Path) -> Optional[Dict[str, Any]]:
"""Load voice configuration from a directory."""
config_path = voice_dir / "config.json"
if not config_path.exists():
logger.warning(f"Config file not found: {config_path}")
return None
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
# Validate required fields
required_fields = ["reference_audio", "reference_text", "language"]
for field in required_fields:
if field not in config:
logger.error(f"Missing required field '{field}' in {config_path}")
return None
# Check if ONNX models exist
onnx_dir = voice_dir / "onnx"
if not onnx_dir.exists():
logger.error(f"ONNX model directory not found: {onnx_dir}")
return None
config["onnx_dir"] = str(onnx_dir)
config["voice_dir"] = str(voice_dir)
return config
except Exception as e:
logger.error(f"Failed to load config from {config_path}: {e}")
return None
def discover_voices() -> Dict[str, Dict[str, Any]]:
"""Discover all available voice models."""
voices = {}
if not MODELS_DIR.exists():
logger.warning(f"Models directory not found: {MODELS_DIR}")
return voices
for voice_dir in MODELS_DIR.iterdir():
if voice_dir.is_dir():
voice_name = voice_dir.name
config = load_voice_config(voice_dir)
if config:
voices[voice_name] = config
logger.info(f"Loaded voice: {voice_name} (language: {config.get('language', 'unknown')})")
return voices
def initialize_genie():
"""Initialize Genie-TTS engine and load all voice models."""
global VOICES
logger.info("Initializing Genie-TTS engine...")
# Import genie_tts
try:
import genie_tts as genie
except ImportError as e:
logger.error(f"Failed to import genie_tts: {e}")
raise
# Download Genie data if needed
logger.info("Checking Genie data...")
genie.download_genie_data()
# Discover and load voices
VOICES = discover_voices()
if not VOICES:
logger.warning("No voice models found!")
return
# Load each voice model
for voice_name, config in VOICES.items():
try:
logger.info(f"Loading voice model: {voice_name}")
genie.load_character(
character_name=voice_name,
onnx_model_dir=config["onnx_dir"],
language=config["language"]
)
# Set reference audio
ref_audio_path = os.path.join(config["voice_dir"], config["reference_audio"])
genie.set_reference_audio(
character_name=voice_name,
audio_path=ref_audio_path,
audio_text=config["reference_text"],
language=config["language"]
)
logger.info(f"Voice model loaded successfully: {voice_name}")
except Exception as e:
logger.error(f"Failed to load voice model {voice_name}: {e}")
del VOICES[voice_name]
logger.info(f"Genie-TTS initialized with {len(VOICES)} voice(s)")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
# Startup
initialize_genie()
yield
# Shutdown
logger.info("Shutting down Genie-TTS server...")
# Create FastAPI app
app = FastAPI(
title="Genie-TTS OpenAI Compatible API",
description="OpenAI-compatible Text-to-Speech API powered by Genie-TTS",
version="1.0.0",
lifespan=lifespan
)
@app.get("/")
async def root():
"""Root endpoint - health check."""
return {
"status": "healthy",
"service": "Genie-TTS OpenAI Compatible API",
"available_models": list(VOICES.keys())
}
@app.get("/health")
async def health():
"""Health check endpoint."""
return {
"status": "healthy",
"models_loaded": len(VOICES),
"available_models": list(VOICES.keys())
}
@app.get("/v1/models")
async def list_models():
"""List available models (OpenAI-compatible)."""
import time
models = []
for voice_name in VOICES.keys():
models.append({
"id": voice_name,
"object": "model",
"created": int(time.time()),
"owned_by": "genie-tts"
})
return {
"object": "list",
"data": models
}
def generate_wav_header(data_size: int) -> bytes:
"""Generate WAV file header."""
header = io.BytesIO()
# RIFF header
header.write(b'RIFF')
header.write((data_size + 36).to_bytes(4, 'little')) # File size - 8
header.write(b'WAVE')
# fmt chunk
header.write(b'fmt ')
header.write((16).to_bytes(4, 'little')) # Chunk size
header.write((1).to_bytes(2, 'little')) # Audio format (PCM)
header.write((CHANNELS).to_bytes(2, 'little')) # Number of channels
header.write((SAMPLE_RATE).to_bytes(4, 'little')) # Sample rate
header.write((SAMPLE_RATE * CHANNELS * BYTES_PER_SAMPLE).to_bytes(4, 'little')) # Byte rate
header.write((CHANNELS * BYTES_PER_SAMPLE).to_bytes(2, 'little')) # Block align
header.write((BYTES_PER_SAMPLE * 8).to_bytes(2, 'little')) # Bits per sample
# data chunk
header.write(b'data')
header.write(data_size.to_bytes(4, 'little'))
return header.getvalue()
@app.post("/v1/audio/speech")
async def create_speech(request: SpeechRequest):
"""
Generate speech from text (OpenAI-compatible endpoint).
This endpoint is compatible with the OpenAI TTS API format.
Only the 'model' and 'input' parameters are used.
"""
import genie_tts as genie
# Validate model
if request.model not in VOICES:
return JSONResponse(
status_code=404,
content={
"error": {
"message": f"Model '{request.model}' not found. Available models: {list(VOICES.keys())}",
"type": "invalid_request_error",
"code": "model_not_found"
}
}
)
# Validate input
if not request.input or not request.input.strip():
return JSONResponse(
status_code=400,
content={
"error": {
"message": "Input text cannot be empty",
"type": "invalid_request_error",
"code": "invalid_input"
}
}
)
try:
# Collect audio chunks
audio_chunks = []
async for chunk in genie.tts_async(
character_name=request.model,
text=request.input.strip(),
play=False,
split_sentence=True
):
audio_chunks.append(chunk)
if not audio_chunks:
return JSONResponse(
status_code=500,
content={
"error": {
"message": "Failed to generate audio",
"type": "server_error",
"code": "generation_failed"
}
}
)
# Combine all chunks
audio_data = b''.join(audio_chunks)
# Generate complete WAV file
wav_header = generate_wav_header(len(audio_data))
wav_content = wav_header + audio_data
return Response(
content=wav_content,
media_type="audio/wav",
headers={
"Content-Disposition": "attachment; filename=speech.wav"
}
)
except Exception as e:
logger.error(f"TTS generation failed: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"error": {
"message": f"TTS generation failed: {str(e)}",
"type": "server_error",
"code": "generation_failed"
}
}
)
# Error handlers
@app.exception_handler(404)
async def not_found_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=404,
content={
"error": {
"message": "Not found",
"type": "invalid_request_error",
"code": "not_found"
}
}
)
@app.exception_handler(500)
async def internal_error_handler(request: Request, exc: Exception):
return JSONResponse(
status_code=500,
content={
"error": {
"message": "Internal server error",
"type": "server_error",
"code": "internal_error"
}
}
)
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)