ataberkkilavuzcu commited on
Commit
edcc9a5
·
verified ·
1 Parent(s): b71bca4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -134
app.py CHANGED
@@ -8,6 +8,8 @@ from typing import Dict, Optional
8
 
9
  import requests
10
  import torch
 
 
11
  from fastapi import BackgroundTasks, Body, FastAPI, Header, HTTPException
12
  from fastapi.responses import FileResponse, JSONResponse
13
  from pydantic import BaseModel, Field, HttpUrl
@@ -21,7 +23,8 @@ HF_TOKEN = (
21
  )
22
 
23
  # Model configuration
24
- MODEL_DIR = os.getenv("MODEL_DIR", "./checkpoints")
 
25
  MAX_TEXT_LENGTH = 1000
26
  DEFAULT_LANGUAGE = "en"
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -34,73 +37,60 @@ JOB_LOCK = Lock()
34
  if HF_TOKEN:
35
  os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
36
  os.environ["HF_TOKEN"] = HF_TOKEN
 
 
 
 
 
37
 
38
- # Download and initialize OpenVoice model
39
  os.makedirs(MODEL_DIR, exist_ok=True)
40
 
41
- print(f"Initializing OpenVoice on {DEVICE}...")
42
-
43
  try:
44
- # Download checkpoints if needed
45
- if not Path(MODEL_DIR, "checkpoints_v2").exists():
46
- print("Downloading OpenVoice V2 checkpoints...")
47
- from huggingface_hub import snapshot_download
48
-
49
  snapshot_download(
50
- repo_id="myshell-ai/OpenVoice",
51
  local_dir=MODEL_DIR,
52
  token=HF_TOKEN,
53
  )
54
  print("Model download complete.")
 
 
 
 
 
 
 
55
 
56
- # Import OpenVoice modules
57
- from melo.api import TTS
58
- from openvoice import se_extractor
59
- from openvoice.api import ToneColorConverter
60
-
61
- # Initialize base TTS (MeloTTS)
62
- ckpt_converter = f'{MODEL_DIR}/checkpoints_v2/converter'
63
 
64
- # Initialize tone color converter
65
- tone_color_converter = ToneColorConverter(
66
- f'{ckpt_converter}/config.json',
67
- device=DEVICE
 
 
68
  )
69
- tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
70
-
71
- # Initialize base TTS for English
72
- base_speaker_tts = TTS(language='EN', device=DEVICE)
73
- base_speaker = base_speaker_tts.hps.data.spk2id['EN-US']
74
-
75
- print("OpenVoice V2 loaded successfully!")
76
-
77
  except Exception as exc:
78
- print(f"Error loading OpenVoice: {exc}")
79
- print("Trying alternative initialization...")
80
-
81
- try:
82
- # Fallback: Use simpler initialization
83
- from melo.api import TTS
84
-
85
- base_speaker_tts = TTS(language='EN', device=DEVICE)
86
- base_speaker = base_speaker_tts.hps.data.spk2id['EN-US']
87
-
88
- # Mock converter for basic functionality
89
- tone_color_converter = None
90
- print("Loaded base TTS only (voice cloning disabled)")
91
-
92
- except Exception as exc2:
93
- raise RuntimeError(f"Failed to load OpenVoice: {exc2}") from exc2
94
 
95
  # Initialize FastAPI app
96
- app = FastAPI(title="openvoice-api", version="2.0.0")
97
 
98
 
99
  class GenerateRequest(BaseModel):
100
  text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
101
  speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded audio")
102
- language: Optional[str] = Field(DEFAULT_LANGUAGE, description="ISO code: en, es, fr, zh, ja, ko")
103
- speed: Optional[float] = Field(1.0, ge=0.5, le=2.0, description="Speech speed (0.5-2.0)")
104
 
105
 
106
  def _require_api_key(x_api_key: Optional[str]):
@@ -150,6 +140,40 @@ def _temp_speaker_file(speaker_wav: str) -> str:
150
  return _write_temp_audio_from_base64(speaker_wav)
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def _set_job(job_id: str, **kwargs):
154
  """Thread-safe job update."""
