ataberkkilavuzcu commited on
Commit
e522f34
·
verified ·
1 Parent(s): 7c7c63c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -320
app.py CHANGED
@@ -1,174 +1,81 @@
1
- import base64
2
- import hashlib
3
  import os
4
- import tempfile
5
  import uuid
6
- import time
7
- from concurrent.futures import ThreadPoolExecutor
8
  from pathlib import Path
9
  from threading import Lock
10
- from typing import Dict, Optional
11
 
12
  import requests
13
  import torch
14
  import torchaudio
15
  from torchaudio.transforms import Resample
16
- from fastapi import BackgroundTasks, Body, FastAPI, Header, HTTPException
17
  from fastapi.responses import FileResponse, JSONResponse
18
  from pydantic import BaseModel, Field, HttpUrl
19
 
20
- # ----------------------------
21
- # Config / Tunables
22
- # ----------------------------
23
  SPACE_API_KEY = os.getenv("SPACE_API_KEY")
24
  HF_TOKEN = (
25
  os.getenv("HUGGING_FACE_HUB_TOKEN")
26
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
27
  or os.getenv("HF_TOKEN")
28
  )
29
- MODEL_REPO = os.getenv("MODEL_REPO", "IndexTeam/IndexTTS-2")
30
  MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2")
31
  os.makedirs(MODEL_DIR, exist_ok=True)
32
 
33
- MAX_TEXT_LENGTH = int(os.getenv("MAX_TEXT_LENGTH", "1000"))
34
- DEFAULT_LANGUAGE = os.getenv("DEFAULT_LANGUAGE", "en")
35
- TARGET_SR = int(os.getenv("TARGET_SR", "16000")) # lowered to 16 kHz for speed
36
- TORCH_NUM_THREADS = int(os.getenv("TORCH_NUM_THREADS", "2"))
37
 
38
- # Embedding cache settings
39
- EMBED_CACHE_MAX = int(os.getenv("EMBED_CACHE_MAX", "128")) # max entries
40
- EMBED_CACHE_TTL = int(os.getenv("EMBED_CACHE_TTL", str(60 * 60 * 24))) # 24h by default
41
 
