tts / app.py
waxz
change MODELS env
2d85dce
import os
import uvicorn
import sys
import secrets
import json
import logging
from contextlib import asynccontextmanager
from typing import Optional, Dict
from fastapi import FastAPI, HTTPException, Security, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
# Import your model engines
import supertonic_model
import kokoro_model
# -----------------------------------------------------------------------------
# Setup Logging
# -----------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
# Map config names to Model Classes
MODEL_FACTORIES = {
"supertonic": supertonic_model.StreamingEngine,
"kokoro": kokoro_model.StreamingEngine
}
# Global storage for loaded engines
engines: Dict[str, object] = {}
# -----------------------------------------------------------------------------
# Authentication
# -----------------------------------------------------------------------------
security = HTTPBearer()
async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
server_key = os.getenv("API_KEY")
if not server_key:
# Warning already logged in lifespan, safe to pass here for dev mode
return True
client_key = credentials.credentials
if not secrets.compare_digest(server_key, client_key):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API Key",
headers={"WWW-Authenticate": "Bearer"},
)
return True
# -----------------------------------------------------------------------------
# Data Models
# -----------------------------------------------------------------------------
class SpeechRequest(BaseModel):
model: Optional[str] = "tts-1"
input: str
voice: str = "alloy"
format: Optional[str] = "mp3" # OpenAI defaults to mp3 usually
speed: Optional[float] = 1.0
# -----------------------------------------------------------------------------
# Lifecycle (Startup/Shutdown)
# -----------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
global engines
# 1. API Key Check
if not os.getenv("API_KEY"):
logger.warning("API_KEY not set. API is open to the public.")
else:
logger.info("Secure Mode: API Key protection enabled.")
# 2. Load Models Configuration
models_env = os.getenv("MODELS")
if not models_env:
logger.error("MODELS environment variable not set. Exiting.")
sys.exit(1)
try:
# SECURITY FIX: Use json.loads instead of eval
models_config = json.loads(models_env)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse MODELS JSON: {e}")
sys.exit(1)
# 3. Initialize Engines
logger.info(f"Loading models configuration: {models_config}")
for model_id, backend_type in models_config.items():
if backend_type not in MODEL_FACTORIES:
logger.error(f"Unknown backend type '{backend_type}' for model '{model_id}'")
continue
try:
logger.info(f"Initializing {model_id} -> {backend_type}...")
engine_class = MODEL_FACTORIES[backend_type]
engines[model_id] = engine_class(f"{model_id}-->{backend_type}")
except Exception as e:
logger.error(f"Failed to load {model_id}: {e}")
# Optional: sys.exit(1) if you want strict startup failure
if not engines:
logger.error("No engines loaded successfully. Exiting.")
sys.exit(1)
yield
# Cleanup (if needed)
engines.clear()
app = FastAPI(lifespan=lifespan, title="Streaming TTS API")
# -----------------------------------------------------------------------------
# Routes
# -----------------------------------------------------------------------------
@app.post("/v1/audio/speech", dependencies=[Depends(verify_api_key)])
async def text_to_speech(request: SpeechRequest):
global engines
if not engines:
raise HTTPException(status_code=500, detail="No TTS engines loaded")
# Validate Model
if request.model not in engines:
valid_models = list(engines.keys())
return JSONResponse(
status_code=404,
content={
"error": {
"message": f"Model '{request.model}' not found. Available: {valid_models}",
"type": "invalid_request_error",
"code": "model_not_found"
}
}
)
# Validate Format
audio_format = request.format if request.format else "mp3"
if audio_format not in ["wav", "mp3"]:
audio_format = "wav" # Fallback
logger.info(f"Generating: model={request.model} voice={request.voice} fmt={audio_format} len={len(request.input)}")
try:
generator = engines[request.model].stream_generator(
request.input,
request.voice,
request.speed,
audio_format
)
return StreamingResponse(
generator,
media_type=f"audio/{audio_format}"
)
except Exception as e:
logger.error(f"Generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models", dependencies=[Depends(verify_api_key)])
async def list_models():
"""
Returns the list of currently loaded models dynamically.
"""
model_list = []
for model_id, engine_inst in engines.items():
# Try to get inner name if available, else use backend name
owned_by = getattr(engine_inst, "name", "system")
model_list.append({
"id": model_id,
"object": "model",
"created": 1677610602,
"owned_by": owned_by
})
return {"object": "list", "data": model_list}
# -----------------------------------------------------------------------------
# Entry Point
# -----------------------------------------------------------------------------
if __name__ == "__main__":
# It's better to run uvicorn from CLI, but this supports python app.py
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port)