yukee1992 commited on
Commit
1214c7f
Β·
verified Β·
1 Parent(s): 4624555

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -17
app.py CHANGED
@@ -305,6 +305,7 @@ def upload_to_oci(file_path: str, filename: str, project_id: str, file_type="voi
305
  except Exception as e:
306
  return None, f"Upload error: {str(e)}"
307
 
 
308
  def load_tts_model(model_type="tacotron2-ddc"):
309
  """Load TTS model with storage optimization"""
310
  global tts, model_loaded, current_model, model_loading
@@ -317,6 +318,11 @@ def load_tts_model(model_type="tacotron2-ddc"):
317
  print(f"❌ Model type '{model_type}' not found.")
318
  return False
319
 
 
 
 
 
 
320
  model_loading = True
321
 
322
  try:
@@ -337,12 +343,25 @@ def load_tts_model(model_type="tacotron2-ddc"):
337
  print(f"πŸš€ Loading {model_config['name']}...")
338
  print(f" Languages: {', '.join(model_config['languages'])}")
339
 
 
 
 
 
 
 
 
 
 
340
  # Load the selected model
341
  tts = TTS(model_config["model_name"]).to(DEVICE)
342
 
343
  # Test the model with appropriate text
344
  test_path = "/tmp/test_output.wav"
345
- test_text = "Hello" if "en" in model_config["languages"] else "δ½ ε₯½"
 
 
 
 
346
  tts.tts_to_file(text=test_text, file_path=test_path)
347
 
348
  if os.path.exists(test_path):
@@ -361,6 +380,10 @@ def load_tts_model(model_type="tacotron2-ddc"):
361
 
362
  except Exception as e:
363
  print(f"❌ Model failed to load: {e}")
 
 
 
 
364
  return False
365
 
366
  finally:
@@ -372,7 +395,7 @@ def load_tts_model(model_type="tacotron2-ddc"):
372
  finally:
373
  model_loading = False
374
 
375
- # ENHANCED: Model switching logic
376
  def ensure_correct_model(voice_style: str, text: str, language: str = "auto"):
377
  """Ensure the correct model is loaded for the requested voice style and language"""
378
  global tts, model_loaded, current_model
@@ -380,14 +403,20 @@ def ensure_correct_model(voice_style: str, text: str, language: str = "auto"):
380
  # Determine target model
381
  target_model = get_model_for_voice_style(voice_style, language)
382
 
 
 
383
  # If no model loaded or wrong model loaded, load the correct one
384
  if not model_loaded or current_model != target_model:
385
- print(f"πŸ”„ Switching to model: {target_model} for voice style: {voice_style}")
386
- return load_tts_model(target_model)
 
 
 
 
387
 
388
  return True
389
 
390
- # API endpoints
391
  @app.post("/api/tts")
392
  async def generate_tts(request: TTSRequest):
393
  """Generate TTS with multi-language support"""
@@ -415,6 +444,7 @@ async def generate_tts(request: TTSRequest):
415
  print(f" Voice Style: {request.voice_style}")
416
  print(f" Language: {detected_language}")
417
  print(f" Text length: {len(request.text)} characters")
 
418
 
419
  # Generate unique filename
420
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -430,18 +460,42 @@ async def generate_tts(request: TTSRequest):
430
 
431
  # Generate TTS
432
  try:
433
- # ENHANCED: For multilingual model, specify language
434
- if current_model == "your_tts" and detected_language in ["en", "zh"]:
435
- tts.tts_to_file(
436
- text=cleaned_text,
437
- file_path=output_path,
438
- language=detected_language
439
- )
 
 
 
 
 
 
 
 
 
440
  else:
441
- tts.tts_to_file(
442
- text=cleaned_text,
443
- file_path=output_path
444
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  except Exception as tts_error:
446
  print(f"❌ TTS generation failed: {tts_error}")
447
  raise tts_error
@@ -493,6 +547,7 @@ async def generate_tts(request: TTSRequest):
493
  "message": f"TTS generation failed: {str(e)}"
494
  }
495
 
 
496
  @app.post("/api/batch-tts")
497
  async def batch_generate_tts(request: BatchTTSRequest):
498
  """Batch TTS with multi-language support"""
@@ -500,6 +555,9 @@ async def batch_generate_tts(request: BatchTTSRequest):
500
  cleanup_old_files()
501
 
502
  print(f"πŸ“₯ Batch TTS request for {len(request.texts)} texts")
 
 
 
503
 
504
  results = []
505
  for i, text in enumerate(request.texts):
@@ -510,6 +568,8 @@ async def batch_generate_tts(request: BatchTTSRequest):
510
  else:
511
  text_language = request.language
512
 
 
 
513
  single_request = TTSRequest(
514
  text=text,
515
  project_id=request.project_id,
@@ -521,23 +581,37 @@ async def batch_generate_tts(request: BatchTTSRequest):
521
  result = await generate_tts(single_request)
522
  results.append({
523
  "text_index": i,
 
524
  "status": result.get("status", "error"),
525
  "message": result.get("message", ""),
526
  "filename": result.get("filename", ""),
527
  "oci_path": result.get("oci_path", ""),
528
- "language": result.get("language", "unknown") # ENHANCED: Include language
529
  })
530
 
531
  except Exception as e:
 
532
  results.append({
533
  "text_index": i,
 
534
  "status": "error",
535
  "message": f"Failed to generate TTS: {str(e)}"
536
  })
537
 
 
 
 
 
 
 
538
  return {
539
  "status": "completed",
540
  "project_id": request.project_id,
 
 
 
 
 
541
  "results": results,
542
  "model_used": current_model
543
  }
 
305
  except Exception as e:
306
  return None, f"Upload error: {str(e)}"
307
 
308
+ # FIXED: Improved model loading with better error handling and memory management
309
  def load_tts_model(model_type="tacotron2-ddc"):
310
  """Load TTS model with storage optimization"""
311
  global tts, model_loaded, current_model, model_loading
 
318
  print(f"❌ Model type '{model_type}' not found.")
319
  return False
320
 
321
+ # If we're already using the correct model, no need to reload
322
+ if model_loaded and current_model == model_type:
323
+ print(f"βœ… Model {model_type} is already loaded")
324
+ return True
325
+
326
  model_loading = True
327
 
328
  try:
 
343
  print(f"πŸš€ Loading {model_config['name']}...")
344
  print(f" Languages: {', '.join(model_config['languages'])}")
345
 
346
+ # Clear current model from memory first if exists
347
+ if tts is not None:
348
+ print("🧹 Clearing previous model from memory...")
349
+ del tts
350
+ import gc
351
+ gc.collect()
352
+ if torch.cuda.is_available():
353
+ torch.cuda.empty_cache()
354
+
355
  # Load the selected model
356
  tts = TTS(model_config["model_name"]).to(DEVICE)
357
 
358
  # Test the model with appropriate text
359
  test_path = "/tmp/test_output.wav"
360
+ if "zh" in model_config["languages"]:
361
+ test_text = "δ½ ε₯½" # Chinese test
362
+ else:
363
+ test_text = "Hello" # English test
364
+
365
  tts.tts_to_file(text=test_text, file_path=test_path)
366
 
367
  if os.path.exists(test_path):
 
380
 
381
  except Exception as e:
382
  print(f"❌ Model failed to load: {e}")
383
+ # Fallback to English model if multilingual fails
384
+ if model_type == "your_tts":
385
+ print("πŸ”„ Falling back to English model...")
386
+ return load_tts_model("tacotron2-ddc")
387
  return False
388
 
389
  finally:
 
395
  finally:
396
  model_loading = False
397
 
398
+ # FIXED: Improved model switching logic with better detection
399
  def ensure_correct_model(voice_style: str, text: str, language: str = "auto"):
400
  """Ensure the correct model is loaded for the requested voice style and language"""
401
  global tts, model_loaded, current_model
 
403
  # Determine target model
404
  target_model = get_model_for_voice_style(voice_style, language)
405
 
406
+ print(f"πŸ” Model selection: voice_style={voice_style}, language={language}, target_model={target_model}")
407
+
408
  # If no model loaded or wrong model loaded, load the correct one
409
  if not model_loaded or current_model != target_model:
410
+ print(f"πŸ”„ Switching to model: {target_model} for voice style: {voice_style}, language: {language}")
411
+ success = load_tts_model(target_model)
412
+ if not success and target_model == "your_tts":
413
+ print("⚠️ Multilingual model failed, falling back to English model")
414
+ return load_tts_model("tacotron2-ddc")
415
+ return success
416
 
417
  return True
418
 
419
+ # FIXED: Enhanced TTS generation with proper language handling
420
  @app.post("/api/tts")
421
  async def generate_tts(request: TTSRequest):
422
  """Generate TTS with multi-language support"""
 
444
  print(f" Voice Style: {request.voice_style}")
445
  print(f" Language: {detected_language}")
446
  print(f" Text length: {len(request.text)} characters")
447
+ print(f" Current Model: {current_model}")
448
 
449
  # Generate unique filename
450
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
 
460
 
461
  # Generate TTS
462
  try:
463
+ # FIXED: Proper language handling for multilingual model
464
+ if current_model == "your_tts":
465
+ if detected_language == "zh":
466
+ print("🎯 Using YourTTS for Chinese text with zh-cn language code")
467
+ tts.tts_to_file(
468
+ text=cleaned_text,
469
+ file_path=output_path,
470
+ language="zh-cn" # Use zh-cn for Chinese
471
+ )
472
+ else:
473
+ print("🎯 Using YourTTS for English text")
474
+ tts.tts_to_file(
475
+ text=cleaned_text,
476
+ file_path=output_path,
477
+ language="en"
478
+ )
479
  else:
480
+ # Tacotron2-DDC for English only
481
+ if detected_language == "zh":
482
+ # If Chinese text but English model, try to switch to multilingual
483
+ print("πŸ”„ Chinese text detected with English model, attempting to switch to multilingual...")
484
+ if load_tts_model("your_tts"):
485
+ # Retry with multilingual model
486
+ tts.tts_to_file(
487
+ text=cleaned_text,
488
+ file_path=output_path,
489
+ language="zh-cn"
490
+ )
491
+ else:
492
+ raise Exception("Chinese text cannot be processed. Multilingual model failed to load.")
493
+ else:
494
+ print("🎯 Using Tacotron2-DDC for English text")
495
+ tts.tts_to_file(
496
+ text=cleaned_text,
497
+ file_path=output_path
498
+ )
499
  except Exception as tts_error:
500
  print(f"❌ TTS generation failed: {tts_error}")
501
  raise tts_error
 
547
  "message": f"TTS generation failed: {str(e)}"
548
  }
549
 
550
+ # FIXED: Enhanced batch processing with better logging and error handling
551
  @app.post("/api/batch-tts")
552
  async def batch_generate_tts(request: BatchTTSRequest):
553
  """Batch TTS with multi-language support"""
 
555
  cleanup_old_files()
556
 
557
  print(f"πŸ“₯ Batch TTS request for {len(request.texts)} texts")
558
+ print(f" Project: {request.project_id}")
559
+ print(f" Voice Style: {request.voice_style}")
560
+ print(f" Language: {request.language}")
561
 
562
  results = []
563
  for i, text in enumerate(request.texts):
 
568
  else:
569
  text_language = request.language
570
 
571
+ print(f" Processing text {i+1}/{len(request.texts)}: {text_language} - {text[:50]}...")
572
+
573
  single_request = TTSRequest(
574
  text=text,
575
  project_id=request.project_id,
 
581
  result = await generate_tts(single_request)
582
  results.append({
583
  "text_index": i,
584
+ "text_preview": text[:30] + "..." if len(text) > 30 else text,
585
  "status": result.get("status", "error"),
586
  "message": result.get("message", ""),
587
  "filename": result.get("filename", ""),
588
  "oci_path": result.get("oci_path", ""),
589
+ "language": result.get("language", "unknown")
590
  })
591
 
592
  except Exception as e:
593
+ print(f"❌ Failed to process text {i}: {str(e)}")
594
  results.append({
595
  "text_index": i,
596
+ "text_preview": text[:30] + "..." if len(text) > 30 else text,
597
  "status": "error",
598
  "message": f"Failed to generate TTS: {str(e)}"
599
  })
600
 
601
+ # Summary
602
+ success_count = sum(1 for r in results if r.get("status") == "success")
603
+ error_count = sum(1 for r in results if r.get("status") == "error")
604
+
605
+ print(f"πŸ“Š Batch completed: {success_count} successful, {error_count} failed")
606
+
607
  return {
608
  "status": "completed",
609
  "project_id": request.project_id,
610
+ "summary": {
611
+ "total": len(results),
612
+ "successful": success_count,
613
+ "failed": error_count
614
+ },
615
  "results": results,
616
  "model_used": current_model
617
  }