File size: 6,874 Bytes
1574efa
 
 
54fb220
f20a8ad
2d85dce
1574efa
2d85dce
 
54fb220
 
2d85dce
1574efa
54fb220
2d85dce
 
 
54fb220
2d85dce
 
 
 
 
 
 
 
 
54fb220
1574efa
2d85dce
54fb220
 
2d85dce
 
 
 
 
 
 
 
 
 
 
 
54fb220
 
 
 
 
 
2d85dce
54fb220
 
 
 
 
 
 
 
 
 
 
 
2d85dce
1574efa
 
 
 
2d85dce
 
1574efa
 
2d85dce
 
 
1574efa
 
2d85dce
 
 
54fb220
2d85dce
54fb220
2d85dce
 
 
 
 
 
 
f20a8ad
 
2d85dce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54fb220
2d85dce
 
 
 
1574efa
2d85dce
1574efa
2d85dce
 
 
54fb220
 
1574efa
2d85dce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f20a8ad
2d85dce
 
 
 
f20a8ad
2d85dce
f20a8ad
2d85dce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1574efa
2d85dce
1574efa
2d85dce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1574efa
2d85dce
 
 
1574efa
2d85dce
 
1574efa
 
 
 
 
2d85dce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import os
import uvicorn
import sys
import secrets
import json
import logging
from contextlib import asynccontextmanager
from typing import Optional, Dict

from fastapi import FastAPI, HTTPException, Security, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel

# Import your model engines
import supertonic_model
import kokoro_model

# -----------------------------------------------------------------------------
# Setup Logging
# -----------------------------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------

# Map config names to Model Classes
MODEL_FACTORIES = {
    "supertonic": supertonic_model.StreamingEngine,
    "kokoro": kokoro_model.StreamingEngine
}

# Global storage for loaded engines
engines: Dict[str, object] = {}

# -----------------------------------------------------------------------------
# Authentication
# -----------------------------------------------------------------------------
security = HTTPBearer()

async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
    server_key = os.getenv("API_KEY")
    
    if not server_key:
        # Warning already logged in lifespan, safe to pass here for dev mode
        return True

    client_key = credentials.credentials
    if not secrets.compare_digest(server_key, client_key):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid API Key",
            headers={"WWW-Authenticate": "Bearer"},
        )
    return True

# -----------------------------------------------------------------------------
# Data Models
# -----------------------------------------------------------------------------
class SpeechRequest(BaseModel):
    model: Optional[str] = "tts-1"
    input: str
    voice: str = "alloy"
    format: Optional[str] = "mp3" # OpenAI defaults to mp3 usually
    speed: Optional[float] = 1.0

# -----------------------------------------------------------------------------
# Lifecycle (Startup/Shutdown)
# -----------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
    global engines

    # 1. API Key Check
    if not os.getenv("API_KEY"):
        logger.warning("API_KEY not set. API is open to the public.")
    else:
        logger.info("Secure Mode: API Key protection enabled.")

    # 2. Load Models Configuration
    models_env = os.getenv("MODELS")
    if not models_env:
        logger.error("MODELS environment variable not set. Exiting.")
        sys.exit(1)

    try:
        # SECURITY FIX: Use json.loads instead of eval
        models_config = json.loads(models_env)
    except json.JSONDecodeError as e:
        logger.error(f"Failed to parse MODELS JSON: {e}")
        sys.exit(1)

    # 3. Initialize Engines
    logger.info(f"Loading models configuration: {models_config}")
    
    for model_id, backend_type in models_config.items():
        if backend_type not in MODEL_FACTORIES:
            logger.error(f"Unknown backend type '{backend_type}' for model '{model_id}'")
            continue

        try:
            logger.info(f"Initializing {model_id} -> {backend_type}...")
            engine_class = MODEL_FACTORIES[backend_type]
            engines[model_id] = engine_class(f"{model_id}-->{backend_type}")
        except Exception as e:
            logger.error(f"Failed to load {model_id}: {e}")
            # Optional: sys.exit(1) if you want strict startup failure

    if not engines:
        logger.error("No engines loaded successfully. Exiting.")
        sys.exit(1)

    yield
    
    # Cleanup (if needed)
    engines.clear()

app = FastAPI(lifespan=lifespan, title="Streaming TTS API")

# -----------------------------------------------------------------------------
# Routes
# -----------------------------------------------------------------------------

@app.post("/v1/audio/speech", dependencies=[Depends(verify_api_key)])
async def text_to_speech(request: SpeechRequest):
    global engines
    
    if not engines:
        raise HTTPException(status_code=500, detail="No TTS engines loaded")

    # Validate Model
    if request.model not in engines:
        valid_models = list(engines.keys())
        return JSONResponse(
            status_code=404,
            content={
                "error": {
                    "message": f"Model '{request.model}' not found. Available: {valid_models}",
                    "type": "invalid_request_error",
                    "code": "model_not_found"
                }
            }
        )

    # Validate Format
    audio_format = request.format if request.format else "mp3"
    if audio_format not in ["wav", "mp3"]:
        audio_format = "wav" # Fallback

    logger.info(f"Generating: model={request.model} voice={request.voice} fmt={audio_format} len={len(request.input)}")

    try:
        generator = engines[request.model].stream_generator(
            request.input, 
            request.voice, 
            request.speed, 
            audio_format
        )
        
        return StreamingResponse(
            generator,
            media_type=f"audio/{audio_format}"
        )
    except Exception as e:
        logger.error(f"Generation failed: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/v1/models", dependencies=[Depends(verify_api_key)])
async def list_models():
    """
    Returns the list of currently loaded models dynamically.
    """
    model_list = []
    for model_id, engine_inst in engines.items():
        # Try to get inner name if available, else use backend name
        owned_by = getattr(engine_inst, "name", "system")
        model_list.append({
            "id": model_id,
            "object": "model",
            "created": 1677610602,
            "owned_by": owned_by
        })
    
    return {"object": "list", "data": model_list}

# -----------------------------------------------------------------------------
# Entry Point
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    # It's better to run uvicorn from CLI, but this supports python app.py
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", default="0.0.0.0")
    parser.add_argument("--port", type=int, default=8000)
    args = parser.parse_args()

    uvicorn.run(app, host=args.host, port=args.port)