ataberkkilavuzcu commited on
Commit
f9f777d
·
verified ·
1 Parent(s): 54966d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -38
app.py CHANGED
@@ -3,7 +3,8 @@ import os
3
  import tempfile
4
  import uuid
5
  from pathlib import Path
6
- from typing import Optional
 
7
 
8
  import requests
9
  import torch
@@ -24,6 +25,8 @@ MAX_TEXT_LENGTH = 1000
24
  DEFAULT_LANGUAGE = "en"
25
 
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
27
 
28
  # Set token in environment before importing
29
  if HF_TOKEN:
@@ -148,10 +151,20 @@ def _preprocess_audio_wav(path: str, target_sr: int = 24000, target_peak: float
148
  return path
149
 
150
 
151
- @app.post("/health")
152
- def health(x_api_key: Optional[str] = Header(default=None)):
153
- _require_api_key(x_api_key)
154
- return {"status": "ok", "model": "indextts2", "device": DEVICE}
 
 
 
 
 
 
 
 
 
 
155
 
156
 
157
  def _cleanup_files(*files: str):
@@ -164,59 +177,106 @@ def _cleanup_files(*files: str):
164
  pass # Ignore cleanup errors
165
 
166
 
167
- @app.post("/generate")
168
- def generate(
169
- payload: GenerateRequest = Body(...),
170
- background_tasks: BackgroundTasks = BackgroundTasks(),
171
- x_api_key: Optional[str] = Header(default=None),
172
- ):
173
- _require_api_key(x_api_key)
174
-
175
  speaker_file = None
176
  output_file = None
177
-
178
  try:
179
- speaker_file = _temp_speaker_file(payload.speaker_wav)
180
  speaker_file = _preprocess_audio_wav(speaker_file)
181
  output_file = os.path.join(tempfile.gettempdir(), f"indextts2-{uuid.uuid4()}.wav")
182
 
183
- # IndexTTS2 inference
184
- # Note: language parameter is kept for API compatibility but IndexTTS2
185
- # handles multilingual automatically (supports English, Turkish, Chinese, etc.)
186
  tts_model.infer(
187
  spk_audio_prompt=speaker_file,
188
- text=payload.text,
189
  output_path=output_file,
190
- use_random=False, # Deterministic output
191
  verbose=False,
192
  )
193
 
194
- # Light post-process to avoid end-of-file artifacts
195
  output_file = _preprocess_audio_wav(output_file)
196
 
197
- # Verify the output file was created
198
  if not Path(output_file).exists():
199
  raise RuntimeError(f"TTS generation failed: output file was not created at {output_file}")
200
 
201
- # Schedule cleanup after response is sent
202
- background_tasks.add_task(_cleanup_files, speaker_file, output_file)
 
 
 
203
 
204
- return FileResponse(output_file, media_type="audio/wav", filename="output.wav")
205
 
206
- except HTTPException:
207
- # Clean up on HTTPException
208
- if speaker_file and Path(speaker_file).exists():
209
- Path(speaker_file).unlink(missing_ok=True)
210
- raise
211
- except Exception as exc: # pragma: no cover
212
- # Clean up on error
213
- if speaker_file and Path(speaker_file).exists():
214
- Path(speaker_file).unlink(missing_ok=True)
215
- if output_file and Path(output_file).exists():
216
- Path(output_file).unlink(missing_ok=True)
217
- return JSONResponse(status_code=500, content={"error": str(exc)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
 
220
  @app.get("/")
221
  def root():
222
- return {"name": "indextts2-api", "endpoints": ["/health", "/generate"]}
 
 
 
 
3
  import tempfile
4
  import uuid
5
  from pathlib import Path
6
+ from threading import Lock
7
+ from typing import Dict, Optional
8
 
9
  import requests
10
  import torch
 
25
  DEFAULT_LANGUAGE = "en"
26
 
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
+ JOBS: Dict[str, Dict[str, str]] = {}
29
+ JOB_LOCK = Lock()
30
 
31
  # Set token in environment before importing
32
  if HF_TOKEN:
 
151
  return path
152
 
153
 
154
+ def _set_job(job_id: str, **kwargs):
155
+ with JOB_LOCK:
156
+ JOBS[job_id] = {**JOBS.get(job_id, {}), **kwargs}
157
+
158
+
159
+ def _get_job(job_id: str) -> Optional[Dict[str, str]]:
160
+ with JOB_LOCK:
161
+ data = JOBS.get(job_id)
162
+ return dict(data) if data else None
163
+
164
+
165
+ def _pop_job(job_id: str) -> Optional[Dict[str, str]]:
166
+ with JOB_LOCK:
167
+ return JOBS.pop(job_id, None)
168
 
169
 
170
  def _cleanup_files(*files: str):
 
177
  pass # Ignore cleanup errors
178
 
179
 
180
+ def _run_generate_job(job_id: str, payload: Dict[str, str]):
 
 
 
 
 
 
 
181
  speaker_file = None
182
  output_file = None
183
+ _set_job(job_id, status="processing")
184
  try:
185
+ speaker_file = _temp_speaker_file(payload["speaker_wav"])
186
  speaker_file = _preprocess_audio_wav(speaker_file)
187
  output_file = os.path.join(tempfile.gettempdir(), f"indextts2-{uuid.uuid4()}.wav")
188
 
 
 
 
189
  tts_model.infer(
190
  spk_audio_prompt=speaker_file,
191
+ text=payload["text"],
192
  output_path=output_file,
193
+ use_random=False,
194
  verbose=False,
195
  )
196
 
 
197
  output_file = _preprocess_audio_wav(output_file)
198
 
 
199
  if not Path(output_file).exists():
200
  raise RuntimeError(f"TTS generation failed: output file was not created at {output_file}")
201
 
202
+ _cleanup_files(speaker_file)
203
+ _set_job(job_id, status="completed", output_file=output_file)
204
+ except Exception as exc:
205
+ _cleanup_files(speaker_file, output_file)
206
+ _set_job(job_id, status="error", error=str(exc))
207
 
 
208
 
209
+ @app.post("/health")
210
+ def health(x_api_key: Optional[str] = Header(default=None)):
211
+ _require_api_key(x_api_key)
212
+ return {"status": "ok", "model": "indextts2", "device": DEVICE}
213
+
214
+
215
+ @app.post("/generate")
216
+ def generate(
217
+ payload: GenerateRequest = Body(...),
218
+ background_tasks: BackgroundTasks = BackgroundTasks(),
219
+ x_api_key: Optional[str] = Header(default=None),
220
+ ):
221
+ _require_api_key(x_api_key)
222
+ job_id = str(uuid.uuid4())
223
+ _set_job(job_id, status="queued")
224
+
225
+ # Offload the long-running synthesis so the HTTP request stays fast (<100s)
226
+ background_tasks.add_task(_run_generate_job, job_id, payload.dict())
227
+
228
+ return JSONResponse(
229
+ status_code=202,
230
+ content={
231
+ "job_id": job_id,
232
+ "status": "queued",
233
+ "status_url": f"/status/{job_id}",
234
+ "result_url": f"/result/{job_id}",
235
+ },
236
+ )
237
+
238
+
239
+ @app.get("/status/{job_id}")
240
+ def job_status(job_id: str, x_api_key: Optional[str] = Header(default=None)):
241
+ _require_api_key(x_api_key)
242
+ job = _get_job(job_id)
243
+ if not job:
244
+ raise HTTPException(status_code=404, detail="Job not found")
245
+ payload: Dict[str, str] = {"job_id": job_id, "status": job.get("status", "unknown")}
246
+ if "error" in job:
247
+ payload["error"] = job["error"]
248
+ return payload
249
+
250
+
251
+ @app.get("/result/{job_id}")
252
+ def job_result(
253
+ job_id: str,
254
+ background_tasks: BackgroundTasks = BackgroundTasks(),
255
+ x_api_key: Optional[str] = Header(default=None),
256
+ ):
257
+ _require_api_key(x_api_key)
258
+ job = _get_job(job_id)
259
+ if not job:
260
+ raise HTTPException(status_code=404, detail="Job not found")
261
+ status = job.get("status")
262
+ if status != "completed":
263
+ raise HTTPException(status_code=409, detail=f"Job not ready (status={status})")
264
+
265
+ output_file = job.get("output_file")
266
+ if not output_file or not Path(output_file).exists():
267
+ _pop_job(job_id)
268
+ raise HTTPException(status_code=410, detail="Result expired or missing")
269
+
270
+ # Remove job from memory and cleanup output after sending
271
+ _pop_job(job_id)
272
+ background_tasks.add_task(_cleanup_files, output_file)
273
+
274
+ return FileResponse(output_file, media_type="audio/wav", filename="output.wav")
275
 
276
 
277
  @app.get("/")
278
  def root():
279
+ return {
280
+ "name": "indextts2-api",
281
+ "endpoints": ["/health", "/generate", "/status/{job_id}", "/result/{job_id}"],
282
+ }