File size: 6,874 Bytes
1574efa 54fb220 f20a8ad 2d85dce 1574efa 2d85dce 54fb220 2d85dce 1574efa 54fb220 2d85dce 54fb220 2d85dce 54fb220 1574efa 2d85dce 54fb220 2d85dce 54fb220 2d85dce 54fb220 2d85dce 1574efa 2d85dce 1574efa 2d85dce 1574efa 2d85dce 54fb220 2d85dce 54fb220 2d85dce f20a8ad 2d85dce 54fb220 2d85dce 1574efa 2d85dce 1574efa 2d85dce 54fb220 1574efa 2d85dce f20a8ad 2d85dce f20a8ad 2d85dce f20a8ad 2d85dce 1574efa 2d85dce 1574efa 2d85dce 1574efa 2d85dce 1574efa 2d85dce 1574efa 2d85dce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import os
import uvicorn
import sys
import secrets
import json
import logging
from contextlib import asynccontextmanager
from typing import Optional, Dict
from fastapi import FastAPI, HTTPException, Security, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
# Import your model engines
import supertonic_model
import kokoro_model
# -----------------------------------------------------------------------------
# Setup Logging
# -----------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
# Map config names to Model Classes
MODEL_FACTORIES = {
"supertonic": supertonic_model.StreamingEngine,
"kokoro": kokoro_model.StreamingEngine
}
# Global storage for loaded engines
engines: Dict[str, object] = {}
# -----------------------------------------------------------------------------
# Authentication
# -----------------------------------------------------------------------------
security = HTTPBearer()
async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
server_key = os.getenv("API_KEY")
if not server_key:
# Warning already logged in lifespan, safe to pass here for dev mode
return True
client_key = credentials.credentials
if not secrets.compare_digest(server_key, client_key):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API Key",
headers={"WWW-Authenticate": "Bearer"},
)
return True
# -----------------------------------------------------------------------------
# Data Models
# -----------------------------------------------------------------------------
class SpeechRequest(BaseModel):
model: Optional[str] = "tts-1"
input: str
voice: str = "alloy"
format: Optional[str] = "mp3" # OpenAI defaults to mp3 usually
speed: Optional[float] = 1.0
# -----------------------------------------------------------------------------
# Lifecycle (Startup/Shutdown)
# -----------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
global engines
# 1. API Key Check
if not os.getenv("API_KEY"):
logger.warning("API_KEY not set. API is open to the public.")
else:
logger.info("Secure Mode: API Key protection enabled.")
# 2. Load Models Configuration
models_env = os.getenv("MODELS")
if not models_env:
logger.error("MODELS environment variable not set. Exiting.")
sys.exit(1)
try:
# SECURITY FIX: Use json.loads instead of eval
models_config = json.loads(models_env)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse MODELS JSON: {e}")
sys.exit(1)
# 3. Initialize Engines
logger.info(f"Loading models configuration: {models_config}")
for model_id, backend_type in models_config.items():
if backend_type not in MODEL_FACTORIES:
logger.error(f"Unknown backend type '{backend_type}' for model '{model_id}'")
continue
try:
logger.info(f"Initializing {model_id} -> {backend_type}...")
engine_class = MODEL_FACTORIES[backend_type]
engines[model_id] = engine_class(f"{model_id}-->{backend_type}")
except Exception as e:
logger.error(f"Failed to load {model_id}: {e}")
# Optional: sys.exit(1) if you want strict startup failure
if not engines:
logger.error("No engines loaded successfully. Exiting.")
sys.exit(1)
yield
# Cleanup (if needed)
engines.clear()
app = FastAPI(lifespan=lifespan, title="Streaming TTS API")
# -----------------------------------------------------------------------------
# Routes
# -----------------------------------------------------------------------------
@app.post("/v1/audio/speech", dependencies=[Depends(verify_api_key)])
async def text_to_speech(request: SpeechRequest):
global engines
if not engines:
raise HTTPException(status_code=500, detail="No TTS engines loaded")
# Validate Model
if request.model not in engines:
valid_models = list(engines.keys())
return JSONResponse(
status_code=404,
content={
"error": {
"message": f"Model '{request.model}' not found. Available: {valid_models}",
"type": "invalid_request_error",
"code": "model_not_found"
}
}
)
# Validate Format
audio_format = request.format if request.format else "mp3"
if audio_format not in ["wav", "mp3"]:
audio_format = "wav" # Fallback
logger.info(f"Generating: model={request.model} voice={request.voice} fmt={audio_format} len={len(request.input)}")
try:
generator = engines[request.model].stream_generator(
request.input,
request.voice,
request.speed,
audio_format
)
return StreamingResponse(
generator,
media_type=f"audio/{audio_format}"
)
except Exception as e:
logger.error(f"Generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models", dependencies=[Depends(verify_api_key)])
async def list_models():
"""
Returns the list of currently loaded models dynamically.
"""
model_list = []
for model_id, engine_inst in engines.items():
# Try to get inner name if available, else use backend name
owned_by = getattr(engine_inst, "name", "system")
model_list.append({
"id": model_id,
"object": "model",
"created": 1677610602,
"owned_by": owned_by
})
return {"object": "list", "data": model_list}
# -----------------------------------------------------------------------------
# Entry Point
# -----------------------------------------------------------------------------
if __name__ == "__main__":
# It's better to run uvicorn from CLI, but this supports python app.py
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port) |