Update app.py
Browse files
app.py
CHANGED
|
@@ -305,6 +305,7 @@ def upload_to_oci(file_path: str, filename: str, project_id: str, file_type="voi
|
|
| 305 |
except Exception as e:
|
| 306 |
return None, f"Upload error: {str(e)}"
|
| 307 |
|
|
|
|
| 308 |
def load_tts_model(model_type="tacotron2-ddc"):
|
| 309 |
"""Load TTS model with storage optimization"""
|
| 310 |
global tts, model_loaded, current_model, model_loading
|
|
@@ -317,6 +318,11 @@ def load_tts_model(model_type="tacotron2-ddc"):
|
|
| 317 |
print(f"β Model type '{model_type}' not found.")
|
| 318 |
return False
|
| 319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
model_loading = True
|
| 321 |
|
| 322 |
try:
|
|
@@ -337,12 +343,25 @@ def load_tts_model(model_type="tacotron2-ddc"):
|
|
| 337 |
print(f"π Loading {model_config['name']}...")
|
| 338 |
print(f" Languages: {', '.join(model_config['languages'])}")
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
# Load the selected model
|
| 341 |
tts = TTS(model_config["model_name"]).to(DEVICE)
|
| 342 |
|
| 343 |
# Test the model with appropriate text
|
| 344 |
test_path = "/tmp/test_output.wav"
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
tts.tts_to_file(text=test_text, file_path=test_path)
|
| 347 |
|
| 348 |
if os.path.exists(test_path):
|
|
@@ -361,6 +380,10 @@ def load_tts_model(model_type="tacotron2-ddc"):
|
|
| 361 |
|
| 362 |
except Exception as e:
|
| 363 |
print(f"β Model failed to load: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
return False
|
| 365 |
|
| 366 |
finally:
|
|
@@ -372,7 +395,7 @@ def load_tts_model(model_type="tacotron2-ddc"):
|
|
| 372 |
finally:
|
| 373 |
model_loading = False
|
| 374 |
|
| 375 |
-
#
|
| 376 |
def ensure_correct_model(voice_style: str, text: str, language: str = "auto"):
|
| 377 |
"""Ensure the correct model is loaded for the requested voice style and language"""
|
| 378 |
global tts, model_loaded, current_model
|
|
@@ -380,14 +403,20 @@ def ensure_correct_model(voice_style: str, text: str, language: str = "auto"):
|
|
| 380 |
# Determine target model
|
| 381 |
target_model = get_model_for_voice_style(voice_style, language)
|
| 382 |
|
|
|
|
|
|
|
| 383 |
# If no model loaded or wrong model loaded, load the correct one
|
| 384 |
if not model_loaded or current_model != target_model:
|
| 385 |
-
print(f"π Switching to model: {target_model} for voice style: {voice_style}")
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
return True
|
| 389 |
|
| 390 |
-
#
|
| 391 |
@app.post("/api/tts")
|
| 392 |
async def generate_tts(request: TTSRequest):
|
| 393 |
"""Generate TTS with multi-language support"""
|
|
@@ -415,6 +444,7 @@ async def generate_tts(request: TTSRequest):
|
|
| 415 |
print(f" Voice Style: {request.voice_style}")
|
| 416 |
print(f" Language: {detected_language}")
|
| 417 |
print(f" Text length: {len(request.text)} characters")
|
|
|
|
| 418 |
|
| 419 |
# Generate unique filename
|
| 420 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
@@ -430,18 +460,42 @@ async def generate_tts(request: TTSRequest):
|
|
| 430 |
|
| 431 |
# Generate TTS
|
| 432 |
try:
|
| 433 |
-
#
|
| 434 |
-
if current_model == "your_tts"
|
| 435 |
-
|
| 436 |
-
text
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
else:
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
except Exception as tts_error:
|
| 446 |
print(f"β TTS generation failed: {tts_error}")
|
| 447 |
raise tts_error
|
|
@@ -493,6 +547,7 @@ async def generate_tts(request: TTSRequest):
|
|
| 493 |
"message": f"TTS generation failed: {str(e)}"
|
| 494 |
}
|
| 495 |
|
|
|
|
| 496 |
@app.post("/api/batch-tts")
|
| 497 |
async def batch_generate_tts(request: BatchTTSRequest):
|
| 498 |
"""Batch TTS with multi-language support"""
|
|
@@ -500,6 +555,9 @@ async def batch_generate_tts(request: BatchTTSRequest):
|
|
| 500 |
cleanup_old_files()
|
| 501 |
|
| 502 |
print(f"π₯ Batch TTS request for {len(request.texts)} texts")
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
results = []
|
| 505 |
for i, text in enumerate(request.texts):
|
|
@@ -510,6 +568,8 @@ async def batch_generate_tts(request: BatchTTSRequest):
|
|
| 510 |
else:
|
| 511 |
text_language = request.language
|
| 512 |
|
|
|
|
|
|
|
| 513 |
single_request = TTSRequest(
|
| 514 |
text=text,
|
| 515 |
project_id=request.project_id,
|
|
@@ -521,23 +581,37 @@ async def batch_generate_tts(request: BatchTTSRequest):
|
|
| 521 |
result = await generate_tts(single_request)
|
| 522 |
results.append({
|
| 523 |
"text_index": i,
|
|
|
|
| 524 |
"status": result.get("status", "error"),
|
| 525 |
"message": result.get("message", ""),
|
| 526 |
"filename": result.get("filename", ""),
|
| 527 |
"oci_path": result.get("oci_path", ""),
|
| 528 |
-
"language": result.get("language", "unknown")
|
| 529 |
})
|
| 530 |
|
| 531 |
except Exception as e:
|
|
|
|
| 532 |
results.append({
|
| 533 |
"text_index": i,
|
|
|
|
| 534 |
"status": "error",
|
| 535 |
"message": f"Failed to generate TTS: {str(e)}"
|
| 536 |
})
|
| 537 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
return {
|
| 539 |
"status": "completed",
|
| 540 |
"project_id": request.project_id,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
"results": results,
|
| 542 |
"model_used": current_model
|
| 543 |
}
|
|
|
|
| 305 |
except Exception as e:
|
| 306 |
return None, f"Upload error: {str(e)}"
|
| 307 |
|
| 308 |
+
# FIXED: Improved model loading with better error handling and memory management
|
| 309 |
def load_tts_model(model_type="tacotron2-ddc"):
|
| 310 |
"""Load TTS model with storage optimization"""
|
| 311 |
global tts, model_loaded, current_model, model_loading
|
|
|
|
| 318 |
print(f"β Model type '{model_type}' not found.")
|
| 319 |
return False
|
| 320 |
|
| 321 |
+
# If we're already using the correct model, no need to reload
|
| 322 |
+
if model_loaded and current_model == model_type:
|
| 323 |
+
print(f"β
Model {model_type} is already loaded")
|
| 324 |
+
return True
|
| 325 |
+
|
| 326 |
model_loading = True
|
| 327 |
|
| 328 |
try:
|
|
|
|
| 343 |
print(f"π Loading {model_config['name']}...")
|
| 344 |
print(f" Languages: {', '.join(model_config['languages'])}")
|
| 345 |
|
| 346 |
+
# Clear current model from memory first if exists
|
| 347 |
+
if tts is not None:
|
| 348 |
+
print("π§Ή Clearing previous model from memory...")
|
| 349 |
+
del tts
|
| 350 |
+
import gc
|
| 351 |
+
gc.collect()
|
| 352 |
+
if torch.cuda.is_available():
|
| 353 |
+
torch.cuda.empty_cache()
|
| 354 |
+
|
| 355 |
# Load the selected model
|
| 356 |
tts = TTS(model_config["model_name"]).to(DEVICE)
|
| 357 |
|
| 358 |
# Test the model with appropriate text
|
| 359 |
test_path = "/tmp/test_output.wav"
|
| 360 |
+
if "zh" in model_config["languages"]:
|
| 361 |
+
test_text = "δ½ ε₯½" # Chinese test
|
| 362 |
+
else:
|
| 363 |
+
test_text = "Hello" # English test
|
| 364 |
+
|
| 365 |
tts.tts_to_file(text=test_text, file_path=test_path)
|
| 366 |
|
| 367 |
if os.path.exists(test_path):
|
|
|
|
| 380 |
|
| 381 |
except Exception as e:
|
| 382 |
print(f"β Model failed to load: {e}")
|
| 383 |
+
# Fallback to English model if multilingual fails
|
| 384 |
+
if model_type == "your_tts":
|
| 385 |
+
print("π Falling back to English model...")
|
| 386 |
+
return load_tts_model("tacotron2-ddc")
|
| 387 |
return False
|
| 388 |
|
| 389 |
finally:
|
|
|
|
| 395 |
finally:
|
| 396 |
model_loading = False
|
| 397 |
|
| 398 |
+
# FIXED: Improved model switching logic with better detection
|
| 399 |
def ensure_correct_model(voice_style: str, text: str, language: str = "auto"):
|
| 400 |
"""Ensure the correct model is loaded for the requested voice style and language"""
|
| 401 |
global tts, model_loaded, current_model
|
|
|
|
| 403 |
# Determine target model
|
| 404 |
target_model = get_model_for_voice_style(voice_style, language)
|
| 405 |
|
| 406 |
+
print(f"π Model selection: voice_style={voice_style}, language={language}, target_model={target_model}")
|
| 407 |
+
|
| 408 |
# If no model loaded or wrong model loaded, load the correct one
|
| 409 |
if not model_loaded or current_model != target_model:
|
| 410 |
+
print(f"π Switching to model: {target_model} for voice style: {voice_style}, language: {language}")
|
| 411 |
+
success = load_tts_model(target_model)
|
| 412 |
+
if not success and target_model == "your_tts":
|
| 413 |
+
print("β οΈ Multilingual model failed, falling back to English model")
|
| 414 |
+
return load_tts_model("tacotron2-ddc")
|
| 415 |
+
return success
|
| 416 |
|
| 417 |
return True
|
| 418 |
|
| 419 |
+
# FIXED: Enhanced TTS generation with proper language handling
|
| 420 |
@app.post("/api/tts")
|
| 421 |
async def generate_tts(request: TTSRequest):
|
| 422 |
"""Generate TTS with multi-language support"""
|
|
|
|
| 444 |
print(f" Voice Style: {request.voice_style}")
|
| 445 |
print(f" Language: {detected_language}")
|
| 446 |
print(f" Text length: {len(request.text)} characters")
|
| 447 |
+
print(f" Current Model: {current_model}")
|
| 448 |
|
| 449 |
# Generate unique filename
|
| 450 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
| 460 |
|
| 461 |
# Generate TTS
|
| 462 |
try:
|
| 463 |
+
# FIXED: Proper language handling for multilingual model
|
| 464 |
+
if current_model == "your_tts":
|
| 465 |
+
if detected_language == "zh":
|
| 466 |
+
print("π― Using YourTTS for Chinese text with zh-cn language code")
|
| 467 |
+
tts.tts_to_file(
|
| 468 |
+
text=cleaned_text,
|
| 469 |
+
file_path=output_path,
|
| 470 |
+
language="zh-cn" # Use zh-cn for Chinese
|
| 471 |
+
)
|
| 472 |
+
else:
|
| 473 |
+
print("π― Using YourTTS for English text")
|
| 474 |
+
tts.tts_to_file(
|
| 475 |
+
text=cleaned_text,
|
| 476 |
+
file_path=output_path,
|
| 477 |
+
language="en"
|
| 478 |
+
)
|
| 479 |
else:
|
| 480 |
+
# Tacotron2-DDC for English only
|
| 481 |
+
if detected_language == "zh":
|
| 482 |
+
# If Chinese text but English model, try to switch to multilingual
|
| 483 |
+
print("π Chinese text detected with English model, attempting to switch to multilingual...")
|
| 484 |
+
if load_tts_model("your_tts"):
|
| 485 |
+
# Retry with multilingual model
|
| 486 |
+
tts.tts_to_file(
|
| 487 |
+
text=cleaned_text,
|
| 488 |
+
file_path=output_path,
|
| 489 |
+
language="zh-cn"
|
| 490 |
+
)
|
| 491 |
+
else:
|
| 492 |
+
raise Exception("Chinese text cannot be processed. Multilingual model failed to load.")
|
| 493 |
+
else:
|
| 494 |
+
print("π― Using Tacotron2-DDC for English text")
|
| 495 |
+
tts.tts_to_file(
|
| 496 |
+
text=cleaned_text,
|
| 497 |
+
file_path=output_path
|
| 498 |
+
)
|
| 499 |
except Exception as tts_error:
|
| 500 |
print(f"β TTS generation failed: {tts_error}")
|
| 501 |
raise tts_error
|
|
|
|
| 547 |
"message": f"TTS generation failed: {str(e)}"
|
| 548 |
}
|
| 549 |
|
| 550 |
+
# FIXED: Enhanced batch processing with better logging and error handling
|
| 551 |
@app.post("/api/batch-tts")
|
| 552 |
async def batch_generate_tts(request: BatchTTSRequest):
|
| 553 |
"""Batch TTS with multi-language support"""
|
|
|
|
| 555 |
cleanup_old_files()
|
| 556 |
|
| 557 |
print(f"π₯ Batch TTS request for {len(request.texts)} texts")
|
| 558 |
+
print(f" Project: {request.project_id}")
|
| 559 |
+
print(f" Voice Style: {request.voice_style}")
|
| 560 |
+
print(f" Language: {request.language}")
|
| 561 |
|
| 562 |
results = []
|
| 563 |
for i, text in enumerate(request.texts):
|
|
|
|
| 568 |
else:
|
| 569 |
text_language = request.language
|
| 570 |
|
| 571 |
+
print(f" Processing text {i+1}/{len(request.texts)}: {text_language} - {text[:50]}...")
|
| 572 |
+
|
| 573 |
single_request = TTSRequest(
|
| 574 |
text=text,
|
| 575 |
project_id=request.project_id,
|
|
|
|
| 581 |
result = await generate_tts(single_request)
|
| 582 |
results.append({
|
| 583 |
"text_index": i,
|
| 584 |
+
"text_preview": text[:30] + "..." if len(text) > 30 else text,
|
| 585 |
"status": result.get("status", "error"),
|
| 586 |
"message": result.get("message", ""),
|
| 587 |
"filename": result.get("filename", ""),
|
| 588 |
"oci_path": result.get("oci_path", ""),
|
| 589 |
+
"language": result.get("language", "unknown")
|
| 590 |
})
|
| 591 |
|
| 592 |
except Exception as e:
|
| 593 |
+
print(f"β Failed to process text {i}: {str(e)}")
|
| 594 |
results.append({
|
| 595 |
"text_index": i,
|
| 596 |
+
"text_preview": text[:30] + "..." if len(text) > 30 else text,
|
| 597 |
"status": "error",
|
| 598 |
"message": f"Failed to generate TTS: {str(e)}"
|
| 599 |
})
|
| 600 |
|
| 601 |
+
# Summary
|
| 602 |
+
success_count = sum(1 for r in results if r.get("status") == "success")
|
| 603 |
+
error_count = sum(1 for r in results if r.get("status") == "error")
|
| 604 |
+
|
| 605 |
+
print(f"π Batch completed: {success_count} successful, {error_count} failed")
|
| 606 |
+
|
| 607 |
return {
|
| 608 |
"status": "completed",
|
| 609 |
"project_id": request.project_id,
|
| 610 |
+
"summary": {
|
| 611 |
+
"total": len(results),
|
| 612 |
+
"successful": success_count,
|
| 613 |
+
"failed": error_count
|
| 614 |
+
},
|
| 615 |
"results": results,
|
| 616 |
"model_used": current_model
|
| 617 |
}
|