ataberkkilavuzcu commited on
Commit
7c7c63c
·
verified ·
1 Parent(s): f9f777d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +269 -82
app.py CHANGED
@@ -1,7 +1,10 @@
1
  import base64
 
2
  import os
3
  import tempfile
4
  import uuid
 
 
5
  from pathlib import Path
6
  from threading import Lock
7
  from typing import Dict, Optional
@@ -14,85 +17,158 @@ from fastapi import BackgroundTasks, Body, FastAPI, Header, HTTPException
14
  from fastapi.responses import FileResponse, JSONResponse
15
  from pydantic import BaseModel, Field, HttpUrl
16
 
 
 
 
17
  SPACE_API_KEY = os.getenv("SPACE_API_KEY")
18
  HF_TOKEN = (
19
  os.getenv("HUGGING_FACE_HUB_TOKEN")
20
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
21
  or os.getenv("HF_TOKEN")
22
  )
23
- MODEL_REPO = "IndexTeam/IndexTTS-2"
24
- MAX_TEXT_LENGTH = 1000
25
- DEFAULT_LANGUAGE = "en"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
- JOBS: Dict[str, Dict[str, str]] = {}
29
- JOB_LOCK = Lock()
30
 
31
- # Set token in environment before importing
 
 
32
  if HF_TOKEN:
33
  os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
34
  os.environ["HF_TOKEN"] = HF_TOKEN
35
  try:
36
  from huggingface_hub import login
 
37
  login(token=HF_TOKEN, add_to_git_credential=False)
38
- except ImportError:
39
  pass
40
 
41
- # Download model checkpoints from Hugging Face
42
- MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2")
43
- os.makedirs(MODEL_DIR, exist_ok=True)
44
-
45
  try:
46
  from huggingface_hub import snapshot_download
47
-
48
- # Download model if not already present
49
- if not Path(MODEL_DIR, "config.yaml").exists():
50
- print(f"Downloading IndexTTS2 model from {MODEL_REPO}...")
51
- snapshot_download(
52
- repo_id=MODEL_REPO,
53
- local_dir=MODEL_DIR,
54
- token=HF_TOKEN,
55
- )
56
- print("Model download complete.")
57
  except Exception as exc:
58
- print(f"Warning: Could not download model: {exc}")
59
- # Continue anyway - model might already be present
60
 
61
- # Initialize IndexTTS2
 
 
62
  try:
63
  from indextts.infer_v2 import IndexTTS2
64
-
65
- cfg_path = os.path.join(MODEL_DIR, "config.yaml")
66
- if not Path(cfg_path).exists():
67
- raise FileNotFoundError(f"Config file not found at {cfg_path}. Model may not be downloaded.")
68
-
69
- tts_model = IndexTTS2(
70
- cfg_path=cfg_path,
71
- model_dir=MODEL_DIR,
72
- use_fp16=False, # CPU doesn't support FP16
73
- use_cuda_kernel=False, # CPU mode
74
- use_deepspeed=False, # CPU mode
75
- )
76
- print("IndexTTS2 model loaded successfully.")
77
  except Exception as exc:
78
- raise RuntimeError(f"Failed to load IndexTTS2 model: {exc}") from exc
79
-
80
- app = FastAPI(title="indextts2-api", version="1.0.0")
81
-
82
-
83
- class GenerateRequest(BaseModel):
84
- text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
85
- speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded audio")
86
- language: Optional[str] = Field(DEFAULT_LANGUAGE, description="ISO code, default en")
87
-
88
-
89
- def _require_api_key(x_api_key: Optional[str]):
90
- if not SPACE_API_KEY:
91
- return
92
- if x_api_key != SPACE_API_KEY:
93
- raise HTTPException(status_code=401, detail="Unauthorized")
94
 
 
 
 
 
 
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def _write_temp_audio_from_url(url: HttpUrl) -> str:
97
  response = requests.get(url, stream=True, timeout=30)
98
  if response.status_code >= 400:
@@ -108,7 +184,7 @@ def _write_temp_audio_from_url(url: HttpUrl) -> str:
108
  def _write_temp_audio_from_base64(payload: str) -> str:
109
  try:
110
  raw = base64.b64decode(payload)
111
- except Exception as exc: # pragma: no cover
112
  raise HTTPException(status_code=400, detail="Invalid base64 speaker_wav") from exc
113
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
114
  tmp.write(raw)
@@ -121,12 +197,10 @@ def _temp_speaker_file(speaker_wav: str) -> str:
121
  return _write_temp_audio_from_base64(speaker_wav)
122
 
123
 
124
- def _preprocess_audio_wav(path: str, target_sr: int = 24000, target_peak: float = 0.98) -> str:
125
  """
126
- Light preprocessing to stabilize embeddings and output quality:
127
- - convert to mono
128
- - resample to target_sr
129
- - peak-normalize to target_peak (avoid clipping)
130
  """
131
  wav, sr = torchaudio.load(path)
132
 
@@ -146,11 +220,64 @@ def _preprocess_audio_wav(path: str, target_sr: int = 24000, target_peak: float
146
  scale = min(target_peak / peak, 1.0)
147
  wav = wav * scale
148
 
149
- # Overwrite input file to avoid extra temp files
150
  torchaudio.save(path, wav, sr, bits_per_sample=16)
151
  return path
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def _set_job(job_id: str, **kwargs):
155
  with JOB_LOCK:
156
  JOBS[job_id] = {**JOBS.get(job_id, {}), **kwargs}
@@ -168,48 +295,111 @@ def _pop_job(job_id: str) -> Optional[Dict[str, str]]:
168
 
169
 
170
  def _cleanup_files(*files: str):
171
- """Background task to clean up temporary files after response is sent."""
172
  for file_path in files:
173
  if file_path and Path(file_path).exists():
174
  try:
175
  Path(file_path).unlink(missing_ok=True)
176
  except Exception:
177
- pass # Ignore cleanup errors
178
 
179
 
180
  def _run_generate_job(job_id: str, payload: Dict[str, str]):
 
 
 
181
  speaker_file = None
182
  output_file = None
183
  _set_job(job_id, status="processing")
184
  try:
 
185
  speaker_file = _temp_speaker_file(payload["speaker_wav"])
186
- speaker_file = _preprocess_audio_wav(speaker_file)
187
- output_file = os.path.join(tempfile.gettempdir(), f"indextts2-{uuid.uuid4()}.wav")
 
 
 
 
 
 
 
 
 
 
188
 
189
- tts_model.infer(
190
- spk_audio_prompt=speaker_file,
191
- text=payload["text"],
192
- output_path=output_file,
193
- use_random=False,
194
- verbose=False,
195
- )
196
-
197
- output_file = _preprocess_audio_wav(output_file)
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  if not Path(output_file).exists():
200
- raise RuntimeError(f"TTS generation failed: output file was not created at {output_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- _cleanup_files(speaker_file)
203
  _set_job(job_id, status="completed", output_file=output_file)
204
  except Exception as exc:
205
  _cleanup_files(speaker_file, output_file)
206
  _set_job(job_id, status="error", error=str(exc))
207
 
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  @app.post("/health")
210
  def health(x_api_key: Optional[str] = Header(default=None)):
211
  _require_api_key(x_api_key)
212
- return {"status": "ok", "model": "indextts2", "device": DEVICE}
213
 
214
 
215
  @app.post("/generate")
@@ -222,8 +412,8 @@ def generate(
222
  job_id = str(uuid.uuid4())
223
  _set_job(job_id, status="queued")
224
 
225
- # Offload the long-running synthesis so the HTTP request stays fast (<100s)
226
- background_tasks.add_task(_run_generate_job, job_id, payload.dict())
227
 
228
  return JSONResponse(
229
  status_code=202,
@@ -276,7 +466,4 @@ def job_result(
276
 
277
  @app.get("/")
278
  def root():
279
- return {
280
- "name": "indextts2-api",
281
- "endpoints": ["/health", "/generate", "/status/{job_id}", "/result/{job_id}"],
282
- }
 
1
  import base64
2
+ import hashlib
3
  import os
4
  import tempfile
5
  import uuid
6
+ import time
7
+ from concurrent.futures import ThreadPoolExecutor
8
  from pathlib import Path
9
  from threading import Lock
10
  from typing import Dict, Optional
 
17
  from fastapi.responses import FileResponse, JSONResponse
18
  from pydantic import BaseModel, Field, HttpUrl
19
 
20
+ # ----------------------------
21
+ # Config / Tunables
22
+ # ----------------------------
23
  SPACE_API_KEY = os.getenv("SPACE_API_KEY")
24
  HF_TOKEN = (
25
  os.getenv("HUGGING_FACE_HUB_TOKEN")
26
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
27
  or os.getenv("HF_TOKEN")
28
  )
29
+ MODEL_REPO = os.getenv("MODEL_REPO", "IndexTeam/IndexTTS-2")
30
+ MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2")
31
+ os.makedirs(MODEL_DIR, exist_ok=True)
32
+
33
+ MAX_TEXT_LENGTH = int(os.getenv("MAX_TEXT_LENGTH", "1000"))
34
+ DEFAULT_LANGUAGE = os.getenv("DEFAULT_LANGUAGE", "en")
35
+ TARGET_SR = int(os.getenv("TARGET_SR", "16000")) # lowered to 16 kHz for speed
36
+ TORCH_NUM_THREADS = int(os.getenv("TORCH_NUM_THREADS", "2"))
37
+
38
+ # Embedding cache settings
39
+ EMBED_CACHE_MAX = int(os.getenv("EMBED_CACHE_MAX", "128")) # max entries
40
+ EMBED_CACHE_TTL = int(os.getenv("EMBED_CACHE_TTL", str(60 * 60 * 24))) # 24h by default
41
+
42
+ # Threadpool for bounded parallel jobs (keeps worker threads limited)
43
+ WORKER_COUNT = int(os.getenv("WORKER_COUNT", "1")) # keep low on CPU
44
+
45
+ # ----------------------------
46
+ # Torch settings
47
+ # ----------------------------
48
+ torch.set_num_threads(TORCH_NUM_THREADS)
49
+ try:
50
+ # optional: limit interop threads
51
+ torch.set_num_interop_threads(max(1, TORCH_NUM_THREADS // 2))
52
+ except Exception:
53
+ pass
54
 
55
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
56
 
57
+ # ----------------------------
58
+ # Hugging Face login (if token)
59
+ # ----------------------------
60
  if HF_TOKEN:
61
  os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
62
  os.environ["HF_TOKEN"] = HF_TOKEN
63
  try:
64
  from huggingface_hub import login
65
+
66
  login(token=HF_TOKEN, add_to_git_credential=False)
67
+ except Exception:
68
  pass
69
 
70
+ # ----------------------------
71
+ # Optionally download model snapshot (only if missing)
72
+ # ----------------------------
 
73
  try:
74
  from huggingface_hub import snapshot_download
75
+
76
+ cfg_path = Path(MODEL_DIR) / "config.yaml"
77
+ if not cfg_path.exists():
78
+ print(f"Config missing; downloading model snapshot {MODEL_REPO} to {MODEL_DIR} ...")
79
+ snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, token=HF_TOKEN)
80
+ print("Download complete.")
 
 
 
 
81
  except Exception as exc:
82
+ print(f"Warning: snapshot_download skipped or failed: {exc}")
 
83
 
84
+ # ----------------------------
85
+ # Load IndexTTS2 model (CPU mode safe defaults)
86
+ # ----------------------------
87
  try:
88
  from indextts.infer_v2 import IndexTTS2
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  except Exception as exc:
90
+ raise RuntimeError("indextts.infer_v2 import failed. Make sure IndexTTS2 is installed.") from exc
91
+
92
+ cfg_path = os.path.join(MODEL_DIR, "config.yaml")
93
+ if not Path(cfg_path).exists():
94
+ raise FileNotFoundError(f"Config file not found at {cfg_path}. Place model files in {MODEL_DIR}.")
95
+
96
+ # Use CPU-safe options. If GPU becomes available, you can toggle use_fp16/use_cuda_kernel.
97
+ tts_model = IndexTTS2(
98
+ cfg_path=cfg_path,
99
+ model_dir=MODEL_DIR,
100
+ use_fp16=False, # CPU doesn't support FP16 reliably
101
+ use_cuda_kernel=False,
102
+ use_deepspeed=False,
103
+ )
104
+ print("IndexTTS2 loaded.")
 
105
 
106
+ # ----------------------------
107
+ # App + job state
108
+ # ----------------------------
109
+ app = FastAPI(title="indextts2-api-optimized", version="1.0.0")
110
+ JOBS: Dict[str, Dict[str, str]] = {}
111
+ JOB_LOCK = Lock()
112
 
113
+ # Threadpool for running TTS jobs; limits concurrency to WORKER_COUNT
114
+ EXECUTOR = ThreadPoolExecutor(max_workers=WORKER_COUNT)
115
+
116
+ # ----------------------------
117
+ # Simple LRU-like embedding cache (in-memory)
118
+ # ----------------------------
119
+ class _EmbedCacheEntry:
120
+ def __init__(self, emb_tensor: torch.Tensor):
121
+ self.emb = emb_tensor.detach().cpu() # keep on CPU, detached
122
+ self.ts = time.time()
123
+
124
+
125
+ EMBED_CACHE: Dict[str, _EmbedCacheEntry] = {}
126
+ EMBED_CACHE_LOCK = Lock()
127
+
128
+
129
+ def _evict_cache_if_needed():
130
+ with EMBED_CACHE_LOCK:
131
+ if len(EMBED_CACHE) <= EMBED_CACHE_MAX:
132
+ return
133
+ # Simple eviction: remove oldest entries
134
+ items = sorted(EMBED_CACHE.items(), key=lambda kv: kv[1].ts)
135
+ for key, _ in items[: max(1, len(items) - EMBED_CACHE_MAX)]:
136
+ EMBED_CACHE.pop(key, None)
137
+
138
+
139
+ def _get_cache_key_for_file(path: str) -> str:
140
+ # Hash the file contents (fast enough for short audio)
141
+ h = hashlib.sha256()
142
+ with open(path, "rb") as f:
143
+ while True:
144
+ chunk = f.read(8192)
145
+ if not chunk:
146
+ break
147
+ h.update(chunk)
148
+ return h.hexdigest()
149
+
150
+
151
+ def _cache_get(key: str) -> Optional[torch.Tensor]:
152
+ with EMBED_CACHE_LOCK:
153
+ entry = EMBED_CACHE.get(key)
154
+ if not entry:
155
+ return None
156
+ if (time.time() - entry.ts) > EMBED_CACHE_TTL:
157
+ EMBED_CACHE.pop(key, None)
158
+ return None
159
+ # update timestamp for LRU-ish behavior
160
+ entry.ts = time.time()
161
+ return entry.emb.clone()
162
+
163
+
164
+ def _cache_set(key: str, emb: torch.Tensor):
165
+ with EMBED_CACHE_LOCK:
166
+ EMBED_CACHE[key] = _EmbedCacheEntry(emb)
167
+ _evict_cache_if_needed()
168
+
169
+ # ----------------------------
170
+ # Utilities for audio input handling
171
+ # ----------------------------
172
  def _write_temp_audio_from_url(url: HttpUrl) -> str:
173
  response = requests.get(url, stream=True, timeout=30)
174
  if response.status_code >= 400:
 
184
  def _write_temp_audio_from_base64(payload: str) -> str:
185
  try:
186
  raw = base64.b64decode(payload)
187
+ except Exception as exc:
188
  raise HTTPException(status_code=400, detail="Invalid base64 speaker_wav") from exc
189
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
190
  tmp.write(raw)
 
197
  return _write_temp_audio_from_base64(speaker_wav)
198
 
199
 
200
+ def _preprocess_audio_wav(path: str, target_sr: int = TARGET_SR, target_peak: float = 0.98) -> str:
201
  """
202
+ Convert to mono, resample to target_sr, and peak-normalize.
203
+ Overwrites the input file.
 
 
204
  """
205
  wav, sr = torchaudio.load(path)
206
 
 
220
  scale = min(target_peak / peak, 1.0)
221
  wav = wav * scale
222
 
223
+ # Overwrite file in 16-bit PCM
224
  torchaudio.save(path, wav, sr, bits_per_sample=16)
225
  return path
226
 
227
 
228
+ # ----------------------------
229
+ # Embedding extraction helper (tries multiple API variants)
230
+ # ----------------------------
231
+ def _compute_spk_embedding(speaker_path: str) -> torch.Tensor:
232
+ """
233
+ Returns a CPU tensor containing the speaker embedding.
234
+ Tries multiple methods to extract embedding (get_spk_emb, extract_spk_emb, etc.)
235
+ """
236
+ # Key: use hash of file contents
237
+ key = _get_cache_key_for_file(speaker_path)
238
+ cached = _cache_get(key)
239
+ if cached is not None:
240
+ return cached
241
+
242
+ # Ensure audio preprocessed (mono/resample/normalize)
243
+ _preprocess_audio_wav(speaker_path, target_sr=TARGET_SR)
244
+
245
+ # Try known wrapper method names (depending on IndexTTS2 version)
246
+ emb = None
247
+ try:
248
+ if hasattr(tts_model, "get_spk_emb"):
249
+ emb = tts_model.get_spk_emb(speaker_path)
250
+ elif hasattr(tts_model, "extract_spk_emb"):
251
+ emb = tts_model.extract_spk_emb(speaker_path)
252
+ elif hasattr(tts_model, "spk_encoder") and hasattr(tts_model.spk_encoder, "embed_utterance"):
253
+ # some wrappers expose internal encoders
254
+ wav, sr = torchaudio.load(speaker_path)
255
+ if wav.shape[0] > 1:
256
+ wav = wav.mean(dim=0, keepdim=True)
257
+ wav = wav.squeeze(0).numpy() # expected shape for some encoders
258
+ emb = tts_model.spk_encoder.embed_utterance(wav)
259
+ emb = torch.from_numpy(emb)
260
+ else:
261
+ raise RuntimeError("No known speaker embedding method available on tts_model.")
262
+ except Exception as exc:
263
+ # If the model doesn't provide a direct API or something fails, fallback to infer path
264
+ # where infer() might internally compute embedding. In that case we return None to indicate
265
+ # that caller should call infer with spk_audio_prompt.
266
+ raise RuntimeError(f"Failed to compute speaker embedding: {exc}") from exc
267
+
268
+ # Normalize & store on CPU as float32
269
+ if isinstance(emb, torch.Tensor):
270
+ emb_cpu = emb.detach().cpu().float()
271
+ else:
272
+ emb_cpu = torch.tensor(emb, dtype=torch.float32, device="cpu")
273
+
274
+ _cache_set(key, emb_cpu)
275
+ return emb_cpu
276
+
277
+
278
+ # ----------------------------
279
+ # Job helpers
280
+ # ----------------------------
281
  def _set_job(job_id: str, **kwargs):
282
  with JOB_LOCK:
283
  JOBS[job_id] = {**JOBS.get(job_id, {}), **kwargs}
 
295
 
296
 
297
  def _cleanup_files(*files: str):
 
298
  for file_path in files:
299
  if file_path and Path(file_path).exists():
300
  try:
301
  Path(file_path).unlink(missing_ok=True)
302
  except Exception:
303
+ pass
304
 
305
 
306
  def _run_generate_job(job_id: str, payload: Dict[str, str]):
307
+ """
308
+ Worker function that computes (or reuses) embedding and performs TTS.
309
+ """
310
  speaker_file = None
311
  output_file = None
312
  _set_job(job_id, status="processing")
313
  try:
314
+ # prepare speaker audio
315
  speaker_file = _temp_speaker_file(payload["speaker_wav"])
316
+ # preprocess (mono + resample + normalize)
317
+ speaker_file = _preprocess_audio_wav(speaker_file, target_sr=TARGET_SR)
318
+
319
+ # compute or fetch embedding (cached)
320
+ try:
321
+ spk_emb = _compute_spk_embedding(speaker_file)
322
+ use_spk_emb = True
323
+ except Exception as exc_emb:
324
+ # If embedding extraction fails, fall back to passing audio path to infer
325
+ spk_emb = None
326
+ use_spk_emb = False
327
+ print(f"Warning: embedding extraction failed, falling back to audio prompt: {exc_emb}")
328
 
329
+ output_file = os.path.join(tempfile.gettempdir(), f"indextts2-{uuid.uuid4()}.wav")
 
 
 
 
 
 
 
 
330
 
331
+ # Call inference: prefer spk_emb if available.
332
+ infer_kwargs = {
333
+ "text": payload["text"],
334
+ "output_path": output_file,
335
+ "use_random": False,
336
+ "verbose": False,
337
+ }
338
+ # include sample_rate if supported by this wrapper
339
+ try:
340
+ infer_kwargs["sample_rate"] = TARGET_SR
341
+ except Exception:
342
+ pass
343
+
344
+ if use_spk_emb and spk_emb is not None:
345
+ # Use embedding path - many wrappers accept spk_emb or spk_embedding
346
+ try:
347
+ tts_model.infer(spk_emb=spk_emb, **infer_kwargs)
348
+ except TypeError:
349
+ # fallback argument name
350
+ tts_model.infer(speaker_emb=spk_emb, **infer_kwargs)
351
+ else:
352
+ # pass the audio file as prompt (slower, model will compute embedding internally)
353
+ tts_model.infer(spk_audio_prompt=speaker_file, **infer_kwargs)
354
+
355
+ # Minimal validation: ensure file created
356
  if not Path(output_file).exists():
357
+ raise RuntimeError(f"TTS generation failed: output file not created at {output_file}")
358
+
359
+ # Do NOT re-run heavy preprocess; only resample if the model returned a different sr (rare)
360
+ try:
361
+ out_wav, out_sr = torchaudio.load(output_file)
362
+ if out_sr != TARGET_SR:
363
+ resampler = Resample(orig_freq=out_sr, new_freq=TARGET_SR)
364
+ out_wav = resampler(out_wav)
365
+ torchaudio.save(output_file, out_wav, TARGET_SR, bits_per_sample=16)
366
+ except Exception:
367
+ # If this fails, still return the original output file
368
+ pass
369
+
370
+ # cleanup speaker temp (we keep output until client downloads)
371
+ if speaker_file:
372
+ try:
373
+ Path(speaker_file).unlink(missing_ok=True)
374
+ except Exception:
375
+ pass
376
 
 
377
  _set_job(job_id, status="completed", output_file=output_file)
378
  except Exception as exc:
379
  _cleanup_files(speaker_file, output_file)
380
  _set_job(job_id, status="error", error=str(exc))
381
 
382
 
383
+ # ----------------------------
384
+ # FastAPI endpoints
385
+ # ----------------------------
386
+ class GenerateRequest(BaseModel):
387
+ text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
388
+ speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded audio")
389
+ language: Optional[str] = Field(DEFAULT_LANGUAGE, description="ISO code, default en")
390
+
391
+
392
+ def _require_api_key(x_api_key: Optional[str]):
393
+ if not SPACE_API_KEY:
394
+ return
395
+ if x_api_key != SPACE_API_KEY:
396
+ raise HTTPException(status_code=401, detail="Unauthorized")
397
+
398
+
399
  @app.post("/health")
400
  def health(x_api_key: Optional[str] = Header(default=None)):
401
  _require_api_key(x_api_key)
402
+ return {"status": "ok", "model": "indextts2", "device": DEVICE, "torch_threads": torch.get_num_threads()}
403
 
404
 
405
  @app.post("/generate")
 
412
  job_id = str(uuid.uuid4())
413
  _set_job(job_id, status="queued")
414
 
415
+ # Submit to bounded threadpool to avoid uncontrolled concurrency on CPU
416
+ EXECUTOR.submit(_run_generate_job, job_id, payload.dict())
417
 
418
  return JSONResponse(
419
  status_code=202,
 
466
 
467
  @app.get("/")
468
  def root():
469
+ return {"name": "indextts2-api-optimized", "endpoints": ["/health", "/generate", "/status/{job_id}", "/result/{job_id}"]}