ataberkkilavuzcu commited on
Commit
16e9700
·
verified ·
1 Parent(s): 8d3e172

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -17
app.py CHANGED
@@ -5,6 +5,7 @@ import uuid
5
  from pathlib import Path
6
  from threading import Lock
7
  from typing import Dict, Optional
 
8
 
9
  import requests
10
  import torch
@@ -28,11 +29,16 @@ MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2")
28
  MAX_TEXT_LENGTH = 1000
29
  DEFAULT_LANGUAGE = "en"
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
31
 
32
  # Job management
33
  JOBS: Dict[str, Dict[str, str]] = {}
34
  JOB_LOCK = Lock()
35
 
 
 
 
 
36
  # Set token in environment before importing
37
  if HF_TOKEN:
38
  os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
@@ -45,7 +51,6 @@ if HF_TOKEN:
45
 
46
  # Download model checkpoints from Hugging Face
47
  os.makedirs(MODEL_DIR, exist_ok=True)
48
-
49
  try:
50
  from huggingface_hub import snapshot_download
51
 
@@ -62,7 +67,7 @@ except Exception as exc:
62
  print(f"Warning: Could not download model: {exc}")
63
  # Continue anyway - model might already be present
64
 
65
- # Initialize IndexTTS2
66
  try:
67
  from indextts.infer_v2 import IndexTTS2
68
 
@@ -72,14 +77,56 @@ try:
72
  f"Config file not found at {cfg_path}. Model may not be downloaded."
73
  )
74
 
 
 
 
75
  tts_model = IndexTTS2(
76
  cfg_path=cfg_path,
77
  model_dir=MODEL_DIR,
78
- use_fp16=False, # CPU doesn't support FP16
79
- use_cuda_kernel=False, # CPU mode
80
- use_deepspeed=False, # CPU mode
81
  )