42
- # Threadpool for bounded parallel jobs (keeps worker threads limited)
43
- WORKER_COUNT = int(os.getenv("WORKER_COUNT", "1")) # keep low on CPU
44
-
45
- # ----------------------------
46
- # Torch settings
47
- # ----------------------------
48
- torch.set_num_threads(TORCH_NUM_THREADS)
49
- try:
50
- # optional: limit interop threads
51
- torch.set_num_interop_threads(max(1, TORCH_NUM_THREADS // 2))
52
- except Exception:
53
- pass
54
 
55
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
56
 
57
- # ----------------------------
58
- # Hugging Face login (if token)
59
- # ----------------------------
60
- if HF_TOKEN:
61
- os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
62
- os.environ["HF_TOKEN"] = HF_TOKEN
63
- try:
64
- from huggingface_hub import login
65
-
66
- login(token=HF_TOKEN, add_to_git_credential=False)
67
- except Exception:
68
- pass
69
 
70
- # ----------------------------
71
- # Optionally download model snapshot (only if missing)
72
- # ----------------------------
73
  try:
74
  from huggingface_hub import snapshot_download
75
-
76
- cfg_path = Path(MODEL_DIR) / "config.yaml"
77
- if not cfg_path.exists():
78
- print(f"Config missing; downloading model snapshot {MODEL_REPO} to {MODEL_DIR} ...")
79
- snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, token=HF_TOKEN)
80
- print("Download complete.")
81
- except Exception as exc:
82
- print(f"Warning: snapshot_download skipped or failed: {exc}")
83
-
84
- # ----------------------------
85
- # Load IndexTTS2 model (CPU mode safe defaults)
86
- # ----------------------------
87
- try:
88
  from indextts.infer_v2 import IndexTTS2
89
- except Exception as exc:
90
- raise RuntimeError("indextts.infer_v2 import failed. Make sure IndexTTS2 is installed.") from exc
91
-
92
- cfg_path = os.path.join(MODEL_DIR, "config.yaml")
93
- if not Path(cfg_path).exists():
94
- raise FileNotFoundError(f"Config file not found at {cfg_path}. Place model files in {MODEL_DIR}.")
95
-
96
- # Use CPU-safe options. If GPU becomes available, you can toggle use_fp16/use_cuda_kernel.
97
- tts_model = IndexTTS2(
98
- cfg_path=cfg_path,
99
- model_dir=MODEL_DIR,
100
- use_fp16=False, # CPU doesn't support FP16 reliably
101
- use_cuda_kernel=False,
102
- use_deepspeed=False,
103
- )
104
- print("IndexTTS2 loaded.")
105
 
106
- # ----------------------------
107
- # App + job state
108
- # ----------------------------
109
- app = FastAPI(title="indextts2-api-optimized", version="1.0.0")
110
  JOBS: Dict[str, Dict[str, str]] = {}
111
  JOB_LOCK = Lock()
112
 
113
- # Threadpool for running TTS jobs; limits concurrency to WORKER_COUNT
114
- EXECUTOR = ThreadPoolExecutor(max_workers=WORKER_COUNT)
115
-
116
- # ----------------------------
117
- # Simple LRU-like embedding cache (in-memory)
118
- # ----------------------------
119
- class _EmbedCacheEntry:
120
- def __init__(self, emb_tensor: torch.Tensor):
121
- self.emb = emb_tensor.detach().cpu() # keep on CPU, detached
122
- self.ts = time.time()
123
-
124
-
125
- EMBED_CACHE: Dict[str, _EmbedCacheEntry] = {}
126
- EMBED_CACHE_LOCK = Lock()
127
-
128
-
129
- def _evict_cache_if_needed():
130
- with EMBED_CACHE_LOCK:
131
- if len(EMBED_CACHE) <= EMBED_CACHE_MAX:
132
- return
133
- # Simple eviction: remove oldest entries
134
- items = sorted(EMBED_CACHE.items(), key=lambda kv: kv[1].ts)
135
- for key, _ in items[: max(1, len(items) - EMBED_CACHE_MAX)]:
136
- EMBED_CACHE.pop(key, None)
137
-
138
-
139
- def _get_cache_key_for_file(path: str) -> str:
140
- # Hash the file contents (fast enough for short audio)
141
- h = hashlib.sha256()
142
- with open(path, "rb") as f:
143
- while True:
144
- chunk = f.read(8192)
145
- if not chunk:
146
- break
147
- h.update(chunk)
148
- return h.hexdigest()
149
-
150
-
151
- def _cache_get(key: str) -> Optional[torch.Tensor]:
152
- with EMBED_CACHE_LOCK:
153
- entry = EMBED_CACHE.get(key)
154
- if not entry:
155
- return None
156
- if (time.time() - entry.ts) > EMBED_CACHE_TTL:
157
- EMBED_CACHE.pop(key, None)
158
- return None
159
- # update timestamp for LRU-ish behavior
160
- entry.ts = time.time()
161
- return entry.emb.clone()
162
-
163
-
164
- def _cache_set(key: str, emb: torch.Tensor):
165
- with EMBED_CACHE_LOCK:
166
- EMBED_CACHE[key] = _EmbedCacheEntry(emb)
167
- _evict_cache_if_needed()
168
-
169
- # ----------------------------
170
- # Utilities for audio input handling
171
- # ----------------------------
172
  def _write_temp_audio_from_url(url: HttpUrl) -> str:
173
  response = requests.get(url, stream=True, timeout=30)
174
  if response.status_code >= 400:
@@ -193,91 +100,25 @@ def _write_temp_audio_from_base64(payload: str) -> str:
193
 
194
  def _temp_speaker_file(speaker_wav: str) -> str:
195
  if speaker_wav.startswith("http://") or speaker_wav.startswith("https://"):
196
- return _write_temp_audio_from_url(HttpUrl(speaker_wav))
197
  return _write_temp_audio_from_base64(speaker_wav)
198
 
199
 
200
  def _preprocess_audio_wav(path: str, target_sr: int = TARGET_SR, target_peak: float = 0.98) -> str:
201
- """
202
- Convert to mono, resample to target_sr, and peak-normalize.
203
- Overwrites the input file.
204
- """
205
  wav, sr = torchaudio.load(path)
206
-
207
- # Mono
208
  if wav.shape[0] > 1:
209
  wav = wav.mean(dim=0, keepdim=True)
210
-
211
- # Resample if needed
212
  if sr != target_sr:
213
  resampler = Resample(orig_freq=sr, new_freq=target_sr)
214
  wav = resampler(wav)
215
  sr = target_sr
216
-
217
- # Peak normalize
218
  peak = wav.abs().max().item() if wav.numel() else 0.0
219
  if peak > 0:
220
- scale = min(target_peak / peak, 1.0)
221
- wav = wav * scale
222
-
223
- # Overwrite file in 16-bit PCM
224
  torchaudio.save(path, wav, sr, bits_per_sample=16)
225
  return path
226
 
227
 
228
- # ----------------------------
229
- # Embedding extraction helper (tries multiple API variants)
230
- # ----------------------------
231
- def _compute_spk_embedding(speaker_path: str) -> torch.Tensor:
232
- """
233
- Returns a CPU tensor containing the speaker embedding.
234
- Tries multiple methods to extract embedding (get_spk_emb, extract_spk_emb, etc.)
235
- """
236
- # Key: use hash of file contents
237
- key = _get_cache_key_for_file(speaker_path)
238
- cached = _cache_get(key)
239
- if cached is not None:
240
- return cached
241
-
242
- # Ensure audio preprocessed (mono/resample/normalize)
243
- _preprocess_audio_wav(speaker_path, target_sr=TARGET_SR)
244
-
245
- # Try known wrapper method names (depending on IndexTTS2 version)
246
- emb = None
247
- try:
248
- if hasattr(tts_model, "get_spk_emb"):
249
- emb = tts_model.get_spk_emb(speaker_path)
250
- elif hasattr(tts_model, "extract_spk_emb"):
251
- emb = tts_model.extract_spk_emb(speaker_path)
252
- elif hasattr(tts_model, "spk_encoder") and hasattr(tts_model.spk_encoder, "embed_utterance"):
253
- # some wrappers expose internal encoders
254
- wav, sr = torchaudio.load(speaker_path)
255
- if wav.shape[0] > 1:
256
- wav = wav.mean(dim=0, keepdim=True)
257
- wav = wav.squeeze(0).numpy() # expected shape for some encoders
258
- emb = tts_model.spk_encoder.embed_utterance(wav)
259
- emb = torch.from_numpy(emb)
260
- else:
261
- raise RuntimeError("No known speaker embedding method available on tts_model.")
262
- except Exception as exc:
263
- # If the model doesn't provide a direct API or something fails, fallback to infer path
264
- # where infer() might internally compute embedding. In that case we return None to indicate
265
- # that caller should call infer with spk_audio_prompt.
266
- raise RuntimeError(f"Failed to compute speaker embedding: {exc}") from exc
267
-
268
- # Normalize & store on CPU as float32
269
- if isinstance(emb, torch.Tensor):
270
- emb_cpu = emb.detach().cpu().float()
271
- else:
272
- emb_cpu = torch.tensor(emb, dtype=torch.float32, device="cpu")
273
-
274
- _cache_set(key, emb_cpu)
275
- return emb_cpu
276
-
277
-
278
- # ----------------------------
279
- # Job helpers
280
- # ----------------------------
281
  def _set_job(job_id: str, **kwargs):
282
  with JOB_LOCK:
283
  JOBS[job_id] = {**JOBS.get(job_id, {}), **kwargs}
@@ -286,7 +127,7 @@ def _set_job(job_id: str, **kwargs):
286
  def _get_job(job_id: str) -> Optional[Dict[str, str]]:
287
  with JOB_LOCK:
288
  data = JOBS.get(job_id)
289
- return dict(data) if data else None
290
 
291
 
292
  def _pop_job(job_id: str) -> Optional[Dict[str, str]]:
@@ -294,114 +135,42 @@ def _pop_job(job_id: str) -> Optional[Dict[str, str]]:
294
  return JOBS.pop(job_id, None)
295
 
296
 
297
- def _cleanup_files(*files: str):
298
- for file_path in files:
299
- if file_path and Path(file_path).exists():
300
- try:
301
- Path(file_path).unlink(missing_ok=True)
302
- except Exception:
303
- pass
304
 
305
 
306
  def _run_generate_job(job_id: str, payload: Dict[str, str]):
307
- """
308
- Worker function that computes (or reuses) embedding and performs TTS.
309
- """
310
  speaker_file = None
311
  output_file = None
312
  _set_job(job_id, status="processing")
313
  try:
314
- # prepare speaker audio
315
  speaker_file = _temp_speaker_file(payload["speaker_wav"])
316
- # preprocess (mono + resample + normalize)
317
  speaker_file = _preprocess_audio_wav(speaker_file, target_sr=TARGET_SR)
318
 
319
- # compute or fetch embedding (cached)
320
- try:
321
- spk_emb = _compute_spk_embedding(speaker_file)
322
- use_spk_emb = True
323
- except Exception as exc_emb:
324
- # If embedding extraction fails, fall back to passing audio path to infer
325
- spk_emb = None
326
- use_spk_emb = False
327
- print(f"Warning: embedding extraction failed, falling back to audio prompt: {exc_emb}")
328
-
329
  output_file = os.path.join(tempfile.gettempdir(), f"indextts2-{uuid.uuid4()}.wav")
330
 
331
- # Call inference: prefer spk_emb if available.
332
- infer_kwargs = {
333
- "text": payload["text"],
334
- "output_path": output_file,
335
- "use_random": False,
336
- "verbose": False,
337
- }
338
- # include sample_rate if supported by this wrapper
339
- try:
340
- infer_kwargs["sample_rate"] = TARGET_SR
341
- except Exception:
342
- pass
343
 
344
- if use_spk_emb and spk_emb is not None:
345
- # Use embedding path - many wrappers accept spk_emb or spk_embedding
346
- try:
347
- tts_model.infer(spk_emb=spk_emb, **infer_kwargs)
348
- except TypeError:
349
- # fallback argument name
350
- tts_model.infer(speaker_emb=spk_emb, **infer_kwargs)
351
- else:
352
- # pass the audio file as prompt (slower, model will compute embedding internally)
353
- tts_model.infer(spk_audio_prompt=speaker_file, **infer_kwargs)
354
-
355
- # Minimal validation: ensure file created
356
  if not Path(output_file).exists():
357
- raise RuntimeError(f"TTS generation failed: output file not created at {output_file}")
358
-
359
- # Do NOT re-run heavy preprocess; only resample if the model returned a different sr (rare)
360
- try:
361
- out_wav, out_sr = torchaudio.load(output_file)
362
- if out_sr != TARGET_SR:
363
- resampler = Resample(orig_freq=out_sr, new_freq=TARGET_SR)
364
- out_wav = resampler(out_wav)
365
- torchaudio.save(output_file, out_wav, TARGET_SR, bits_per_sample=16)
366
- except Exception:
367
- # If this fails, still return the original output file
368
- pass
369
-
370
- # cleanup speaker temp (we keep output until client downloads)
371
- if speaker_file:
372
- try:
373
- Path(speaker_file).unlink(missing_ok=True)
374
- except Exception:
375
- pass
376
 
377
  _set_job(job_id, status="completed", output_file=output_file)
378
  except Exception as exc:
379
- _cleanup_files(speaker_file, output_file)
380
  _set_job(job_id, status="error", error=str(exc))
381
 
382
 
383
- # ----------------------------
384
- # FastAPI endpoints
385
- # ----------------------------
386
- class GenerateRequest(BaseModel):
387
- text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
388
- speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded audio")
389
- language: Optional[str] = Field(DEFAULT_LANGUAGE, description="ISO code, default en")
390
-
391
-
392
- def _require_api_key(x_api_key: Optional[str]):
393
- if not SPACE_API_KEY:
394
- return
395
- if x_api_key != SPACE_API_KEY:
396
- raise HTTPException(status_code=401, detail="Unauthorized")
397
-
398
-
399
- @app.post("/health")
400
- def health(x_api_key: Optional[str] = Header(default=None)):
401
- _require_api_key(x_api_key)
402
- return {"status": "ok", "model": "indextts2", "device": DEVICE, "torch_threads": torch.get_num_threads()}
403
-
404
-
405
  @app.post("/generate")
406
  def generate(
407
  payload: GenerateRequest = Body(...),
@@ -412,8 +181,7 @@ def generate(
412
  job_id = str(uuid.uuid4())
413
  _set_job(job_id, status="queued")
414
 
415
- # Submit to bounded threadpool to avoid uncontrolled concurrency on CPU
416
- EXECUTOR.submit(_run_generate_job, job_id, payload.dict())
417
 
418
  return JSONResponse(
419
  status_code=202,
@@ -427,43 +195,31 @@ def generate(
427
 
428
 
429
  @app.get("/status/{job_id}")
430
- def job_status(job_id: str, x_api_key: Optional[str] = Header(default=None)):
431
  _require_api_key(x_api_key)
432
  job = _get_job(job_id)
433
  if not job:
434
  raise HTTPException(status_code=404, detail="Job not found")
435
- payload: Dict[str, str] = {"job_id": job_id, "status": job.get("status", "unknown")}
436
  if "error" in job:
437
- payload["error"] = job["error"]
438
- return payload
439
 
440
 
441
  @app.get("/result/{job_id}")
442
- def job_result(
443
- job_id: str,
444
- background_tasks: BackgroundTasks = BackgroundTasks(),
445
- x_api_key: Optional[str] = Header(default=None),
446
- ):
447
  _require_api_key(x_api_key)
448
  job = _get_job(job_id)
449
  if not job:
450
  raise HTTPException(status_code=404, detail="Job not found")
451
- status = job.get("status")
452
- if status != "completed":
453
- raise HTTPException(status_code=409, detail=f"Job not ready (status={status})")
454
-
455
  output_file = job.get("output_file")
456
  if not output_file or not Path(output_file).exists():
457
  _pop_job(job_id)
458
- raise HTTPException(status_code=410, detail="Result expired or missing")
459
-
460
- # Remove job from memory and cleanup output after sending
 
461
  _pop_job(job_id)
462
- background_tasks.add_task(_cleanup_files, output_file)
463
-
464
- return FileResponse(output_file, media_type="audio/wav", filename="output.wav")
465
-
466
-
467
- @app.get("/")
468
- def root():
469
- return {"name": "indextts2-api-optimized", "endpoints": ["/health", "/generate", "/status/{job_id}", "/result/{job_id}"]}
 
 
 
1
  import os
 
2
  import uuid
3
+ import tempfile
4
+ import base64
5
  from pathlib import Path
6
  from threading import Lock
7
+ from typing import Optional, Dict
8
 
9
  import requests
10
  import torch
11
  import torchaudio
12
  from torchaudio.transforms import Resample
13
+ from fastapi import FastAPI, Body, Header, HTTPException, BackgroundTasks
14
  from fastapi.responses import FileResponse, JSONResponse
15
  from pydantic import BaseModel, Field, HttpUrl
16
 
17
+ # ========== Configuration ==========
18
+
 
19
  SPACE_API_KEY = os.getenv("SPACE_API_KEY")
20
  HF_TOKEN = (
21
  os.getenv("HUGGING_FACE_HUB_TOKEN")
22
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
23
  or os.getenv("HF_TOKEN")
24
  )
25
+ MODEL_REPO = "IndexTeam/IndexTTS-2"
26
  MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2")
27
  os.makedirs(MODEL_DIR, exist_ok=True)
28
 
29
+ # Max length for input text
30
+ MAX_TEXT_LENGTH = 1000
 
 
31
 
32
+ # Use 16 kHz sample rate for faster/audio-size tradeoff
33
+ TARGET_SR = 16000
 
34
 
35
+ # Limit PyTorch threads on CPU
36
+ torch.set_num_threads(1)
 
 
 
 
 
 
 
 
 
 
37
 
38
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
 
40
+ # ========== Download / Load Model ==========
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
42
  try:
43
  from huggingface_hub import snapshot_download
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  from indextts.infer_v2 import IndexTTS2
45
+ except Exception as e:
46
+ raise RuntimeError("Required library missing: ensure `huggingface_hub` and `indextts` are installed.") from e
47
+
48
+ # Only download if not already present
49
+ config_file = Path(MODEL_DIR) / "config.yaml"
50
+ if not config_file.exists():
51
+ print(f"Downloading model {MODEL_REPO} to {MODEL_DIR} …")
52
+ snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, token=HF_TOKEN)
53
+ print("Download complete.")
54
+
55
+ tts_model = IndexTTS2(cfg_path=str(config_file), model_dir=MODEL_DIR, use_fp16=False, use_cuda_kernel=False, use_deepspeed=False)
56
+ print("IndexTTS-2 loaded, device:", DEVICE)
57
+
58
+ # ========== FastAPI app ==========
59
+
60
+ app = FastAPI(title="IndexTTS2 API")
61
 
 
 
 
 
62
  JOBS: Dict[str, Dict[str, str]] = {}
63
  JOB_LOCK = Lock()
64
 
65
+
66
+ class GenerateRequest(BaseModel):
67
+ text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
68
+ speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded audio")
69
+ language: Optional[str] = Field("en", description="Language code")
70
+
71
+
72
+ def _require_api_key(x_api_key: Optional[str]):
73
+ if not SPACE_API_KEY:
74
+ return
75
+ if x_api_key != SPACE_API_KEY:
76
+ raise HTTPException(status_code=401, detail="Unauthorized")
77
+
78
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def _write_temp_audio_from_url(url: HttpUrl) -> str:
80
  response = requests.get(url, stream=True, timeout=30)
81
  if response.status_code >= 400:
 
100
 
101
  def _temp_speaker_file(speaker_wav: str) -> str:
102
  if speaker_wav.startswith("http://") or speaker_wav.startswith("https://"):
103
+ return _write_temp_audio_from_url(speaker_wav)
104
  return _write_temp_audio_from_base64(speaker_wav)
105
 
106
 
107
  def _preprocess_audio_wav(path: str, target_sr: int = TARGET_SR, target_peak: float = 0.98) -> str:
 
 
 
 
108
  wav, sr = torchaudio.load(path)
 
 
109
  if wav.shape[0] > 1:
110
  wav = wav.mean(dim=0, keepdim=True)
 
 
111
  if sr != target_sr:
112
  resampler = Resample(orig_freq=sr, new_freq=target_sr)
113
  wav = resampler(wav)
114
  sr = target_sr
 
 
115
  peak = wav.abs().max().item() if wav.numel() else 0.0
116
  if peak > 0:
117
+ wav = wav * (target_peak / peak)
 
 
 
118
  torchaudio.save(path, wav, sr, bits_per_sample=16)
119
  return path
120
 
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def _set_job(job_id: str, **kwargs):
123
  with JOB_LOCK:
124
  JOBS[job_id] = {**JOBS.get(job_id, {}), **kwargs}
 
127
  def _get_job(job_id: str) -> Optional[Dict[str, str]]:
128
  with JOB_LOCK:
129
  data = JOBS.get(job_id)
130
+ return dict(data) if data else None
131
 
132
 
133
  def _pop_job(job_id: str) -> Optional[Dict[str, str]]:
 
135
  return JOBS.pop(job_id, None)
136
 
137
 
138
+ def _cleanup_files(*paths: str):
139
+ for p in paths:
140
+ try:
141
+ os.remove(p)
142
+ except OSError:
143
+ pass
 
144
 
145
 
146
  def _run_generate_job(job_id: str, payload: Dict[str, str]):
 
 
 
147
  speaker_file = None
148
  output_file = None
149
  _set_job(job_id, status="processing")
150
  try:
 
151
  speaker_file = _temp_speaker_file(payload["speaker_wav"])
 
152
  speaker_file = _preprocess_audio_wav(speaker_file, target_sr=TARGET_SR)
153
 
 
 
 
 
 
 
 
 
 
 
154
  output_file = os.path.join(tempfile.gettempdir(), f"indextts2-{uuid.uuid4()}.wav")
155
 
156
+ # Use spk_audio_prompt this model requires audio prompt
157
+ tts_model.infer(
158
+ text=payload["text"],
159
+ spk_audio_prompt=speaker_file,
160
+ output_path=output_file,
161
+ use_random=False,
162
+ verbose=False,
163
+ )
 
 
 
 
164
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  if not Path(output_file).exists():
166
+ raise RuntimeError(f"TTS generation failed output file not created.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  _set_job(job_id, status="completed", output_file=output_file)
169
  except Exception as exc:
170
+ _cleanup_files(speaker_file or "", output_file or "")
171
  _set_job(job_id, status="error", error=str(exc))
172
 
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  @app.post("/generate")
175
  def generate(
176
  payload: GenerateRequest = Body(...),
 
181
  job_id = str(uuid.uuid4())
182
  _set_job(job_id, status="queued")
183
 
184
+ background_tasks.add_task(_run_generate_job, job_id, payload.dict())
 
185
 
186
  return JSONResponse(
187
  status_code=202,
 
195
 
196
 
197
  @app.get("/status/{job_id}")
198
+ def status(job_id: str, x_api_key: Optional[str] = Header(default=None)):
199
  _require_api_key(x_api_key)
200
  job = _get_job(job_id)
201
  if not job:
202
  raise HTTPException(status_code=404, detail="Job not found")
203
+ resp = {"job_id": job_id, "status": job.get("status", "unknown")}
204
  if "error" in job:
205
+ resp["error"] = job["error"]
206
+ return resp
207
 
208
 
209
  @app.get("/result/{job_id}")
210
+ def result(job_id: str, x_api_key: Optional[str] = Header(default=None)):
 
 
 
 
211
  _require_api_key(x_api_key)
212
  job = _get_job(job_id)
213
  if not job:
214
  raise HTTPException(status_code=404, detail="Job not found")
215
+ if job.get("status") != "completed":
216
+ raise HTTPException(status_code=409, detail=f"Job not ready (status={job.get('status')})")
 
 
217
  output_file = job.get("output_file")
218
  if not output_file or not Path(output_file).exists():
219
  _pop_job(job_id)
220
+ raise HTTPException(status_code=410, detail="Result missing or expired")
221
+ # cleanup after sending
222
+ background = BackgroundTasks()
223
+ background.add_task(_cleanup_files, output_file)
224
  _pop_job(job_id)
225
+ return FileResponse(output_file, media_type="audio/wav", filename="output.wav", background=background)