155
  with JOB_LOCK:
@@ -180,90 +204,39 @@ def _cleanup_files(*files: str):
180
 
181
 
182
  def _run_generate_job(job_id: str, payload: Dict[str, str]):
183
- """
184
- Background job for TTS generation using OpenVoice V2.
185
- Two-step process:
186
- 1. Generate base speech with MeloTTS
187
- 2. Apply target voice characteristics with ToneColorConverter
188
- """
189
  speaker_file = None
190
- temp_audio = None
191
  output_file = None
192
  _set_job(job_id, status="processing")
193
 
194
  try:
195
- # Step 1: Generate base speech
196
- temp_audio = os.path.join(
 
 
197
  tempfile.gettempdir(),
198
- f"openvoice-temp-{uuid.uuid4()}.wav"
199
  )
200
 
201
- speed = float(payload.get("speed", 1.0))
202
-
203
- base_speaker_tts.tts_to_file(
204
- payload["text"],
205
- base_speaker,
206
- temp_audio,
207
- speed=speed
208
  )
209
 
210
- # Step 2: Apply voice cloning if converter is available
211
- if tone_color_converter is not None:
212
- try:
213
- # Prepare reference audio
214
- speaker_file = _temp_speaker_file(payload["speaker_wav"])
215
-
216
- # Extract target speaker embedding
217
- target_se, _ = se_extractor.get_se(
218
- speaker_file,
219
- tone_color_converter,
220
- vad=True
221
- )
222
-
223
- # Get source speaker embedding
224
- source_se = torch.load(
225
- f'{MODEL_DIR}/checkpoints_v2/base_speakers/ses/en-us.pth',
226
- map_location=DEVICE
227
- )
228
-
229
- # Apply voice conversion
230
- output_file = os.path.join(
231
- tempfile.gettempdir(),
232
- f"openvoice-{uuid.uuid4()}.wav"
233
- )
234
-
235
- tone_color_converter.convert(
236
- audio_src_path=temp_audio,
237
- src_se=source_se,
238
- tgt_se=target_se,
239
- output_path=output_file,
240
- message="@MyShell"
241
- )
242
-
243
- # Cleanup temp audio
244
- _cleanup_files(speaker_file, temp_audio)
245
-
246
- except Exception as convert_error:
247
- print(f"Voice conversion failed: {convert_error}")
248
- # Fall back to base audio without voice cloning
249
- output_file = temp_audio
250
- temp_audio = None
251
- _cleanup_files(speaker_file)
252
- else:
253
- # No converter available, use base audio
254
- output_file = temp_audio
255
- temp_audio = None
256
 
257
- # Verify output exists
258
  if not Path(output_file).exists():
259
  raise RuntimeError(
260
- f"TTS generation failed: output file was not created"
261
  )
262
 
 
263
  _set_job(job_id, status="completed", output_file=output_file)
264
-
265
  except Exception as exc:
266
- _cleanup_files(speaker_file, temp_audio, output_file)
267
  _set_job(job_id, status="error", error=str(exc))
268
 
269
 
@@ -271,13 +244,7 @@ def _run_generate_job(job_id: str, payload: Dict[str, str]):
271
  def health(x_api_key: Optional[str] = Header(default=None)):
272
  """Health check endpoint."""
273
  _require_api_key(x_api_key)
274
- return {
275
- "status": "ok",
276
- "model": "openvoice-v2",
277
- "device": DEVICE,
278
- "voice_cloning": tone_color_converter is not None,
279
- "supported_languages": ["en", "es", "fr", "zh", "ja", "ko"]
280
- }
281
 
282
 
283
  @app.post("/generate")
@@ -287,7 +254,7 @@ def generate(
287
  x_api_key: Optional[str] = Header(default=None),
288
  ):
289
  """
290
- Generate speech from text using voice cloning with OpenVoice.
291
  Returns job information for async processing.
292
  """
293
  _require_api_key(x_api_key)
