ataberkkilavuzcu commited on
Commit
fd98daf
·
verified ·
1 Parent(s): 48b31ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -43
app.py CHANGED
@@ -5,7 +5,6 @@ import uuid
5
  from pathlib import Path
6
  from threading import Lock
7
  from typing import Dict, Optional
8
- import time
9
 
10
  import requests
11
  import torch
@@ -29,16 +28,11 @@ MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2")
29
  MAX_TEXT_LENGTH = 1000
30
  DEFAULT_LANGUAGE = "en"
31
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
- USE_GPU = DEVICE == "cuda"
33
 
34
  # Job management
35
  JOBS: Dict[str, Dict[str, str]] = {}
36
  JOB_LOCK = Lock()
37
 
38
- # Connection pooling for faster URL downloads
39
- HTTP_SESSION = requests.Session()
40
- HTTP_SESSION.headers.update({"User-Agent": "IndexTTS2-API/1.0"})
41
-
42
  # Set token in environment before importing
43
  if HF_TOKEN:
44
  os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
@@ -51,6 +45,7 @@ if HF_TOKEN:
51
 
52
  # Download model checkpoints from Hugging Face
53
  os.makedirs(MODEL_DIR, exist_ok=True)
 
54
  try:
55
  from huggingface_hub import snapshot_download
56
 
@@ -67,7 +62,7 @@ except Exception as exc:
67
  print(f"Warning: Could not download model: {exc}")
68
  # Continue anyway - model might already be present
69
 
70
- # Initialize IndexTTS2 with GPU optimizations if available
71
  try:
72
  from indextts.infer_v2 import IndexTTS2
73
 
@@ -77,20 +72,14 @@ try:
77
  f"Config file not found at {cfg_path}. Model may not be downloaded."
78
  )
79
 
80
- print(f"Loading IndexTTS2 model on {DEVICE}...")
81
- load_start = time.time()
82
-
83
  tts_model = IndexTTS2(
84
  cfg_path=cfg_path,
85
  model_dir=MODEL_DIR,
86
- use_fp16=False, # Keep disabled - can cause issues
87
- use_cuda_kernel=False, # Keep disabled for stability
88
- use_deepspeed=False, # Keep disabled for stability
89
  )
90
-
91
- load_time = time.time() - load_start
92
- print(f"IndexTTS2 model loaded successfully in {load_time:.2f}s on {DEVICE}")
93
-
94
  except Exception as exc:
95
  raise RuntimeError(f"Failed to load IndexTTS2 model: {exc}") from exc
96
 
@@ -113,8 +102,8 @@ def _require_api_key(x_api_key: Optional[str]):
113
 
114
 
115
  def _write_temp_audio_from_url(url: HttpUrl) -> str:
116
- """Download audio from URL to temporary file using connection pooling."""
117
- response = HTTP_SESSION.get(str(url), stream=True, timeout=30)
118
  if response.status_code >= 400:
119
  raise HTTPException(
120
  status_code=400,
@@ -161,8 +150,6 @@ def _preprocess_audio_wav(
161
  - convert to mono
162
  - resample to target_sr
163
  - peak-normalize to target_peak (avoid clipping)
164
-
165
- Optimized to minimize disk I/O.
166
  """
167
  wav, sr = torchaudio.load(path)
168
 
@@ -223,30 +210,22 @@ def _run_generate_job(job_id: str, payload: Dict[str, str]):
223
  _set_job(job_id, status="processing")
224
 
225
  try:
226
- start_time = time.time()
227
-
228
- # Download/decode speaker audio
229
  speaker_file = _temp_speaker_file(payload["speaker_wav"])
230
  speaker_file = _preprocess_audio_wav(speaker_file)
231
- prep_time = time.time() - start_time
232
 
233
  output_file = os.path.join(
234
  tempfile.gettempdir(),
235
  f"indextts2-{uuid.uuid4()}.wav"
236
  )
237
 
238
- # Run inference (no wrapper - let the model handle its own optimizations)
239
- infer_start = time.time()
240
  tts_model.infer(
241
  spk_audio_prompt=speaker_file,
242
  text=payload["text"],
243
  output_path=output_file,
244
  use_random=False,
245
- verbose=True, # Keep verbose for timing info
246
  )
247
- infer_time = time.time() - infer_start
248
 
249
- # Post-process output
250
  output_file = _preprocess_audio_wav(output_file)
251
 
252
  if not Path(output_file).exists():
@@ -254,13 +233,9 @@ def _run_generate_job(job_id: str, payload: Dict[str, str]):
254
  f"TTS generation failed: output file was not created at {output_file}"
255
  )
256
 
257
- total_time = time.time() - start_time
258
- print(f">> Job {job_id[:8]} completed: prep={prep_time:.2f}s, infer={infer_time:.2f}s, total={total_time:.2f}s")
259
-
260
  _cleanup_files(speaker_file)
261
  _set_job(job_id, status="completed", output_file=output_file)
262
  except Exception as exc:
263
- print(f">> Job {job_id[:8]} failed: {exc}")
264
  _cleanup_files(speaker_file, output_file)
265
  _set_job(job_id, status="error", error=str(exc))
266
 
@@ -269,13 +244,7 @@ def _run_generate_job(job_id: str, payload: Dict[str, str]):
269
  def health(x_api_key: Optional[str] = Header(default=None)):
270
  """Health check endpoint."""
271
  _require_api_key(x_api_key)
272
- return {
273
- "status": "ok",
274
- "model": "indextts2",
275
- "device": DEVICE,
276
- "gpu_enabled": USE_GPU,
277
- "fp16_enabled": USE_GPU,
278
- }
279
 
280
 
281
  @app.post("/generate")
@@ -368,8 +337,6 @@ def root():
368
  """API root with available endpoints."""
369
  return {
370
  "name": "indextts2-api",
371
- "device": DEVICE,
372
- "gpu_enabled": USE_GPU,
373
  "endpoints": [
374
  "/health",
375
  "/generate",
 
5
  from pathlib import Path
6
  from threading import Lock
7
  from typing import Dict, Optional
 
8
 
9
  import requests
10
  import torch
 
28
  MAX_TEXT_LENGTH = 1000
29
  DEFAULT_LANGUAGE = "en"
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
31
 
32
  # Job management
33
  JOBS: Dict[str, Dict[str, str]] = {}
34
  JOB_LOCK = Lock()
35
 
 
 
 
 
36
  # Set token in environment before importing
37
  if HF_TOKEN:
38
  os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
 
45
 
46
  # Download model checkpoints from Hugging Face
47
  os.makedirs(MODEL_DIR, exist_ok=True)
48
+
49
  try:
50
  from huggingface_hub import snapshot_download
51
 
 
62
  print(f"Warning: Could not download model: {exc}")
63
  # Continue anyway - model might already be present
64
 
65
+ # Initialize IndexTTS2
66
  try:
67
  from indextts.infer_v2 import IndexTTS2
68
 
 
72
  f"Config file not found at {cfg_path}. Model may not be downloaded."
73
  )
74
 
 
 
 
75
  tts_model = IndexTTS2(
76
  cfg_path=cfg_path,
77
  model_dir=MODEL_DIR,
78
+ use_fp16=False, # CPU doesn't support FP16
79
+ use_cuda_kernel=False, # CPU mode
80
+ use_deepspeed=False, # CPU mode
81
  )
82
+ print("IndexTTS2 model loaded successfully.")
 
 
 
83
  except Exception as exc:
84
  raise RuntimeError(f"Failed to load IndexTTS2 model: {exc}") from exc
85
 
 
102
 
103
 
104
  def _write_temp_audio_from_url(url: HttpUrl) -> str:
105
+ """Download audio from URL to temporary file."""
106
+ response = requests.get(url, stream=True, timeout=30)
107
  if response.status_code >= 400:
108
  raise HTTPException(
109
  status_code=400,
 
150
  - convert to mono
151
  - resample to target_sr
152
  - peak-normalize to target_peak (avoid clipping)
 
 
153
  """
154
  wav, sr = torchaudio.load(path)
155
 
 
210
  _set_job(job_id, status="processing")
211
 
212
  try:
 
 
 
213
  speaker_file = _temp_speaker_file(payload["speaker_wav"])
214
  speaker_file = _preprocess_audio_wav(speaker_file)
 
215
 
216
  output_file = os.path.join(
217
  tempfile.gettempdir(),
218
  f"indextts2-{uuid.uuid4()}.wav"
219
  )
220
 
 
 
221
  tts_model.infer(
222
  spk_audio_prompt=speaker_file,
223
  text=payload["text"],
224
  output_path=output_file,
225
  use_random=False,
226
+ verbose=False,
227
  )
 
228
 
 
229
  output_file = _preprocess_audio_wav(output_file)
230
 
231
  if not Path(output_file).exists():
 
233
  f"TTS generation failed: output file was not created at {output_file}"
234
  )
235
 
 
 
 
236
  _cleanup_files(speaker_file)
237
  _set_job(job_id, status="completed", output_file=output_file)
238
  except Exception as exc:
 
239
  _cleanup_files(speaker_file, output_file)
240
  _set_job(job_id, status="error", error=str(exc))
241
 
 
244
  def health(x_api_key: Optional[str] = Header(default=None)):
245
  """Health check endpoint."""
246
  _require_api_key(x_api_key)
247
+ return {"status": "ok", "model": "indextts2", "device": DEVICE}
 
 
 
 
 
 
248
 
249
 
250
  @app.post("/generate")
 
337
  """API root with available endpoints."""
338
  return {
339
  "name": "indextts2-api",
 
 
340
  "endpoints": [
341
  "/health",
342
  "/generate",