82
- print("IndexTTS2 model loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  except Exception as exc:
84
  raise RuntimeError(f"Failed to load IndexTTS2 model: {exc}") from exc
85
 
@@ -102,8 +149,8 @@ def _require_api_key(x_api_key: Optional[str]):
102
 
103
 
104
  def _write_temp_audio_from_url(url: HttpUrl) -> str:
105
- """Download audio from URL to temporary file."""
106
- response = requests.get(url, stream=True, timeout=30)
107
  if response.status_code >= 400:
108
  raise HTTPException(
109
  status_code=400,
@@ -150,6 +197,8 @@ def _preprocess_audio_wav(
150
  - convert to mono
151
  - resample to target_sr
152
  - peak-normalize to target_peak (avoid clipping)
 
 
153
  """
154
  wav, sr = torchaudio.load(path)
155
 
@@ -204,28 +253,37 @@ def _cleanup_files(*files: str):
204
 
205
 
206
  def _run_generate_job(job_id: str, payload: Dict[str, str]):
207
- """Background job for TTS generation."""
208
  speaker_file = None
209
  output_file = None
210
  _set_job(job_id, status="processing")
211
 
212
  try:
 
 
 
213
  speaker_file = _temp_speaker_file(payload["speaker_wav"])
214
  speaker_file = _preprocess_audio_wav(speaker_file)
 
215
 
216
  output_file = os.path.join(
217
  tempfile.gettempdir(),
218
  f"indextts2-{uuid.uuid4()}.wav"
219
  )
220
 
221
- tts_model.infer(
222
- spk_audio_prompt=speaker_file,
223
- text=payload["text"],
224
- output_path=output_file,
225
- use_random=False,
226
- verbose=False,
227
- )
 
 
 
 
228
 
 
229
  output_file = _preprocess_audio_wav(output_file)
230
 
231
  if not Path(output_file).exists():
@@ -233,9 +291,13 @@ def _run_generate_job(job_id: str, payload: Dict[str, str]):
233
  f"TTS generation failed: output file was not created at {output_file}"
234
  )
235
 
 
 
 
236
  _cleanup_files(speaker_file)
237
  _set_job(job_id, status="completed", output_file=output_file)
238
  except Exception as exc:
 
239
  _cleanup_files(speaker_file, output_file)
240
  _set_job(job_id, status="error", error=str(exc))
241
 
@@ -244,7 +306,13 @@ def _run_generate_job(job_id: str, payload: Dict[str, str]):
244
  def health(x_api_key: Optional[str] = Header(default=None)):
245
  """Health check endpoint."""
246
  _require_api_key(x_api_key)
247
- return {"status": "ok", "model": "indextts2", "device": DEVICE}
 
 
 
 
 
 
248
 
249
 
250
  @app.post("/generate")
@@ -337,6 +405,8 @@ def root():
337
  """API root with available endpoints."""
338
  return {
339
  "name": "indextts2-api",
 
 
340
  "endpoints": [
341
  "/health",
342
  "/generate",
 
5
  from pathlib import Path
6
  from threading import Lock
7
  from typing import Dict, Optional
8
+ import time
9
 
10
  import requests
11
  import torch
 
29
  MAX_TEXT_LENGTH = 1000
30
  DEFAULT_LANGUAGE = "en"
31
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
+ USE_GPU = DEVICE == "cuda"
33
 
34
  # Job management
35
  JOBS: Dict[str, Dict[str, str]] = {}
36
  JOB_LOCK = Lock()
37
 
38
+ # Connection pooling for faster URL downloads
39
+ HTTP_SESSION = requests.Session()
40
+ HTTP_SESSION.headers.update({"User-Agent": "IndexTTS2-API/1.0"})
41
+
42
  # Set token in environment before importing
43
  if HF_TOKEN:
44
  os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
 
51
 
52
  # Download model checkpoints from Hugging Face
53
  os.makedirs(MODEL_DIR, exist_ok=True)
 
54
  try:
55
  from huggingface_hub import snapshot_download
56
 
 
67
  print(f"Warning: Could not download model: {exc}")
68
  # Continue anyway - model might already be present
69
 
70
+ # Initialize IndexTTS2 with GPU optimizations if available
71
  try:
72
  from indextts.infer_v2 import IndexTTS2
73
 
 
77
  f"Config file not found at {cfg_path}. Model may not be downloaded."
78
  )
79
 
80
+ print(f"Loading IndexTTS2 model on {DEVICE}...")
81
+ load_start = time.time()
82
+
83
  tts_model = IndexTTS2(
84
  cfg_path=cfg_path,
85
  model_dir=MODEL_DIR,
86
+ use_fp16=USE_GPU, # Enable FP16 on GPU for ~30-40% speedup
87
+ use_cuda_kernel=USE_GPU, # Enable CUDA kernels on GPU
88
+ use_deepspeed=False, # Keep disabled for stability
89
  )
90
+
91
+ load_time = time.time() - load_start
92
+ print(f"IndexTTS2 model loaded successfully in {load_time:.2f}s on {DEVICE}")
93
+
94
+ # Warmup inference to initialize all model components
95
+ # This moves the initialization cost from first request to startup
96
+ print("Running warmup inference...")
97
+ warmup_start = time.time()
98
+ try:
99
+ # Create a minimal warmup audio file
100
+ warmup_audio_path = os.path.join(tempfile.gettempdir(), "warmup.wav")
101
+ warmup_output_path = os.path.join(tempfile.gettempdir(), "warmup_out.wav")
102
+
103
+ # Generate a short sine wave for warmup (1 second at 24kHz)
104
+ sample_rate = 24000
105
+ duration = 1.0
106
+ t = torch.linspace(0, duration, int(sample_rate * duration))
107
+ warmup_wav = (0.5 * torch.sin(2 * 3.14159 * 440 * t)).unsqueeze(0)
108
+ torchaudio.save(warmup_audio_path, warmup_wav, sample_rate)
109
+
110
+ # Run minimal inference with inference_mode for speed
111
+ with torch.inference_mode():
112
+ tts_model.infer(
113
+ spk_audio_prompt=warmup_audio_path,
114
+ text="Hello.",
115
+ output_path=warmup_output_path,
116
+ use_random=False,
117
+ verbose=False,
118
+ )
119
+
120
+ # Cleanup warmup files
121
+ Path(warmup_audio_path).unlink(missing_ok=True)
122
+ Path(warmup_output_path).unlink(missing_ok=True)
123
+
124
+ warmup_time = time.time() - warmup_start
125
+ print(f"Warmup complete in {warmup_time:.2f}s - model is ready!")
126
+ except Exception as warmup_exc:
127
+ print(f"Warmup failed (non-fatal): {warmup_exc}")
128
+ # Continue anyway - first request will just be slower
129
+
130
  except Exception as exc:
131
  raise RuntimeError(f"Failed to load IndexTTS2 model: {exc}") from exc
132
 
 
149
 
150
 
151
  def _write_temp_audio_from_url(url: HttpUrl) -> str:
152
+ """Download audio from URL to temporary file using connection pooling."""
153
+ response = HTTP_SESSION.get(str(url), stream=True, timeout=30)
154
  if response.status_code >= 400:
155
  raise HTTPException(
156
  status_code=400,
 
197
  - convert to mono
198
  - resample to target_sr
199
  - peak-normalize to target_peak (avoid clipping)
200
+
201
+ Optimized to minimize disk I/O.
202
  """
203
  wav, sr = torchaudio.load(path)
204
 
 
253
 
254
 
255
  def _run_generate_job(job_id: str, payload: Dict[str, str]):
256
+ """Background job for TTS generation with optimizations."""
257
  speaker_file = None
258
  output_file = None
259
  _set_job(job_id, status="processing")
260
 
261
  try:
262
+ start_time = time.time()
263
+
264
+ # Download/decode speaker audio
265
  speaker_file = _temp_speaker_file(payload["speaker_wav"])
266
  speaker_file = _preprocess_audio_wav(speaker_file)
267
+ prep_time = time.time() - start_time
268
 
269
  output_file = os.path.join(
270
  tempfile.gettempdir(),
271
  f"indextts2-{uuid.uuid4()}.wav"
272
  )
273
 
274
+ # Run inference with torch.inference_mode() for faster execution
275
+ infer_start = time.time()
276
+ with torch.inference_mode():
277
+ tts_model.infer(
278
+ spk_audio_prompt=speaker_file,
279
+ text=payload["text"],
280
+ output_path=output_file,
281
+ use_random=False,
282
+ verbose=True, # Keep verbose for timing info
283
+ )
284
+ infer_time = time.time() - infer_start
285
 
286
+ # Post-process output
287
  output_file = _preprocess_audio_wav(output_file)
288
 
289
  if not Path(output_file).exists():
 
291
  f"TTS generation failed: output file was not created at {output_file}"
292
  )
293
 
294
+ total_time = time.time() - start_time
295
+ print(f">> Job {job_id[:8]} completed: prep={prep_time:.2f}s, infer={infer_time:.2f}s, total={total_time:.2f}s")
296
+
297
  _cleanup_files(speaker_file)
298
  _set_job(job_id, status="completed", output_file=output_file)
299
  except Exception as exc:
300
+ print(f">> Job {job_id[:8]} failed: {exc}")
301
  _cleanup_files(speaker_file, output_file)
302
  _set_job(job_id, status="error", error=str(exc))
303
 
 
306
  def health(x_api_key: Optional[str] = Header(default=None)):
307
  """Health check endpoint."""
308
  _require_api_key(x_api_key)
309
+ return {
310
+ "status": "ok",
311
+ "model": "indextts2",
312
+ "device": DEVICE,
313
+ "gpu_enabled": USE_GPU,
314
+ "fp16_enabled": USE_GPU,
315
+ }
316
 
317
 
318
  @app.post("/generate")
 
405
  """API root with available endpoints."""
406
  return {
407
  "name": "indextts2-api",
408
+ "device": DEVICE,
409
+ "gpu_enabled": USE_GPU,
410
  "endpoints": [
411
  "/health",
412
  "/generate",