@@ -295,7 +262,7 @@ def generate(
295
  job_id = str(uuid.uuid4())
296
  _set_job(job_id, status="queued")
297
 
298
- # Offload the synthesis to background task
299
  background_tasks.add_task(_run_generate_job, job_id, payload.dict())
300
 
301
  return JSONResponse(
@@ -369,20 +336,11 @@ def job_result(
369
  def root():
370
  """API root with available endpoints."""
371
  return {
372
- "name": "openvoice-api",
373
- "version": "2.0.0",
374
- "model": "OpenVoice V2",
375
- "voice_cloning": tone_color_converter is not None,
376
  "endpoints": [
377
  "/health",
378
  "/generate",
379
  "/status/{job_id}",
380
  "/result/{job_id}"
381
  ],
382
- "features": [
383
- "Voice cloning with 3-10s reference audio" if tone_color_converter else "Base TTS only",
384
- "Multi-language support (EN, ES, FR, ZH, JA, KO)",
385
- "Adjustable speech speed (0.5-2.0x)",
386
- "Fast CPU performance"
387
- ]
388
  }
 
8
 
9
  import requests
10
  import torch
11
+ import torchaudio
12
+ from torchaudio.transforms import Resample
13
  from fastapi import BackgroundTasks, Body, FastAPI, Header, HTTPException
14
  from fastapi.responses import FileResponse, JSONResponse
15
  from pydantic import BaseModel, Field, HttpUrl
 
23
  )
24
 
25
  # Model configuration
26
+ MODEL_REPO = "IndexTeam/IndexTTS-2"
27
+ MODEL_DIR = os.getenv("MODEL_DIR", "/data/indextts2")
28
  MAX_TEXT_LENGTH = 1000
29
  DEFAULT_LANGUAGE = "en"
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
37
  if HF_TOKEN:
38
  os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
39
  os.environ["HF_TOKEN"] = HF_TOKEN
40
+ try:
41
+ from huggingface_hub import login
42
+ login(token=HF_TOKEN, add_to_git_credential=False)
43
+ except ImportError:
44
+ pass
45
 
46
+ # Download model checkpoints from Hugging Face
47
  os.makedirs(MODEL_DIR, exist_ok=True)
48
 
 
 
49
  try:
50
+ from huggingface_hub import snapshot_download
51
+
52
+ # Download model if not already present
53
+ if not Path(MODEL_DIR, "config.yaml").exists():
54
+ print(f"Downloading IndexTTS2 model from {MODEL_REPO}...")
55
  snapshot_download(
56
+ repo_id=MODEL_REPO,
57
  local_dir=MODEL_DIR,
58
  token=HF_TOKEN,
59
  )
60
  print("Model download complete.")
61
+ except Exception as exc:
62
+ print(f"Warning: Could not download model: {exc}")
63
+ # Continue anyway - model might already be present
64
+
65
+ # Initialize IndexTTS2
66
+ try:
67
+ from indextts.infer_v2 import IndexTTS2
68
 
69
+ cfg_path = os.path.join(MODEL_DIR, "config.yaml")
70
+ if not Path(cfg_path).exists():
71
+ raise FileNotFoundError(
72
+ f"Config file not found at {cfg_path}. Model may not be downloaded."
73
+ )
 
 
74
 
75
+ tts_model = IndexTTS2(
76
+ cfg_path=cfg_path,
77
+ model_dir=MODEL_DIR,
78
+ use_fp16=False, # CPU doesn't support FP16
79
+ use_cuda_kernel=False, # CPU mode
80
+ use_deepspeed=False, # CPU mode
81
  )
82
+ print("IndexTTS2 model loaded successfully.")
 
 
 
 
 
 
 
83
  except Exception as exc:
84
+ raise RuntimeError(f"Failed to load IndexTTS2 model: {exc}") from exc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  # Initialize FastAPI app
87
+ app = FastAPI(title="indextts2-api", version="1.0.0")
88
 
89
 
90
  class GenerateRequest(BaseModel):
91
  text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH)
92
  speaker_wav: str = Field(..., description="HTTPS URL or base64-encoded audio")
93
+ language: Optional[str] = Field(DEFAULT_LANGUAGE, description="ISO code, default en")
 
94
 
95
 
96
  def _require_api_key(x_api_key: Optional[str]):
 
140
  return _write_temp_audio_from_base64(speaker_wav)
141
 
142
 
143
+ def _preprocess_audio_wav(
144
+ path: str,
145
+ target_sr: int = 24000,
146
+ target_peak: float = 0.98
147
+ ) -> str:
148
+ """
149
+ Light preprocessing to stabilize embeddings and output quality:
150
+ - convert to mono
151
+ - resample to target_sr
152
+ - peak-normalize to target_peak (avoid clipping)
153
+ """
154
+ wav, sr = torchaudio.load(path)
155
+
156
+ # Convert to mono
157
+ if wav.shape[0] > 1:
158
+ wav = wav.mean(dim=0, keepdim=True)
159
+
160
+ # Resample if needed
161
+ if sr != target_sr:
162
+ resampler = Resample(orig_freq=sr, new_freq=target_sr)
163
+ wav = resampler(wav)
164
+ sr = target_sr
165
+
166
+ # Peak normalize
167
+ peak = wav.abs().max().item() if wav.numel() else 0.0
168
+ if peak > 0:
169
+ scale = min(target_peak / peak, 1.0)
170
+ wav = wav * scale
171
+
172
+ # Overwrite input file to avoid extra temp files
173
+ torchaudio.save(path, wav, sr, bits_per_sample=16)
174
+ return path
175
+
176
+
177
  def _set_job(job_id: str, **kwargs):
178
  """Thread-safe job update."""
179
  with JOB_LOCK:
 
204
 
205
 
206
  def _run_generate_job(job_id: str, payload: Dict[str, str]):
207
+ """Background job for TTS generation."""
 
 
 
 
 
208
  speaker_file = None
 
209
  output_file = None
210
  _set_job(job_id, status="processing")
211
 
212
  try:
213
+ speaker_file = _temp_speaker_file(payload["speaker_wav"])
214
+ speaker_file = _preprocess_audio_wav(speaker_file)
215
+
216
+ output_file = os.path.join(
217
  tempfile.gettempdir(),
218
+ f"indextts2-{uuid.uuid4()}.wav"
219
  )
220
 
221
+ tts_model.infer(
222
+ spk_audio_prompt=speaker_file,
223
+ text=payload["text"],
224
+ output_path=output_file,
225
+ use_random=False,
226
+ verbose=False,
 
227
  )
228
 
229
+ output_file = _preprocess_audio_wav(output_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
 
231
  if not Path(output_file).exists():
232
  raise RuntimeError(
233
+ f"TTS generation failed: output file was not created at {output_file}"
234
  )
235
 
236
+ _cleanup_files(speaker_file)
237
  _set_job(job_id, status="completed", output_file=output_file)
 
238
  except Exception as exc:
239
+ _cleanup_files(speaker_file, output_file)
240
  _set_job(job_id, status="error", error=str(exc))
241
 
242
 
 
244
  def health(x_api_key: Optional[str] = Header(default=None)):
245
  """Health check endpoint."""
246
  _require_api_key(x_api_key)
247
+ return {"status": "ok", "model": "indextts2", "device": DEVICE}
 
 
 
 
 
 
248
 
249
 
250
  @app.post("/generate")
 
254
  x_api_key: Optional[str] = Header(default=None),
255
  ):
256
  """
257
+ Generate speech from text using voice cloning.
258
  Returns job information for async processing.
259
  """
260
  _require_api_key(x_api_key)
 
262
  job_id = str(uuid.uuid4())
263
  _set_job(job_id, status="queued")
264
 
265
+ # Offload the long-running synthesis so the HTTP request stays fast (<100s)
266
  background_tasks.add_task(_run_generate_job, job_id, payload.dict())
267
 
268
  return JSONResponse(
 
336
  def root():
337
  """API root with available endpoints."""
338
  return {
339
+ "name": "indextts2-api",
 
 
 
340
  "endpoints": [
341
  "/health",
342
  "/generate",
343
  "/status/{job_id}",
344
  "/result/{job_id}"
345
  ],
 
 
 
 
 
 
346
  }