waxz commited on
Commit
2d85dce
·
1 Parent(s): f20a8ad

change MODELS env

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +140 -86
README.md CHANGED
@@ -31,7 +31,7 @@ uv pip install -r ./requirements.txt
31
 
32
  ```bash
33
  export API_KEY=yourapi
34
- export MODELS="{'tts-2':'supertonic','tts-1':'kokoro'}"
35
  python app.py
36
  ```
37
 
 
31
 
32
  ```bash
33
  export API_KEY=yourapi
34
+ export MODELS='{"tts-1": "kokoro", "tts-2": "supertonic"}'
35
  python app.py
36
  ```
37
 
app.py CHANGED
@@ -1,43 +1,57 @@
1
  import os
2
- import argparse
3
  import uvicorn
4
  import sys
5
  import secrets
6
  import json
 
7
  from contextlib import asynccontextmanager
 
 
8
  from fastapi import FastAPI, HTTPException, Security, status, Depends
9
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
10
- from fastapi.responses import StreamingResponse, Response
11
  from pydantic import BaseModel
12
- from typing import Optional, Literal
13
- import supertonic_model,kokoro_model
14
 
 
 
 
15
 
 
 
 
 
 
 
 
 
 
16
 
17
  # -----------------------------------------------------------------------------
18
- # 1. Authentication Logic
19
  # -----------------------------------------------------------------------------
20
 
21
- # Standard Bearer token scheme (used by OpenAI clients)
 
 
 
 
 
 
 
 
 
 
 
22
  security = HTTPBearer()
23
 
24
  async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
25
- """
26
- Verifies that the Bearer token sent by the client matches the API_KEY env var.
27
- """
28
  server_key = os.getenv("API_KEY")
29
 
30
- # If no key is set on the server, we can either:
31
- # A) Block everything (Safe default)
32
- # B) Allow everything (Dev mode)
33
- # Let's Allow everything but print a warning if no key is configured.
34
  if not server_key:
35
- # print("WARNING: No API_KEY set. Allowing unauthenticated request.")
36
  return True
37
 
38
  client_key = credentials.credentials
39
-
40
- # Secure string comparison to prevent timing attacks
41
  if not secrets.compare_digest(server_key, client_key):
