yukee1992 commited on
Commit
2eaadb8
Β·
verified Β·
1 Parent(s): 697cc6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -44
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import tempfile
3
  import uuid
4
  import time
 
5
  from datetime import datetime
6
  from typing import List, Optional
7
  from pathlib import Path
@@ -36,9 +37,10 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
36
 
37
  print(f"βœ… Using device: {DEVICE}")
38
 
39
- # Initialize TTS model with automatic TOS acceptance
40
  tts = None
41
  model_loaded = False
 
42
 
43
  try:
44
  # Set environment variable to automatically accept terms
@@ -56,18 +58,37 @@ try:
56
  sys.stdin = StringIO('y\n')
57
 
58
  try:
59
- print("πŸš€ Loading TTS model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  tts = TTS(DEFAULT_MODEL).to(DEVICE)
61
  model_loaded = True
62
- print("βœ… TTS model loaded successfully")
 
 
63
  except Exception as e:
64
- print(f"❌ Primary model failed: {e}")
65
  # Try fallback model
66
  try:
67
  print("πŸ”„ Trying fallback model...")
68
  tts = TTS("tts_models/en/ljspeech/tacotron2-DDC").to(DEVICE)
69
  model_loaded = True
70
- print("βœ… Fallback TTS model loaded successfully")
 
71
  except Exception as fallback_error:
72
  print(f"❌ Fallback model also failed: {fallback_error}")
73
  tts = None
@@ -182,11 +203,9 @@ def clone_voice(voice_name: str, audio_files: List[str], description: str = ""):
182
  # Copy audio files to voice directory
183
  for i, audio_file in enumerate(audio_files):
184
  dest_path = f"{voice_dir}/sample_{i+1}.wav"
185
- # For now, just create a placeholder since we can't copy files in this context
186
- # In a real implementation, you'd copy the files here
187
- print(f" Would copy sample {i+1} to: {dest_path}")
188
 
189
- # For XTTS model, we can use the samples directly
190
  print(f"βœ… Voice cloning setup completed for {voice_name}")
191
 
192
  return True, f"Voice {voice_name} is ready for use"
@@ -194,6 +213,10 @@ def clone_voice(voice_name: str, audio_files: List[str], description: str = ""):
194
  except Exception as e:
195
  return False, f"Voice cloning failed: {str(e)}"
196
 
 
 
 
 
197
  # API endpoints
198
  @app.post("/api/tts")
199
  async def generate_tts(request: TTSRequest):
@@ -212,6 +235,14 @@ async def generate_tts(request: TTSRequest):
212
  print(f" Voice: {request.voice_name}")
213
  print(f" Language: {request.language}")
214
 
 
 
 
 
 
 
 
 
215
  # Generate unique filename with sequential naming
216
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
217
  filename = f"voiceover_{timestamp}.wav"
@@ -227,13 +258,21 @@ async def generate_tts(request: TTSRequest):
227
  "message": f"Voice '{request.voice_name}' not found"
228
  }
229
 
230
- # Generate TTS
231
- tts.tts_to_file(
232
- text=request.text,
233
- speaker_wav=speaker_wav,
234
- language=request.language,
235
- file_path=output_path
236
- )
 
 
 
 
 
 
 
 
237
 
238
  print(f"βœ… TTS generated: {output_path}")
239
 
@@ -264,7 +303,9 @@ async def generate_tts(request: TTSRequest):
264
  "status": "success",
265
  "message": "TTS generated and uploaded successfully",
266
  "filename": filename,
267
- "oci_path": upload_result.get("path", f"{request.project_id}/voiceover/{filename}")
 
 
268
  }
269
 
270
  except Exception as e:
@@ -283,6 +324,13 @@ async def batch_generate_tts(request: BatchTTSRequest):
283
  print(f" Voice: {request.voice_name}")
284
  print(f" Language: {request.language}")
285
 
 
 
 
 
 
 
 
286
  # Get voice path if custom voice is requested
287
  speaker_wav = None
288
  if request.voice_name != "default":
@@ -299,13 +347,21 @@ async def batch_generate_tts(request: BatchTTSRequest):
299
  filename = f"voiceover_{i+1:02d}.wav"
300
  output_path = f"/tmp/output/{filename}"
301
 
302
- # Generate TTS
303
- tts.tts_to_file(
304
- text=text,
305
- speaker_wav=speaker_wav,
306
- language=request.language,
307
- file_path=output_path
308
- )
 
 
 
 
 
 
 
 
309
 
310
  # Upload to OCI
