Spaces:
Sleeping
Sleeping
Fix syntax error in TTS stage and complete pipeline
Browse files
app.py
CHANGED
|
@@ -380,8 +380,18 @@ try:
|
|
| 380 |
TTS_MODEL.load_state_dict(state["state_dict"])
|
| 381 |
else:
|
| 382 |
TTS_MODEL.load_state_dict(state) # Assume the whole file is the state_dict
|
|
|
|
| 383 |
TTS_MODEL.eval()
|
| 384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
# Try torch.compile for additional speedup (PyTorch 2.0+)
|
| 386 |
try:
|
| 387 |
TTS_MODEL = torch.compile(TTS_MODEL, mode="reduce-overhead")
|
|
@@ -389,10 +399,7 @@ try:
|
|
| 389 |
except Exception as compile_error:
|
| 390 |
print(f"Torch compile not available: {compile_error}, using standard model.")
|
| 391 |
|
| 392 |
-
print("TTS model loaded successfully.")
|
| 393 |
-
except Exception as e:
|
| 394 |
-
print(f"Error loading TTS model: {e}")
|
| 395 |
-
TTS_MODEL = None
|
| 396 |
|
| 397 |
# Load STT (Whisper) Model from Hub
|
| 398 |
try:
|
|
|
|
| 380 |
TTS_MODEL.load_state_dict(state["state_dict"])
|
| 381 |
else:
|
| 382 |
TTS_MODEL.load_state_dict(state) # Assume the whole file is the state_dict
|
| 383 |
+
|
| 384 |
TTS_MODEL.eval()
|
| 385 |
|
| 386 |
+
# Set all submodules to eval mode and disable gradients permanently
|
| 387 |
+
for module in TTS_MODEL.modules():
|
| 388 |
+
if hasattr(module, 'training'):
|
| 389 |
+
module.train(False)
|
| 390 |
+
|
| 391 |
+
# Disable gradients for all parameters permanently
|
| 392 |
+
for param in TTS_MODEL.parameters():
|
| 393 |
+
param.requires_grad = False
|
| 394 |
+
|
| 395 |
# Try torch.compile for additional speedup (PyTorch 2.0+)
|
| 396 |
try:
|
| 397 |
TTS_MODEL = torch.compile(TTS_MODEL, mode="reduce-overhead")
|
|
|
|
| 399 |
except Exception as compile_error:
|
| 400 |
print(f"Torch compile not available: {compile_error}, using standard model.")
|
| 401 |
|
| 402 |
+
print("TTS model loaded successfully (optimized for inference).")
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
# Load STT (Whisper) Model from Hub
|
| 405 |
try:
|