ataberkkilavuzcu commited on
Commit
b52eab0
·
verified ·
1 Parent(s): 9cf6a70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -30
app.py CHANGED
@@ -1,9 +1,7 @@
1
  import base64
2
  import os
3
- import sys
4
  import tempfile
5
  import uuid
6
- from io import StringIO
7
  from pathlib import Path
8
  from typing import Optional
9
 
@@ -21,42 +19,62 @@ HF_TOKEN = (
21
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
22
  or os.getenv("HF_TOKEN")
23
  )
24
- MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
25
  MAX_TEXT_LENGTH = 1000
26
  DEFAULT_LANGUAGE = "en"
27
 
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
- # Set token in environment before importing TTS
31
  if HF_TOKEN:
32
  os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
33
  os.environ["HF_TOKEN"] = HF_TOKEN
34
- # Also login explicitly via huggingface_hub
35
  try:
36
  from huggingface_hub import login
37
  login(token=HF_TOKEN, add_to_git_credential=False)
38
  except ImportError:
39
- pass # huggingface_hub might not be installed, that's okay
40
 
41
- # Mock stdin to automatically accept TTS Terms of Service
42
- # This prevents the interactive prompt that causes EOFError in containers
43
- _original_stdin = sys.stdin
44
- sys.stdin = StringIO("y\n") # Auto-accept TOS
45
-
46
- from TTS.api import TTS
47
 
48
  try:
49
- tts_model = TTS(MODEL_NAME, gpu=DEVICE == "cuda", progress_bar=False)
50
- except Exception as exc: # pragma: no cover
51
- hint = ""
52
- if "EOF when reading a line" in str(exc):
53
- hint = " Hint: set HUGGING_FACE_HUB_TOKEN to a Hugging Face token that has accepted the XTTS v2 license."
54
- raise RuntimeError(f"Failed to load XTTS v2 model: {exc}.{hint}") from exc
55
- finally:
56
- # Restore stdin after model loading (TOS check happens during model load)
57
- sys.stdin = _original_stdin
 
 
 
 
 
58
 
59
- app = FastAPI(title="xtts-v2-api", version="1.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
  class GenerateRequest(BaseModel):
@@ -133,7 +151,7 @@ def _preprocess_audio_wav(path: str, target_sr: int = 24000, target_peak: float
133
  @app.post("/health")
134
  def health(x_api_key: Optional[str] = Header(default=None)):
135
  _require_api_key(x_api_key)
136
- return {"status": "ok", "model": "xtts_v2", "device": DEVICE}
137
 
138
 
139
  def _cleanup_files(*files: str):
@@ -160,14 +178,17 @@ def generate(
160
  try:
161
  speaker_file = _temp_speaker_file(payload.speaker_wav)
162
  speaker_file = _preprocess_audio_wav(speaker_file)
163
- output_file = os.path.join(tempfile.gettempdir(), f"xtts-{uuid.uuid4()}.wav")
164
 
165
- tts_model.tts_to_file(
 
 
 
 
166
  text=payload.text,
167
- file_path=output_file,
168
- speaker_wav=speaker_file,
169
- language=payload.language or DEFAULT_LANGUAGE,
170
- split_sentences=True,
171
  )
172
 
173
  # Light post-process to avoid end-of-file artifacts
@@ -198,4 +219,4 @@ def generate(
198
 
199
  @app.get("/")
200
  def root():
201
- return {"name": "xtts-v2-api", "endpoints": ["/health", "/generate"]}
 
1
  import base64
2
  import os
 
3
  import tempfile
4
  import uuid
 
5
  from pathlib import Path
6
  from typing import Optional
7
 
 
19
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
20
  or os.getenv("HF_TOKEN")
21
  )
22
+ MODEL_REPO = "IndexTeam/IndexTTS-2"
23
  MAX_TEXT_LENGTH = 1000
24
  DEFAULT_LANGUAGE = "en"
25
 
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
 
28
+ # Set token in environment before importing
29
  if HF_TOKEN:
30
  os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
31
  os.environ["HF_TOKEN"] = HF_TOKEN
 
32
  try:
33
  from huggingface_hub import login
34
  login(token=HF_TOKEN, add_to_git_credential=False)
35
  except ImportError:
36
+ pass
37
 
38
+ # Download model checkpoints from Hugging Face
39
+ MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2")
40
+ os.makedirs(MODEL_DIR, exist_ok=True)
 
 
 
41
 
42
  try:
43
+ from huggingface_hub import snapshot_download
44
+
45
+ # Download model if not already present
46
+ if not Path(MODEL_DIR, "config.yaml").exists():
47
+ print(f"Downloading IndexTTS2 model from {MODEL_REPO}...")
48
+ snapshot_download(
49
+ repo_id=MODEL_REPO,
50
+ local_dir=MODEL_DIR,
51
+ token=HF_TOKEN,
52
+ )
53
+ print("Model download complete.")
54
+ except Exception as exc:
55
+ print(f"Warning: Could not download model: {exc}")
56
+ # Continue anyway - model might already be present
57
 
58
+ # Initialize IndexTTS2
59
+ try:
60
+ from indextts.infer_v2 import IndexTTS2
61
+
62
+ cfg_path = os.path.join(MODEL_DIR, "config.yaml")
63
+ if not Path(cfg_path).exists():
64
+ raise FileNotFoundError(f"Config file not found at {cfg_path}. Model may not be downloaded.")
65
+
66
+ tts_model = IndexTTS2(
67
+ cfg_path=cfg_path,
68
+ model_dir=MODEL_DIR,
69
+ use_fp16=False, # CPU doesn't support FP16
70
+ use_cuda_kernel=False, # CPU mode
71
+ use_deepspeed=False, # CPU mode
72
+ )
73
+ print("IndexTTS2 model loaded successfully.")
74
+ except Exception as exc:
75
+ raise RuntimeError(f"Failed to load IndexTTS2 model: {exc}") from exc
76
+
77
+ app = FastAPI(title="indextts2-api", version="1.0.0")
78
 
79
 
80
  class GenerateRequest(BaseModel):
 
151
  @app.post("/health")
152
  def health(x_api_key: Optional[str] = Header(default=None)):
153
  _require_api_key(x_api_key)
154
+ return {"status": "ok", "model": "indextts2", "device": DEVICE}
155
 
156
 
157
  def _cleanup_files(*files: str):
 
178
  try:
179
  speaker_file = _temp_speaker_file(payload.speaker_wav)
180
  speaker_file = _preprocess_audio_wav(speaker_file)
181
+ output_file = os.path.join(tempfile.gettempdir(), f"indextts2-{uuid.uuid4()}.wav")
182
 
183
+ # IndexTTS2 inference
184
+ # Note: language parameter is kept for API compatibility but IndexTTS2
185
+ # handles multilingual automatically (supports English, Turkish, Chinese, etc.)
186
+ tts_model.infer(
187
+ spk_audio_prompt=speaker_file,
188
  text=payload.text,
189
+ output_path=output_file,
190
+ use_random=False, # Deterministic output
191
+ verbose=False,
 
192
  )
193
 
194
  # Light post-process to avoid end-of-file artifacts
 
219
 
220
  @app.get("/")
221
  def root():
222
+ return {"name": "indextts2-api", "endpoints": ["/health", "/generate"]}