311
  upload_result, error = upload_to_oci_with_retry(
@@ -340,7 +396,9 @@ async def batch_generate_tts(request: BatchTTSRequest):
340
  return {
341
  "status": "completed",
342
  "project_id": request.project_id,
343
- "results": results
 
 
344
  }
345
 
346
  except Exception as e:
@@ -357,6 +415,13 @@ async def upload_voice_sample(
357
  try:
358
  print(f"πŸ“₯ Voice upload request: {voice_name} for project {project_id}")
359
 
 
 
 
 
 
 
 
360
  # Validate file type
361
  if not file.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')):
362
  raise HTTPException(status_code=400, detail="Only audio files are allowed")
@@ -391,6 +456,13 @@ async def api_clone_voice(
391
  try:
392
  print(f"πŸ“₯ Voice cloning request: {voice_name} for project {project_id}")
393
 
 
 
 
 
 
 
 
394
  # Save uploaded files temporarily
395
  temp_files = []
396
  for i, file in enumerate(files):
@@ -455,7 +527,8 @@ async def list_voices():
455
 
456
  return {
457
  "status": "success",
458
- "voices": voices
 
459
  }
460
 
461
  except Exception as e:
@@ -468,6 +541,8 @@ async def health_check():
468
  return {
469
  "status": "healthy",
470
  "tts_loaded": tts is not None,
 
 
471
  "device": DEVICE,
472
  "timestamp": datetime.now().isoformat()
473
  }
@@ -477,21 +552,4 @@ async def root():
477
  """Root endpoint with API information"""
478
  return {
479
  "message": "TTS API with Coqui TTS and Voice Cloning",
480
- "endpoints": {
481
- "POST /api/tts": "Generate TTS for a single text",
482
- "POST /api/batch-tts": "Generate TTS for multiple texts",
483
- "POST /api/upload-voice": "Upload a voice sample for cloning",
484
- "POST /api/clone-voice": "Clone a voice from multiple samples",
485
- "GET /api/voices": "List available voices",
486
- "GET /api/health": "Health check"
487
- },
488
- "model_loaded": tts is not None,
489
- "model_name": DEFAULT_MODEL if tts else "None"
490
- }
491
-
492
- if __name__ == "__main__":
493
- import uvicorn
494
- print("πŸš€ Starting TTS API with Coqui TTS and Voice Cloning...")
495
- print("πŸ“Š API endpoints available at: http://localhost:7860/")
496
- print("πŸ“š Documentation available at: http://localhost:7860/docs")
497
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
2
  import tempfile
3
  import uuid
4
  import time
5
+ import shutil
6
  from datetime import datetime
7
  from typing import List, Optional
8
  from pathlib import Path
 
37
 
38
  print(f"βœ… Using device: {DEVICE}")
39
 
40
+ # Initialize TTS model with automatic TOS acceptance and safe globals
41
  tts = None
42
  model_loaded = False
43
+ current_model = ""
44
 
45
  try:
46
  # Set environment variable to automatically accept terms
 
58
  sys.stdin = StringIO('y\n')
59
 
60
  try:
61
+ print("πŸš€ Loading XTTS model with safe globals...")
62
+
63
+ # Add safe globals for PyTorch 2.6 compatibility
64
+ try:
65
+ import torch.serialization
66
+ # Import the required classes for safe globals
67
+ from TTS.tts.configs.xtts_config import XttsConfig
68
+ from TTS.tts.models.xtts import Xtts
69
+ from TTS.utils.manage import ModelManager
70
+
71
+ # Add the required classes to safe globals
72
+ torch.serialization.add_safe_globals([XttsConfig, Xtts, ModelManager])
73
+ print("βœ… Added safe globals for XTTS model")
74
+ except Exception as safe_globals_error:
75
+ print(f"⚠️ Could not add safe globals: {safe_globals_error}")
76
+
77
+ # Load the XTTS model
78
  tts = TTS(DEFAULT_MODEL).to(DEVICE)
79
  model_loaded = True
80
+ current_model = DEFAULT_MODEL
81
+ print("βœ… XTTS model loaded successfully with voice cloning support")
82
+
83
  except Exception as e:
84
+ print(f"❌ XTTS model failed: {e}")
85
  # Try fallback model
86
  try:
87
  print("πŸ”„ Trying fallback model...")
88
  tts = TTS("tts_models/en/ljspeech/tacotron2-DDC").to(DEVICE)
89
  model_loaded = True
90
+ current_model = "tts_models/en/ljspeech/tacotron2-DDC"
91
+ print("βœ… Fallback TTS model loaded successfully (English only, no voice cloning)")
92
  except Exception as fallback_error:
93
  print(f"❌ Fallback model also failed: {fallback_error}")
94
  tts = None
 
203
  # Copy audio files to voice directory
204
  for i, audio_file in enumerate(audio_files):
205
  dest_path = f"{voice_dir}/sample_{i+1}.wav"
206
+ shutil.copy2(audio_file, dest_path)
207
+ print(f" Copied sample {i+1} to: {dest_path}")
 
208
 
 
209
  print(f"βœ… Voice cloning setup completed for {voice_name}")
210
 
211
  return True, f"Voice {voice_name} is ready for use"
 
213
  except Exception as e:
214
  return False, f"Voice cloning failed: {str(e)}"
215
 
216
+ def supports_voice_cloning():
217
+ """Check if the current model supports voice cloning"""
218
+ return "xtts" in current_model.lower()
219
+
220
  # API endpoints
221
  @app.post("/api/tts")
222
  async def generate_tts(request: TTSRequest):
 
235
  print(f" Voice: {request.voice_name}")
236
  print(f" Language: {request.language}")
237
 
238
+ # Check if voice cloning is requested but not supported
239
+ if request.voice_name != "default" and not supports_voice_cloning():
240
+ return {
241
+ "status": "error",
242
+ "message": "Voice cloning is not supported with the current model. Only the default voice is available.",
243
+ "model": current_model
244
+ }
245
+
246
  # Generate unique filename with sequential naming
247
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
248
  filename = f"voiceover_{timestamp}.wav"
 
258
  "message": f"Voice '{request.voice_name}' not found"
259
  }
260
 
261
+ # Generate TTS based on model capabilities
262
+ if supports_voice_cloning():
263
+ # XTTS model with voice cloning support
264
+ tts.tts_to_file(
265
+ text=request.text,
266
+ speaker_wav=speaker_wav,
267
+ language=request.language,
268
+ file_path=output_path
269
+ )
270
+ else:
271
+ # Fallback model (Tacotron2)
272
+ tts.tts_to_file(
273
+ text=request.text,
274
+ file_path=output_path
275
+ )
276
 
277
  print(f"βœ… TTS generated: {output_path}")
278
 
 
303
  "status": "success",
304
  "message": "TTS generated and uploaded successfully",
305
  "filename": filename,
306
+ "oci_path": upload_result.get("path", f"{request.project_id}/voiceover/{filename}"),
307
+ "model_used": current_model,
308
+ "voice_cloning": supports_voice_cloning() and request.voice_name != "default"
309
  }
310
 
311
  except Exception as e:
 
324
  print(f" Voice: {request.voice_name}")
325
  print(f" Language: {request.language}")
326
 
327
+ # Check if voice cloning is requested but not supported
328
+ if request.voice_name != "default" and not supports_voice_cloning():
329
+ raise HTTPException(
330
+ status_code=400,
331
+ detail="Voice cloning is not supported with the current model. Only the default voice is available."
332
+ )
333
+
334
  # Get voice path if custom voice is requested
335
  speaker_wav = None
336
  if request.voice_name != "default":
 
347
  filename = f"voiceover_{i+1:02d}.wav"
348
  output_path = f"/tmp/output/{filename}"
349
 
350
+ # Generate TTS based on model capabilities
351
+ if supports_voice_cloning():
352
+ # XTTS model with voice cloning support
353
+ tts.tts_to_file(
354
+ text=text,
355
+ speaker_wav=speaker_wav,
356
+ language=request.language,
357
+ file_path=output_path
358
+ )
359
+ else:
360
+ # Fallback model (Tacotron2)
361
+ tts.tts_to_file(
362
+ text=text,
363
+ file_path=output_path
364
+ )
365
 
366
  # Upload to OCI
367
  upload_result, error = upload_to_oci_with_retry(
 
396
  return {
397
  "status": "completed",
398
  "project_id": request.project_id,
399
+ "results": results,
400
+ "model_used": current_model,
401
+ "voice_cloning": supports_voice_cloning() and request.voice_name != "default"
402
  }
403
 
404
  except Exception as e:
 
415
  try:
416
  print(f"πŸ“₯ Voice upload request: {voice_name} for project {project_id}")
417
 
418
+ # Check if voice cloning is supported
419
+ if not supports_voice_cloning():
420
+ raise HTTPException(
421
+ status_code=400,
422
+ detail="Voice cloning is not supported with the current model. Please use the XTTS model for voice cloning."
423
+ )
424
+
425
  # Validate file type
426
  if not file.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')):
427
  raise HTTPException(status_code=400, detail="Only audio files are allowed")
 
456
  try:
457
  print(f"πŸ“₯ Voice cloning request: {voice_name} for project {project_id}")
458
 
459
+ # Check if voice cloning is supported
460
+ if not supports_voice_cloning():
461
+ raise HTTPException(
462
+ status_code=400,
463
+ detail="Voice cloning is not supported with the current model. Please use the XTTS model for voice cloning."
464
+ )
465
+
466
  # Save uploaded files temporarily
467
  temp_files = []
468
  for i, file in enumerate(files):
 
527
 
528
  return {
529
  "status": "success",
530
+ "voices": voices,
531
+ "voice_cloning_supported": supports_voice_cloning()
532
  }
533
 
534
  except Exception as e:
 
541
  return {
542
  "status": "healthy",
543
  "tts_loaded": tts is not None,
544
+ "model": current_model,
545
+ "voice_cloning_supported": supports_voice_cloning(),
546
  "device": DEVICE,
547
  "timestamp": datetime.now().isoformat()
548
  }
 
552
  """Root endpoint with API information"""
553
  return {
554
  "message": "TTS API with Coqui TTS and Voice Cloning",
555
+ "en