MoHamdyy commited on
Commit
1dab1d3
·
1 Parent(s): 7318551

Fix syntax error in TTS stage and complete pipeline

Browse files
Files changed (1) hide show
  1. app.py +2 -12
app.py CHANGED
@@ -307,7 +307,8 @@ class TransformerTTS(nn.Module):
307
 
308
  @torch.no_grad()
309
  def inference(self, text, max_length=800, gate_threshold=1e-5, with_tqdm=True):
310
-
 
311
  text_lengths = torch.tensor(text.shape[1]).unsqueeze(0).to(DEVICE)
312
  N = 1
313
  SOS = torch.zeros((N, 1, hp.mel_freq), device=DEVICE)
@@ -379,18 +380,7 @@ try:
379
  TTS_MODEL.load_state_dict(state["state_dict"])
380
  else:
381
  TTS_MODEL.load_state_dict(state) # Assume the whole file is the state_dict
382
-
383
  TTS_MODEL.eval()
384
-
385
- # Set all submodules to eval mode and disable gradients permanently
386
- for module in TTS_MODEL.modules():
387
- if hasattr(module, 'training'):
388
- module.train(False)
389
-
390
- # Disable gradients for all parameters permanently
391
- for param in TTS_MODEL.parameters():
392
- param.requires_grad = False
393
-
394
 
395
  # Try torch.compile for additional speedup (PyTorch 2.0+)
396
  try:
 
307
 
308
  @torch.no_grad()
309
  def inference(self, text, max_length=800, gate_threshold=1e-5, with_tqdm=True):
310
+ self.eval()
311
+ self.train(False)
312
  text_lengths = torch.tensor(text.shape[1]).unsqueeze(0).to(DEVICE)
313
  N = 1
314
  SOS = torch.zeros((N, 1, hp.mel_freq), device=DEVICE)
 
380
  TTS_MODEL.load_state_dict(state["state_dict"])
381
  else:
382
  TTS_MODEL.load_state_dict(state) # Assume the whole file is the state_dict
 
383
  TTS_MODEL.eval()
 
 
 
 
 
 
 
 
 
 
384
 
385
  # Try torch.compile for additional speedup (PyTorch 2.0+)
386
  try: