Spaces:
Sleeping
Sleeping
Fix syntax error in TTS stage and complete pipeline
Browse files
app.py
CHANGED
|
@@ -307,7 +307,8 @@ class TransformerTTS(nn.Module):
|
|
| 307 |
|
| 308 |
@torch.no_grad()
|
| 309 |
def inference(self, text, max_length=800, gate_threshold=1e-5, with_tqdm=True):
|
| 310 |
-
|
|
|
|
| 311 |
text_lengths = torch.tensor(text.shape[1]).unsqueeze(0).to(DEVICE)
|
| 312 |
N = 1
|
| 313 |
SOS = torch.zeros((N, 1, hp.mel_freq), device=DEVICE)
|
|
@@ -379,18 +380,7 @@ try:
|
|
| 379 |
TTS_MODEL.load_state_dict(state["state_dict"])
|
| 380 |
else:
|
| 381 |
TTS_MODEL.load_state_dict(state) # Assume the whole file is the state_dict
|
| 382 |
-
|
| 383 |
TTS_MODEL.eval()
|
| 384 |
-
|
| 385 |
-
# Set all submodules to eval mode and disable gradients permanently
|
| 386 |
-
for module in TTS_MODEL.modules():
|
| 387 |
-
if hasattr(module, 'training'):
|
| 388 |
-
module.train(False)
|
| 389 |
-
|
| 390 |
-
# Disable gradients for all parameters permanently
|
| 391 |
-
for param in TTS_MODEL.parameters():
|
| 392 |
-
param.requires_grad = False
|
| 393 |
-
|
| 394 |
|
| 395 |
# Try torch.compile for additional speedup (PyTorch 2.0+)
|
| 396 |
try:
|
|
|
|
| 307 |
|
| 308 |
@torch.no_grad()
|
| 309 |
def inference(self, text, max_length=800, gate_threshold=1e-5, with_tqdm=True):
|
| 310 |
+
self.eval()
|
| 311 |
+
self.train(False)
|
| 312 |
text_lengths = torch.tensor(text.shape[1]).unsqueeze(0).to(DEVICE)
|
| 313 |
N = 1
|
| 314 |
SOS = torch.zeros((N, 1, hp.mel_freq), device=DEVICE)
|
|
|
|
| 380 |
TTS_MODEL.load_state_dict(state["state_dict"])
|
| 381 |
else:
|
| 382 |
TTS_MODEL.load_state_dict(state) # Assume the whole file is the state_dict
|
|
|
|
| 383 |
TTS_MODEL.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
# Try torch.compile for additional speedup (PyTorch 2.0+)
|
| 386 |
try:
|