ataberkkilavuzcu commited on
Commit
68b44a1
·
verified ·
1 Parent(s): 64b50bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -74
app.py CHANGED
@@ -1,75 +1,100 @@
 
1
  import os
2
- import uuid
3
  import tempfile
4
- import base64
5
  from pathlib import Path
6
  from threading import Lock
7
- from typing import Optional, Dict
8
 
9
  import requests
10
  import torch
11
  import torchaudio
12
  from torchaudio.transforms import Resample
13
- from fastapi import FastAPI, Body, Header, HTTPException, BackgroundTasks
14
  from fastapi.responses import FileResponse, JSONResponse
15
  from pydantic import BaseModel, Field, HttpUrl
16
 
17
- # ========== Configuration ==========
18
-
19
  SPACE_API_KEY = os.getenv("SPACE_API_KEY")
20
  HF_TOKEN = (
21
  os.getenv("HUGGING_FACE_HUB_TOKEN")
22
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
23
  or os.getenv("HF_TOKEN")
24
  )
 
 
25
  MODEL_REPO = "IndexTeam/IndexTTS-2"
26
  MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2")
27
- os.makedirs(MODEL_DIR, exist_ok=True)
28
-
29
- # Max length for input text
30
  MAX_TEXT_LENGTH = 1000
 
 
31
 
32
- # Use 16 kHz sample rate for faster/audio-size tradeoff
33
- TARGET_SR = 16000
34
-
35
- # Limit PyTorch threads on CPU
36
- torch.set_num_threads(1)
37
 
38
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
39
 
40
- # ========== Download / Load Model ==========
 
41
 
42
  try:
43
  from huggingface_hub import snapshot_download
44
- from indextts.infer_v2 import IndexTTS2
45
- except Exception as e:
46
- raise RuntimeError("Required library missing: ensure `huggingface_hub` and `indextts` are installed.") from e
47
-
48
- # Only download if not already present
49
- config_file = Path(MODEL_DIR) / "config.yaml"
50
- if not config_file.exists():
51
- print(f"Downloading model {MODEL_REPO} to {MODEL_DIR} …")
52
- snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, token=HF_TOKEN)
53
- print("Download complete.")
54
-
55
- tts_model = IndexTTS2(cfg_path=str(config_file), model_dir=MODEL_DIR, use_fp16=False, use_cuda_kernel=False, use_deepspeed=False)
56
- print("IndexTTS-2 loaded, device:", DEVICE)
57
-
58
- # ========== FastAPI app ==========
59
 