42
  raise HTTPException(
43
  status_code=status.HTTP_401_UNAUTHORIZED,
@@ -47,103 +61,143 @@ async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(se
47
  return True
48
 
49
  # -----------------------------------------------------------------------------
50
- # 2. Text & Audio Utilities
51
- # -----------------------------------------------------------------------------
52
-
53
  # -----------------------------------------------------------------------------
54
- # 2. Streaming Engine with Fallback Logic
55
- # -----------------------------------------------------------------------------
56
-
57
- # -----------------------------------------------------------------------------
58
- # 3. API Setup
59
- # -----------------------------------------------------------------------------
60
-
61
- engine = {}
62
-
63
  class SpeechRequest(BaseModel):
64
  model: Optional[str] = "tts-1"
65
  input: str
66
- voice: str = "alloy" # Default 'alloy'
67
- format: Optional[str] = "wav"
68
  speed: Optional[float] = 1.0
69
 
70
-
 
 
71
  @asynccontextmanager
72
  async def lifespan(app: FastAPI):
73
- global engine
74
- # Check if API Key is set
 
75
  if not os.getenv("API_KEY"):
76
- print("\n!!! WARNING: API_KEY not set. API is open to the public. !!!\n")
77
  else:
78
- print(f"\n*** Secure Mode: API Key protection enabled. ***\n")
 
 
 
 
 
 
79
 
80
- MODELS = None
81
- if not os.getenv("MODELS"):
82
- print(f"\n!!! WARNING: MODELS not set")
83
- sys.exit(0)
84
- else:
85
- MODELS = os.getenv("MODELS")
86
-
87
- print(f"\n!!! WARNING: eval {MODELS}")
88
  try:
89
- MODELS = eval(MODELS)
90
- except:
91
- print(f"\n!!! WARNING: eval {MODELS} failed")
92
- sys.exit(0)
93
-
94
- print(f"\n*** Load {MODELS}. ***\n")
95
- for k,v in MODELS.items():
96
- print(f"Mapping {k}-->{v}")
97
- if "supertonic" == v:
98
- engine[k] = supertonic_model.StreamingEngine(f"{k}-->{v}")
99
- if "kokoro" == v:
100
- engine[k] = kokoro_model.StreamingEngine(f"{k}-->{v}")
101
- yield
 
 
 
 
 
 
 
 
 
 
 
 
102
 
 
 
 
 
103
 
104
- app = FastAPI(lifespan=lifespan)
105
 
 
 
 
106
 
107
- # PROTECTED ROUTE
108
- # The Depends(verify_api_key) enforces auth for this specific endpoint
109
  @app.post("/v1/audio/speech", dependencies=[Depends(verify_api_key)])
110
  async def text_to_speech(request: SpeechRequest):
111
- global engine
112
- if not engine:
113
- raise HTTPException(500, "Engine not loaded")
114
-
115
- print(f"request:{request}")
116
- format = request.format
117
- model = request.model
118
- if format not in ["wav", "mp3"]:
119
- format = "wav"
120
- if model not in engine.keys():
121
- print(f"!!!WARNING {model} not found")
122
-
123
- content = {
124
- "ok": False,
125
- "message": f"!!!WARNING {model} not found"
126
- }
 
 
127
 
128
- content = json.dumps(content)
 
 
 
129
 
130
- return Response(content=content, status_code=404,media_type="application/json")
131
-
132
-
133
 
134
- return StreamingResponse(
135
- engine[model].stream_generator(request.input, request.voice, request.speed, format),
136
- media_type=f"audio/{format}"
137
- )
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- @app.get("/v1/models")
140
  async def list_models():
141
- return {"data": [{"id": "tts-1", "owned_by": "supertonic"}]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
 
 
 
143
  if __name__ == "__main__":
 
 
144
  parser = argparse.ArgumentParser()
145
  parser.add_argument("--host", default="0.0.0.0")
146
  parser.add_argument("--port", type=int, default=8000)
147
  args = parser.parse_args()
148
 
149
- uvicorn.run(app, host=args.host, port=args.port)
 
1
  import os
 
2
  import uvicorn
3
  import sys
4
  import secrets
5
  import json
6
+ import logging
7
  from contextlib import asynccontextmanager
8
+ from typing import Optional, Dict
9
+
10
  from fastapi import FastAPI, HTTPException, Security, status, Depends
11
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
+ from fastapi.responses import StreamingResponse, JSONResponse
13
  from pydantic import BaseModel
 
 
14
 
15
+ # Import your model engines
16
+ import supertonic_model
17
+ import kokoro_model
18
 
19
+ # -----------------------------------------------------------------------------
20
+ # Setup Logging
21
+ # -----------------------------------------------------------------------------
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format="%(asctime)s [%(levelname)s] %(message)s",
25
+ handlers=[logging.StreamHandler()]
26
+ )
27
+ logger = logging.getLogger(__name__)
28
 
29
  # -----------------------------------------------------------------------------
30
+ # Configuration
31
  # -----------------------------------------------------------------------------
32
 
33
+ # Map config names to Model Classes
34
+ MODEL_FACTORIES = {
35
+ "supertonic": supertonic_model.StreamingEngine,
36
+ "kokoro": kokoro_model.StreamingEngine
37
+ }
38
+
39
+ # Global storage for loaded engines
40
+ engines: Dict[str, object] = {}
41
+
42
+ # -----------------------------------------------------------------------------
43
+ # Authentication
44
+ # -----------------------------------------------------------------------------
45
  security = HTTPBearer()
46
 
47
  async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
 
 
 
48
  server_key = os.getenv("API_KEY")
49
 
 
 
 
 
50
  if not server_key:
51
+ # Warning already logged in lifespan, safe to pass here for dev mode
52
  return True
53
 
54
  client_key = credentials.credentials
 
 
55
  if not secrets.compare_digest(server_key, client_key):
56
  raise HTTPException(
57
  status_code=status.HTTP_401_UNAUTHORIZED,
 
61
  return True
62
 
63
  # -----------------------------------------------------------------------------
64
+ # Data Models
 
 
65
  # -----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
66
  class SpeechRequest(BaseModel):
67
  model: Optional[str] = "tts-1"
68
  input: str
69
+ voice: str = "alloy"
70
+ format: Optional[str] = "mp3" # OpenAI defaults to mp3 usually
71
  speed: Optional[float] = 1.0
72
 
73
+ # -----------------------------------------------------------------------------
74
+ # Lifecycle (Startup/Shutdown)
75
+ # -----------------------------------------------------------------------------
76
  @asynccontextmanager
77
  async def lifespan(app: FastAPI):
78
+ global engines
79
+
80
+ # 1. API Key Check
81
  if not os.getenv("API_KEY"):
82
+ logger.warning("API_KEY not set. API is open to the public.")
83
  else:
84
+ logger.info("Secure Mode: API Key protection enabled.")
85
+
86
+ # 2. Load Models Configuration
87
+ models_env = os.getenv("MODELS")
88
+ if not models_env:
89
+ logger.error("MODELS environment variable not set. Exiting.")
90
+ sys.exit(1)
91
 
 
 
 
 
 
 
 
 
92
  try:
93
+ # SECURITY FIX: Use json.loads instead of eval
94
+ models_config = json.loads(models_env)
95
+ except json.JSONDecodeError as e:
96
+ logger.error(f"Failed to parse MODELS JSON: {e}")
97
+ sys.exit(1)
98
+
99
+ # 3. Initialize Engines
100
+ logger.info(f"Loading models configuration: {models_config}")
101
+
102
+ for model_id, backend_type in models_config.items():
103
+ if backend_type not in MODEL_FACTORIES:
104
+ logger.error(f"Unknown backend type '{backend_type}' for model '{model_id}'")
105
+ continue
106
+
107
+ try:
108
+ logger.info(f"Initializing {model_id} -> {backend_type}...")
109
+ engine_class = MODEL_FACTORIES[backend_type]
110
+ engines[model_id] = engine_class(f"{model_id}-->{backend_type}")
111
+ except Exception as e:
112
+ logger.error(f"Failed to load {model_id}: {e}")
113
+ # Optional: sys.exit(1) if you want strict startup failure
114
+
115
+ if not engines:
116
+ logger.error("No engines loaded successfully. Exiting.")
117
+ sys.exit(1)
118
 
119
+ yield
120
+
121
+ # Cleanup (if needed)
122
+ engines.clear()
123
 
124
+ app = FastAPI(lifespan=lifespan, title="Streaming TTS API")
125
 
126
+ # -----------------------------------------------------------------------------
127
+ # Routes
128
+ # -----------------------------------------------------------------------------
129
 
 
 
130
  @app.post("/v1/audio/speech", dependencies=[Depends(verify_api_key)])
131
  async def text_to_speech(request: SpeechRequest):
132
+ global engines
133
+
134
+ if not engines:
135
+ raise HTTPException(status_code=500, detail="No TTS engines loaded")
136
+
137
+ # Validate Model
138
+ if request.model not in engines:
139
+ valid_models = list(engines.keys())
140
+ return JSONResponse(
141
+ status_code=404,
142
+ content={
143
+ "error": {
144
+ "message": f"Model '{request.model}' not found. Available: {valid_models}",
145
+ "type": "invalid_request_error",
146
+ "code": "model_not_found"
147
+ }
148
+ }
149
+ )
150
 
151
+ # Validate Format
152
+ audio_format = request.format if request.format else "mp3"
153
+ if audio_format not in ["wav", "mp3"]:
154
+ audio_format = "wav" # Fallback
155
 
156
+ logger.info(f"Generating: model={request.model} voice={request.voice} fmt={audio_format} len={len(request.input)}")
 
 
157
 
158
+ try:
159
+ generator = engines[request.model].stream_generator(
160
+ request.input,
161
+ request.voice,
162
+ request.speed,
163
+ audio_format
164
+ )
165
+
166
+ return StreamingResponse(
167
+ generator,
168
+ media_type=f"audio/{audio_format}"
169
+ )
170
+ except Exception as e:
171
+ logger.error(f"Generation failed: {e}")
172
+ raise HTTPException(status_code=500, detail=str(e))
173
 
174
+ @app.get("/v1/models", dependencies=[Depends(verify_api_key)])
175
  async def list_models():
176
+ """
177
+ Returns the list of currently loaded models dynamically.
178
+ """
179
+ model_list = []
180
+ for model_id, engine_inst in engines.items():
181
+ # Try to get inner name if available, else use backend name
182
+ owned_by = getattr(engine_inst, "name", "system")
183
+ model_list.append({
184
+ "id": model_id,
185
+ "object": "model",
186
+ "created": 1677610602,
187
+ "owned_by": owned_by
188
+ })
189
+
190
+ return {"object": "list", "data": model_list}
191
 
192
+ # -----------------------------------------------------------------------------
193
+ # Entry Point
194
+ # -----------------------------------------------------------------------------
195
  if __name__ == "__main__":
196
+ # It's better to run uvicorn from CLI, but this supports python app.py
197
+ import argparse
198
  parser = argparse.ArgumentParser()
199
  parser.add_argument("--host", default="0.0.0.0")
200
  parser.add_argument("--port", type=int, default=8000)
201
  args = parser.parse_args()
202
 
203
+ uvicorn.run(app, host=args.host, port=args.port)