ataberkkilavuzcu commited on
Commit
48b31ff
·
verified ·
1 Parent(s): 16e9700

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -48
app.py CHANGED
@@ -83,49 +83,13 @@ try:
83
  tts_model = IndexTTS2(
84
  cfg_path=cfg_path,
85
  model_dir=MODEL_DIR,
86
- use_fp16=USE_GPU, # Enable FP16 on GPU for ~30-40% speedup
87
- use_cuda_kernel=USE_GPU, # Enable CUDA kernels on GPU
88
  use_deepspeed=False, # Keep disabled for stability
89
  )
90
 
91
  load_time = time.time() - load_start
92
  print(f"IndexTTS2 model loaded successfully in {load_time:.2f}s on {DEVICE}")
93
-
94
- # Warmup inference to initialize all model components
95
- # This moves the initialization cost from first request to startup
96
- print("Running warmup inference...")
97
- warmup_start = time.time()
98
- try:
99
- # Create a minimal warmup audio file
100
- warmup_audio_path = os.path.join(tempfile.gettempdir(), "warmup.wav")
101
- warmup_output_path = os.path.join(tempfile.gettempdir(), "warmup_out.wav")
102
-
103
- # Generate a short sine wave for warmup (1 second at 24kHz)
104
- sample_rate = 24000
105
- duration = 1.0
106
- t = torch.linspace(0, duration, int(sample_rate * duration))
107
- warmup_wav = (0.5 * torch.sin(2 * 3.14159 * 440 * t)).unsqueeze(0)
108
- torchaudio.save(warmup_audio_path, warmup_wav, sample_rate)
109
-
110
- # Run minimal inference with inference_mode for speed
111
- with torch.inference_mode():
112
- tts_model.infer(
113
- spk_audio_prompt=warmup_audio_path,
114
- text="Hello.",
115
- output_path=warmup_output_path,
116
- use_random=False,
117
- verbose=False,
118
- )
119
-
120
- # Cleanup warmup files
121
- Path(warmup_audio_path).unlink(missing_ok=True)
122
- Path(warmup_output_path).unlink(missing_ok=True)
123
-
124
- warmup_time = time.time() - warmup_start
125
- print(f"Warmup complete in {warmup_time:.2f}s - model is ready!")
126
- except Exception as warmup_exc:
127
- print(f"Warmup failed (non-fatal): {warmup_exc}")
128
- # Continue anyway - first request will just be slower
129
 
130
  except Exception as exc:
131
  raise RuntimeError(f"Failed to load IndexTTS2 model: {exc}") from exc
@@ -253,7 +217,7 @@ def _cleanup_files(*files: str):
253
 
254
 
255
  def _run_generate_job(job_id: str, payload: Dict[str, str]):
256
- """Background job for TTS generation with optimizations."""
257
  speaker_file = None
258
  output_file = None
259
  _set_job(job_id, status="processing")
@@ -271,16 +235,15 @@ def _run_generate_job(job_id: str, payload: Dict[str, str]):
271
  f"indextts2-{uuid.uuid4()}.wav"
272
  )
273
 
274
- # Run inference with torch.inference_mode() for faster execution
275
  infer_start = time.time()
276
- with torch.inference_mode():
277
- tts_model.infer(
278
- spk_audio_prompt=speaker_file,
279
- text=payload["text"],
280
- output_path=output_file,
281
- use_random=False,
282
- verbose=True, # Keep verbose for timing info
283
- )
284
  infer_time = time.time() - infer_start
285
 
286
  # Post-process output
 
83
  tts_model = IndexTTS2(
84
  cfg_path=cfg_path,
85
  model_dir=MODEL_DIR,
86
+ use_fp16=False, # Keep disabled - can cause issues
87
+ use_cuda_kernel=False, # Keep disabled for stability
88
  use_deepspeed=False, # Keep disabled for stability
89
  )
90
 
91
  load_time = time.time() - load_start
92
  print(f"IndexTTS2 model loaded successfully in {load_time:.2f}s on {DEVICE}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  except Exception as exc:
95
  raise RuntimeError(f"Failed to load IndexTTS2 model: {exc}") from exc
 
217
 
218
 
219
  def _run_generate_job(job_id: str, payload: Dict[str, str]):
220
+ """Background job for TTS generation."""
221
  speaker_file = None
222
  output_file = None
223
  _set_job(job_id, status="processing")
 
235
  f"indextts2-{uuid.uuid4()}.wav"
236
  )
237
 
238
+ # Run inference (no wrapper - let the model handle its own optimizations)
239
  infer_start = time.time()
240
+ tts_model.infer(
241
+ spk_audio_prompt=speaker_file,
242
+ text=payload["text"],
243
+ output_path=output_file,
244
+ use_random=False,
245
+ verbose=True, # Keep verbose for timing info
246
+ )
 
247
  infer_time = time.time() - infer_start
248
 
249
  # Post-process output