60
- app = FastAPI(title="IndexTTS2 API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- JOBS: Dict[str, Dict[str, str]] = {}
63
- JOB_LOCK = Lock()
64
 
65
 
66
  class GenerateRequest(BaseModel):
67
  text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
68
  speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded audio")
69
- language: Optional[str] = Field("en", description="Language code")
70
 
71
 
72
  def _require_api_key(x_api_key: Optional[str]):
 
73
  if not SPACE_API_KEY:
74
  return
75
  if x_api_key != SPACE_API_KEY:
@@ -77,99 +102,150 @@ def _require_api_key(x_api_key: Optional[str]):
77
 
78
 
79
  def _write_temp_audio_from_url(url: HttpUrl) -> str:
 
80
  response = requests.get(url, stream=True, timeout=30)
81
  if response.status_code >= 400:
82
- raise HTTPException(status_code=400, detail=f"Could not fetch speaker audio: {response.status_code}")
 
 
 
 
83
  suffix = Path(url.path).suffix or ".wav"
84
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
85
  for chunk in response.iter_content(chunk_size=8192):
86
  if chunk:
87
  tmp.write(chunk)
88
- return tmp.name
89
 
90
 
91
  def _write_temp_audio_from_base64(payload: str) -> str:
 
92
  try:
93
  raw = base64.b64decode(payload)
94
  except Exception as exc:
95
- raise HTTPException(status_code=400, detail="Invalid base64 speaker_wav") from exc
 
 
 
 
96
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
97
  tmp.write(raw)
98
- return tmp.name
99
 
100
 
101
  def _temp_speaker_file(speaker_wav: str) -> str:
 
102
  if speaker_wav.startswith("http://") or speaker_wav.startswith("https://"):
103
- return _write_temp_audio_from_url(speaker_wav)
104
  return _write_temp_audio_from_base64(speaker_wav)
105
 
106
 
107
- def _preprocess_audio_wav(path: str, target_sr: int = TARGET_SR, target_peak: float = 0.98) -> str:
 
 
 
 
 
 
 
 
 
 
108
  wav, sr = torchaudio.load(path)
 
 
109
  if wav.shape[0] > 1:
110
  wav = wav.mean(dim=0, keepdim=True)
 
 
111
  if sr != target_sr:
112
  resampler = Resample(orig_freq=sr, new_freq=target_sr)
113
  wav = resampler(wav)
114
  sr = target_sr
 
 
115
  peak = wav.abs().max().item() if wav.numel() else 0.0
116
  if peak > 0:
117
- wav = wav * (target_peak / peak)
 
 
 
118
  torchaudio.save(path, wav, sr, bits_per_sample=16)
119
  return path
120
 
121
 
122
  def _set_job(job_id: str, **kwargs):
 
123
  with JOB_LOCK:
124
  JOBS[job_id] = {**JOBS.get(job_id, {}), **kwargs}
125
 
126
 
127
  def _get_job(job_id: str) -> Optional[Dict[str, str]]:
 
128
  with JOB_LOCK:
129
  data = JOBS.get(job_id)
130
- return dict(data) if data else None
131
 
132
 
133
  def _pop_job(job_id: str) -> Optional[Dict[str, str]]:
 
134
  with JOB_LOCK:
135
  return JOBS.pop(job_id, None)
136
 
137
 
138
- def _cleanup_files(*paths: str):
139
- for p in paths:
140
- try:
141
- os.remove(p)
142
- except OSError:
143
- pass
 
 
144
 
145
 
146
  def _run_generate_job(job_id: str, payload: Dict[str, str]):
 
147
  speaker_file = None
148
  output_file = None
149
  _set_job(job_id, status="processing")
 
150
  try:
151
  speaker_file = _temp_speaker_file(payload["speaker_wav"])
152
- speaker_file = _preprocess_audio_wav(speaker_file, target_sr=TARGET_SR)
153
-
154
- output_file = os.path.join(tempfile.gettempdir(), f"indextts2-{uuid.uuid4()}.wav")
155
-
156
- # FIXED ARGUMENT NAME
 
 
157
  tts_model.infer(
 
158
  text=payload["text"],
159
- speaker_prompt=speaker_file,
160
  output_path=output_file,
 
 
161
  )
162
-
 
 
163
  if not Path(output_file).exists():
164
- raise RuntimeError("TTS generation failed — output file not created.")
165
-
 
 
 
166
  _set_job(job_id, status="completed", output_file=output_file)
167
-
168
  except Exception as exc:
169
- _cleanup_files(speaker_file or "", output_file or "")
170
  _set_job(job_id, status="error", error=str(exc))
171
 
172
 
 
 
 
 
 
 
173
 
174
  @app.post("/generate")
175
  def generate(
@@ -177,12 +253,18 @@ def generate(
177
  background_tasks: BackgroundTasks = BackgroundTasks(),
178
  x_api_key: Optional[str] = Header(default=None),
179
  ):
 
 
 
 
180
  _require_api_key(x_api_key)
 
181
  job_id = str(uuid.uuid4())
182
  _set_job(job_id, status="queued")
183
-
 
184
  background_tasks.add_task(_run_generate_job, job_id, payload.dict())
185
-
186
  return JSONResponse(
187
  status_code=202,
188
  content={
@@ -195,31 +277,70 @@ def generate(
195
 
196
 
197
  @app.get("/status/{job_id}")
198
- def status(job_id: str, x_api_key: Optional[str] = Header(default=None)):
 
199
  _require_api_key(x_api_key)
 
200
  job = _get_job(job_id)
201
  if not job:
202
  raise HTTPException(status_code=404, detail="Job not found")
203
- resp = {"job_id": job_id, "status": job.get("status", "unknown")}
 
 
 
 
 
204
  if "error" in job:
205
- resp["error"] = job["error"]
206
- return resp
 
207
 
208
 
209
  @app.get("/result/{job_id}")
210
- def result(job_id: str, x_api_key: Optional[str] = Header(default=None)):
 
 
 
 
 
211
  _require_api_key(x_api_key)
 
212
  job = _get_job(job_id)
213
  if not job:
214
  raise HTTPException(status_code=404, detail="Job not found")
215
- if job.get("status") != "completed":
216
- raise HTTPException(status_code=409, detail=f"Job not ready (status={job.get('status')})")
 
 
 
 
 
 
217
  output_file = job.get("output_file")
218
  if not output_file or not Path(output_file).exists():
219
  _pop_job(job_id)
220
- raise HTTPException(status_code=410, detail="Result missing or expired")
221
- # cleanup after sending
222
- background = BackgroundTasks()
223
- background.add_task(_cleanup_files, output_file)
224
  _pop_job(job_id)
225
- return FileResponse(output_file, media_type="audio/wav", filename="output.wav", background=background)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
  import os
 
3
  import tempfile
4
+ import uuid
5
  from pathlib import Path
6
  from threading import Lock
7
+ from typing import Dict, Optional
8
 
9
  import requests
10
  import torch
11
  import torchaudio
12
  from torchaudio.transforms import Resample
13
+ from fastapi import BackgroundTasks, Body, FastAPI, Header, HTTPException
14
  from fastapi.responses import FileResponse, JSONResponse
15
  from pydantic import BaseModel, Field, HttpUrl
16
 
17
+ # Environment configuration
 
18
  SPACE_API_KEY = os.getenv("SPACE_API_KEY")
19
  HF_TOKEN = (
20
  os.getenv("HUGGING_FACE_HUB_TOKEN")
21
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
22
  or os.getenv("HF_TOKEN")
23
  )
24
+
25
+ # Model configuration
26
  MODEL_REPO = "IndexTeam/IndexTTS-2"
27
  MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2")
 
 
 
28
  MAX_TEXT_LENGTH = 1000
29
+ DEFAULT_LANGUAGE = "en"
30
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
+ # Job management
33
+ JOBS: Dict[str, Dict[str, str]] = {}
34
+ JOB_LOCK = Lock()
 
 
35
 
36
+ # Set token in environment before importing
37
+ if HF_TOKEN:
38
+ os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
39
+ os.environ["HF_TOKEN"] = HF_TOKEN
40
+ try:
41
+ from huggingface_hub import login
42
+ login(token=HF_TOKEN, add_to_git_credential=False)
43
+ except ImportError:
44
+ pass
45
 
46
+ # Download model checkpoints from Hugging Face
47
+ os.makedirs(MODEL_DIR, exist_ok=True)
48
 
49
  try:
50
  from huggingface_hub import snapshot_download
51
+
52
+ # Download model if not already present
53
+ if not Path(MODEL_DIR, "config.yaml").exists():
54
+ print(f"Downloading IndexTTS2 model from {MODEL_REPO}...")
55
+ snapshot_download(
56
+ repo_id=MODEL_REPO,
57
+ local_dir=MODEL_DIR,
58
+ token=HF_TOKEN,
59
+ )
60
+ print("Model download complete.")
61
+ except Exception as exc:
62
+ print(f"Warning: Could not download model: {exc}")
63
+ # Continue anyway - model might already be present
 
 
64
 
65
+ # Initialize IndexTTS2
66
+ try:
67
+ from indextts.infer_v2 import IndexTTS2
68
+
69
+ cfg_path = os.path.join(MODEL_DIR, "config.yaml")
70
+ if not Path(cfg_path).exists():
71
+ raise FileNotFoundError(
72
+ f"Config file not found at {cfg_path}. Model may not be downloaded."
73
+ )
74
+
75
+ tts_model = IndexTTS2(
76
+ cfg_path=cfg_path,
77
+ model_dir=MODEL_DIR,
78
+ use_fp16=False, # CPU doesn't support FP16
79
+ use_cuda_kernel=False, # CPU mode
80
+ use_deepspeed=False, # CPU mode
81
+ )
82
+ print("IndexTTS2 model loaded successfully.")
83
+ except Exception as exc:
84
+ raise RuntimeError(f"Failed to load IndexTTS2 model: {exc}") from exc
85
 
86
+ # Initialize FastAPI app
87
+ app = FastAPI(title="indextts2-api", version="1.0.0")
88
 
89
 
90
  class GenerateRequest(BaseModel):
91
  text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
92
  speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded audio")
93
+ language: Optional[str] = Field(DEFAULT_LANGUAGE, description="ISO code, default en")
94
 
95
 
96
  def _require_api_key(x_api_key: Optional[str]):
97
+ """Validate API key if configured."""
98
  if not SPACE_API_KEY:
99
  return
100
  if x_api_key != SPACE_API_KEY:
 
102
 
103
 
104
  def _write_temp_audio_from_url(url: HttpUrl) -> str:
105
+ """Download audio from URL to temporary file."""
106
  response = requests.get(url, stream=True, timeout=30)
107
  if response.status_code >= 400:
108
+ raise HTTPException(
109
+ status_code=400,
110
+ detail=f"Could not fetch speaker audio: {response.status_code}"
111
+ )
112
+
113
  suffix = Path(url.path).suffix or ".wav"
114
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
115
  for chunk in response.iter_content(chunk_size=8192):
116
  if chunk:
117
  tmp.write(chunk)
118
+ return tmp.name
119
 
120
 
121
  def _write_temp_audio_from_base64(payload: str) -> str:
122
+ """Decode base64 audio to temporary file."""
123
  try:
124
  raw = base64.b64decode(payload)
125
  except Exception as exc:
126
+ raise HTTPException(
127
+ status_code=400,
128
+ detail="Invalid base64 speaker_wav"
129
+ ) from exc
130
+
131
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
132
  tmp.write(raw)
133
+ return tmp.name
134
 
135
 
136
  def _temp_speaker_file(speaker_wav: str) -> str:
137
+ """Handle speaker audio input from URL or base64."""
138
  if speaker_wav.startswith("http://") or speaker_wav.startswith("https://"):
139
+ return _write_temp_audio_from_url(HttpUrl(speaker_wav))
140
  return _write_temp_audio_from_base64(speaker_wav)
141
 
142
 
143
+ def _preprocess_audio_wav(
144
+ path: str,
145
+ target_sr: int = 24000,
146
+ target_peak: float = 0.98
147
+ ) -> str:
148
+ """
149
+ Light preprocessing to stabilize embeddings and output quality:
150
+ - convert to mono
151
+ - resample to target_sr
152
+ - peak-normalize to target_peak (avoid clipping)
153
+ """
154
  wav, sr = torchaudio.load(path)
155
+
156
+ # Convert to mono
157
  if wav.shape[0] > 1:
158
  wav = wav.mean(dim=0, keepdim=True)
159
+
160
+ # Resample if needed
161
  if sr != target_sr:
162
  resampler = Resample(orig_freq=sr, new_freq=target_sr)
163
  wav = resampler(wav)
164
  sr = target_sr
165
+
166
+ # Peak normalize
167
  peak = wav.abs().max().item() if wav.numel() else 0.0
168
  if peak > 0:
169
+ scale = min(target_peak / peak, 1.0)
170
+ wav = wav * scale
171
+
172
+ # Overwrite input file to avoid extra temp files
173
  torchaudio.save(path, wav, sr, bits_per_sample=16)
174
  return path
175
 
176
 
177
  def _set_job(job_id: str, **kwargs):
178
+ """Thread-safe job update."""
179
  with JOB_LOCK:
180
  JOBS[job_id] = {**JOBS.get(job_id, {}), **kwargs}
181
 
182
 
183
  def _get_job(job_id: str) -> Optional[Dict[str, str]]:
184
+ """Thread-safe job retrieval."""
185
  with JOB_LOCK:
186
  data = JOBS.get(job_id)
187
+ return dict(data) if data else None
188
 
189
 
190
  def _pop_job(job_id: str) -> Optional[Dict[str, str]]:
191
+ """Thread-safe job removal."""
192
  with JOB_LOCK:
193
  return JOBS.pop(job_id, None)
194
 
195
 
196
+ def _cleanup_files(*files: str):
197
+ """Background task to clean up temporary files after response is sent."""
198
+ for file_path in files:
199
+ if file_path and Path(file_path).exists():
200
+ try:
201
+ Path(file_path).unlink(missing_ok=True)
202
+ except Exception:
203
+ pass # Ignore cleanup errors
204
 
205
 
206
  def _run_generate_job(job_id: str, payload: Dict[str, str]):
207
+ """Background job for TTS generation."""
208
  speaker_file = None
209
  output_file = None
210
  _set_job(job_id, status="processing")
211
+
212
  try:
213
  speaker_file = _temp_speaker_file(payload["speaker_wav"])
214
+ speaker_file = _preprocess_audio_wav(speaker_file)
215
+
216
+ output_file = os.path.join(
217
+ tempfile.gettempdir(),
218
+ f"indextts2-{uuid.uuid4()}.wav"
219
+ )
220
+
221
  tts_model.infer(
222
+ spk_audio_prompt=speaker_file,
223
  text=payload["text"],
 
224
  output_path=output_file,
225
+ use_random=False,
226
+ verbose=False,
227
  )
228
+
229
+ output_file = _preprocess_audio_wav(output_file)
230
+
231
  if not Path(output_file).exists():
232
+ raise RuntimeError(
233
+ f"TTS generation failed: output file was not created at {output_file}"
234
+ )
235
+
236
+ _cleanup_files(speaker_file)
237
  _set_job(job_id, status="completed", output_file=output_file)
 
238
  except Exception as exc:
239
+ _cleanup_files(speaker_file, output_file)
240
  _set_job(job_id, status="error", error=str(exc))
241
 
242
 
243
+ @app.post("/health")
244
+ def health(x_api_key: Optional[str] = Header(default=None)):
245
+ """Health check endpoint."""
246
+ _require_api_key(x_api_key)
247
+ return {"status": "ok", "model": "indextts2", "device": DEVICE}
248
+
249
 
250
  @app.post("/generate")
251
  def generate(
 
253
  background_tasks: BackgroundTasks = BackgroundTasks(),
254
  x_api_key: Optional[str] = Header(default=None),
255
  ):
256
+ """
257
+ Generate speech from text using voice cloning.
258
+ Returns job information for async processing.
259
+ """
260
  _require_api_key(x_api_key)
261
+
262
  job_id = str(uuid.uuid4())
263
  _set_job(job_id, status="queued")
264
+
265
+ # Offload the long-running synthesis so the HTTP request stays fast (<100s)
266
  background_tasks.add_task(_run_generate_job, job_id, payload.dict())
267
+
268
  return JSONResponse(
269
  status_code=202,
270
  content={
 
277
 
278
 
279
  @app.get("/status/{job_id}")
280
+ def job_status(job_id: str, x_api_key: Optional[str] = Header(default=None)):
281
+ """Check the status of a generation job."""
282
  _require_api_key(x_api_key)
283
+
284
  job = _get_job(job_id)
285
  if not job:
286
  raise HTTPException(status_code=404, detail="Job not found")
287
+
288
+ payload: Dict[str, str] = {
289
+ "job_id": job_id,
290
+ "status": job.get("status", "unknown")
291
+ }
292
+
293
  if "error" in job:
294
+ payload["error"] = job["error"]
295
+
296
+ return payload
297
 
298
 
299
  @app.get("/result/{job_id}")
300
+ def job_result(
301
+ job_id: str,
302
+ background_tasks: BackgroundTasks = BackgroundTasks(),
303
+ x_api_key: Optional[str] = Header(default=None),
304
+ ):
305
+ """Retrieve the result of a completed generation job."""
306
  _require_api_key(x_api_key)
307
+
308
  job = _get_job(job_id)
309
  if not job:
310
  raise HTTPException(status_code=404, detail="Job not found")
311
+
312
+ status = job.get("status")
313
+ if status != "completed":
314
+ raise HTTPException(
315
+ status_code=409,
316
+ detail=f"Job not ready (status={status})"
317
+ )
318
+
319
  output_file = job.get("output_file")
320
  if not output_file or not Path(output_file).exists():
321
  _pop_job(job_id)
322
+ raise HTTPException(status_code=410, detail="Result expired or missing")
323
+
324
+ # Remove job from memory and cleanup output after sending
 
325
  _pop_job(job_id)
326
+ background_tasks.add_task(_cleanup_files, output_file)
327
+
328
+ return FileResponse(
329
+ output_file,
330
+ media_type="audio/wav",
331
+ filename="output.wav"
332
+ )
333
+
334
+
335
+ @app.get("/")
336
+ def root():
337
+ """API root with available endpoints."""
338
+ return {
339
+ "name": "indextts2-api",
340
+ "endpoints": [
341
+ "/health",
342
+ "/generate",
343
+ "/status/{job_id}",
344
+ "/result/{job_id}"
345
+ ],
346
+ }